Skip to main content

oxicuda_nerf/encoding/
integrated_pe.rs

1//! Mip-NeRF Integrated Positional Encoding (IPE).
2//!
3//! Encodes a Gaussian-approximated cone frustum rather than a single point.
4//! For a Gaussian with mean μ and diagonal variance σ²:
5//!
6//! `E[sin(ω·x)] = sin(ω·μ) · exp(-ω²·σ²/2)`
7//! `E[cos(ω·x)] = cos(ω·μ) · exp(-ω²·σ²/2)`
8//!
9//! where `ω = 2^k · π` for frequency level k.
10
11use crate::error::{NerfError, NerfResult};
12
13/// Configuration for integrated positional encoding.
14#[derive(Debug, Clone, Copy)]
15pub struct IpeConfig {
16    /// Number of frequency levels L.
17    pub n_freq: usize,
18    /// Input dimensionality.
19    pub input_dim: usize,
20}
21
22impl IpeConfig {
23    /// Output dimensionality: `input_dim * 2 * n_freq`.
24    #[must_use]
25    pub fn output_dim(&self) -> usize {
26        self.input_dim * 2 * self.n_freq
27    }
28}
29
30/// Compute IPE for a Gaussian with mean `mu` and diagonal variance `diag_sigma2`.
31///
32/// Output layout: for each frequency level k, for each dimension d:
33///   `[sin(ω·μ_d)·exp(-ω²·σ²_d/2), cos(ω·μ_d)·exp(-ω²·σ²_d/2)]`
34///
35/// # Errors
36///
37/// Returns `DimensionMismatch` if `mu.len() != diag_sigma2.len()` or they
38/// don't match `cfg.input_dim`, or `InvalidFreqLevels` if `n_freq == 0`.
39pub fn integrated_pe(mu: &[f32], diag_sigma2: &[f32], cfg: &IpeConfig) -> NerfResult<Vec<f32>> {
40    if cfg.n_freq == 0 {
41        return Err(NerfError::InvalidFreqLevels { levels: 0 });
42    }
43    if mu.len() != cfg.input_dim {
44        return Err(NerfError::DimensionMismatch {
45            expected: cfg.input_dim,
46            got: mu.len(),
47        });
48    }
49    if diag_sigma2.len() != cfg.input_dim {
50        return Err(NerfError::DimensionMismatch {
51            expected: cfg.input_dim,
52            got: diag_sigma2.len(),
53        });
54    }
55
56    let out_dim = cfg.output_dim();
57    let mut out = vec![0.0_f32; out_dim];
58    let mut write_pos = 0;
59
60    for k in 0..cfg.n_freq {
61        let omega = (2_u32.pow(k as u32)) as f32 * std::f32::consts::PI;
62        let omega_sq = omega * omega;
63
64        for d in 0..cfg.input_dim {
65            let mu_d = mu[d];
66            let sigma2_d = diag_sigma2[d].max(0.0);
67            let attenuation = (-0.5 * omega_sq * sigma2_d).exp();
68
69            out[write_pos] = (omega * mu_d).sin() * attenuation;
70            write_pos += 1;
71            out[write_pos] = (omega * mu_d).cos() * attenuation;
72            write_pos += 1;
73        }
74    }
75
76    Ok(out)
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn ipe_output_dim() {
85        let cfg = IpeConfig {
86            n_freq: 4,
87            input_dim: 3,
88        };
89        assert_eq!(cfg.output_dim(), 24);
90    }
91
92    #[test]
93    fn ipe_zero_variance_matches_pe() {
94        // With zero variance, IPE should match regular PE
95        let cfg = IpeConfig {
96            n_freq: 2,
97            input_dim: 1,
98        };
99        let mu = [0.5_f32];
100        let sigma2 = [0.0_f32];
101        let ipe_out = integrated_pe(&mu, &sigma2, &cfg).unwrap();
102
103        // Manually compute PE
104        let expected: Vec<f32> = (0..2)
105            .flat_map(|k| {
106                let omega = (2_u32.pow(k)) as f32 * std::f32::consts::PI;
107                [(omega * 0.5).sin(), (omega * 0.5).cos()]
108            })
109            .collect();
110
111        for (a, b) in ipe_out.iter().zip(expected.iter()) {
112            assert!((a - b).abs() < 1e-5, "IPE mismatch: {a} vs {b}");
113        }
114    }
115
116    #[test]
117    fn ipe_high_variance_attenuates() {
118        let cfg = IpeConfig {
119            n_freq: 4,
120            input_dim: 1,
121        };
122        let mu = [0.5_f32];
123        let low_var = [0.0001_f32];
124        let high_var = [100.0_f32];
125        let out_low = integrated_pe(&mu, &low_var, &cfg).unwrap();
126        let out_high = integrated_pe(&mu, &high_var, &cfg).unwrap();
127
128        // High variance should attenuate higher frequencies more
129        let mag_low: f32 = out_low.iter().map(|v| v * v).sum::<f32>().sqrt();
130        let mag_high: f32 = out_high.iter().map(|v| v * v).sum::<f32>().sqrt();
131        assert!(
132            mag_high < mag_low,
133            "high variance should attenuate encoding"
134        );
135    }
136}