use crate::error::{NerfError, NerfResult};
#[derive(Debug, Clone, Copy)]
pub struct IpeConfig {
pub n_freq: usize,
pub input_dim: usize,
}
impl IpeConfig {
#[must_use]
pub fn output_dim(&self) -> usize {
self.input_dim * 2 * self.n_freq
}
}
pub fn integrated_pe(mu: &[f32], diag_sigma2: &[f32], cfg: &IpeConfig) -> NerfResult<Vec<f32>> {
if cfg.n_freq == 0 {
return Err(NerfError::InvalidFreqLevels { levels: 0 });
}
if mu.len() != cfg.input_dim {
return Err(NerfError::DimensionMismatch {
expected: cfg.input_dim,
got: mu.len(),
});
}
if diag_sigma2.len() != cfg.input_dim {
return Err(NerfError::DimensionMismatch {
expected: cfg.input_dim,
got: diag_sigma2.len(),
});
}
let out_dim = cfg.output_dim();
let mut out = vec![0.0_f32; out_dim];
let mut write_pos = 0;
for k in 0..cfg.n_freq {
let omega = (2_u32.pow(k as u32)) as f32 * std::f32::consts::PI;
let omega_sq = omega * omega;
for d in 0..cfg.input_dim {
let mu_d = mu[d];
let sigma2_d = diag_sigma2[d].max(0.0);
let attenuation = (-0.5 * omega_sq * sigma2_d).exp();
out[write_pos] = (omega * mu_d).sin() * attenuation;
write_pos += 1;
out[write_pos] = (omega * mu_d).cos() * attenuation;
write_pos += 1;
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ipe_output_dim() {
let cfg = IpeConfig {
n_freq: 4,
input_dim: 3,
};
assert_eq!(cfg.output_dim(), 24);
}
#[test]
fn ipe_zero_variance_matches_pe() {
let cfg = IpeConfig {
n_freq: 2,
input_dim: 1,
};
let mu = [0.5_f32];
let sigma2 = [0.0_f32];
let ipe_out = integrated_pe(&mu, &sigma2, &cfg).unwrap();
let expected: Vec<f32> = (0..2)
.flat_map(|k| {
let omega = (2_u32.pow(k)) as f32 * std::f32::consts::PI;
[(omega * 0.5).sin(), (omega * 0.5).cos()]
})
.collect();
for (a, b) in ipe_out.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-5, "IPE mismatch: {a} vs {b}");
}
}
#[test]
fn ipe_high_variance_attenuates() {
let cfg = IpeConfig {
n_freq: 4,
input_dim: 1,
};
let mu = [0.5_f32];
let low_var = [0.0001_f32];
let high_var = [100.0_f32];
let out_low = integrated_pe(&mu, &low_var, &cfg).unwrap();
let out_high = integrated_pe(&mu, &high_var, &cfg).unwrap();
let mag_low: f32 = out_low.iter().map(|v| v * v).sum::<f32>().sqrt();
let mag_high: f32 = out_high.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(
mag_high < mag_low,
"high variance should attenuate encoding"
);
}
}