oxicuda_nerf/encoding/
integrated_pe.rs1use crate::error::{NerfError, NerfResult};
12
13#[derive(Debug, Clone, Copy)]
15pub struct IpeConfig {
16 pub n_freq: usize,
18 pub input_dim: usize,
20}
21
22impl IpeConfig {
23 #[must_use]
25 pub fn output_dim(&self) -> usize {
26 self.input_dim * 2 * self.n_freq
27 }
28}
29
30pub fn integrated_pe(mu: &[f32], diag_sigma2: &[f32], cfg: &IpeConfig) -> NerfResult<Vec<f32>> {
40 if cfg.n_freq == 0 {
41 return Err(NerfError::InvalidFreqLevels { levels: 0 });
42 }
43 if mu.len() != cfg.input_dim {
44 return Err(NerfError::DimensionMismatch {
45 expected: cfg.input_dim,
46 got: mu.len(),
47 });
48 }
49 if diag_sigma2.len() != cfg.input_dim {
50 return Err(NerfError::DimensionMismatch {
51 expected: cfg.input_dim,
52 got: diag_sigma2.len(),
53 });
54 }
55
56 let out_dim = cfg.output_dim();
57 let mut out = vec![0.0_f32; out_dim];
58 let mut write_pos = 0;
59
60 for k in 0..cfg.n_freq {
61 let omega = (2_u32.pow(k as u32)) as f32 * std::f32::consts::PI;
62 let omega_sq = omega * omega;
63
64 for d in 0..cfg.input_dim {
65 let mu_d = mu[d];
66 let sigma2_d = diag_sigma2[d].max(0.0);
67 let attenuation = (-0.5 * omega_sq * sigma2_d).exp();
68
69 out[write_pos] = (omega * mu_d).sin() * attenuation;
70 write_pos += 1;
71 out[write_pos] = (omega * mu_d).cos() * attenuation;
72 write_pos += 1;
73 }
74 }
75
76 Ok(out)
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn ipe_output_dim() {
85 let cfg = IpeConfig {
86 n_freq: 4,
87 input_dim: 3,
88 };
89 assert_eq!(cfg.output_dim(), 24);
90 }
91
92 #[test]
93 fn ipe_zero_variance_matches_pe() {
94 let cfg = IpeConfig {
96 n_freq: 2,
97 input_dim: 1,
98 };
99 let mu = [0.5_f32];
100 let sigma2 = [0.0_f32];
101 let ipe_out = integrated_pe(&mu, &sigma2, &cfg).unwrap();
102
103 let expected: Vec<f32> = (0..2)
105 .flat_map(|k| {
106 let omega = (2_u32.pow(k)) as f32 * std::f32::consts::PI;
107 [(omega * 0.5).sin(), (omega * 0.5).cos()]
108 })
109 .collect();
110
111 for (a, b) in ipe_out.iter().zip(expected.iter()) {
112 assert!((a - b).abs() < 1e-5, "IPE mismatch: {a} vs {b}");
113 }
114 }
115
116 #[test]
117 fn ipe_high_variance_attenuates() {
118 let cfg = IpeConfig {
119 n_freq: 4,
120 input_dim: 1,
121 };
122 let mu = [0.5_f32];
123 let low_var = [0.0001_f32];
124 let high_var = [100.0_f32];
125 let out_low = integrated_pe(&mu, &low_var, &cfg).unwrap();
126 let out_high = integrated_pe(&mu, &high_var, &cfg).unwrap();
127
128 let mag_low: f32 = out_low.iter().map(|v| v * v).sum::<f32>().sqrt();
130 let mag_high: f32 = out_high.iter().map(|v| v * v).sum::<f32>().sqrt();
131 assert!(
132 mag_high < mag_low,
133 "high variance should attenuate encoding"
134 );
135 }
136}