Skip to main content

oxicuda_nerf/encoding/
positional.rs

1//! NeRF positional encoding: γ(p).
2//!
3//! For each input dimension d and frequency level k:
4//!   [sin(2^k · π · p_d), cos(2^k · π · p_d)]
5//!
6//! With L levels and `input_dim` inputs:
7//!   encoded_dim = input_dim * 2 * L
8//!   (optionally: encoded_dim += input_dim if `include_input`)
9
10use crate::error::{NerfError, NerfResult};
11
12/// Configuration for positional encoding.
13#[derive(Debug, Clone, Copy)]
14pub struct PosEncConfig {
15    /// L — number of frequency levels.
16    pub n_freq: usize,
17    /// Whether to prepend the raw (unencoded) input.
18    pub include_input: bool,
19    /// Input dimensionality (3 for xyz, 3 for view direction).
20    pub input_dim: usize,
21}
22
23impl PosEncConfig {
24    /// Compute the output dimensionality of the encoding.
25    #[must_use]
26    pub fn output_dim(&self) -> usize {
27        let enc = self.input_dim * 2 * self.n_freq;
28        if self.include_input {
29            enc + self.input_dim
30        } else {
31            enc
32        }
33    }
34}
35
36/// Encode a flat array of N input vectors of length `cfg.input_dim`.
37///
38/// Input layout: `[x0_0, x0_1, ..., x0_{D-1}, x1_0, ...]` (N × D).
39/// Output layout: `[enc(x0), enc(x1), ...]` (N × output_dim).
40///
41/// # Errors
42///
43/// Returns `InvalidFreqLevels` if `n_freq == 0`,
44/// `EmptyInput` if `input` is empty,
45/// `DimensionMismatch` if `input.len() % input_dim != 0`.
46pub fn positional_encode(input: &[f32], cfg: &PosEncConfig) -> NerfResult<Vec<f32>> {
47    if cfg.n_freq == 0 {
48        return Err(NerfError::InvalidFreqLevels { levels: 0 });
49    }
50    if input.is_empty() {
51        return Err(NerfError::EmptyInput);
52    }
53    if !input.len().is_multiple_of(cfg.input_dim) {
54        return Err(NerfError::DimensionMismatch {
55            expected: cfg.input_dim,
56            got: input.len() % cfg.input_dim,
57        });
58    }
59
60    let n = input.len() / cfg.input_dim;
61    let out_dim = cfg.output_dim();
62    let mut out = vec![0.0_f32; n * out_dim];
63
64    for (pt_idx, out_chunk) in out.chunks_mut(out_dim).enumerate() {
65        let in_chunk = &input[pt_idx * cfg.input_dim..(pt_idx + 1) * cfg.input_dim];
66        let mut write_pos = 0;
67
68        if cfg.include_input {
69            out_chunk[..cfg.input_dim].copy_from_slice(in_chunk);
70            write_pos += cfg.input_dim;
71        }
72
73        for k in 0..cfg.n_freq {
74            let freq = (2_u32.pow(k as u32)) as f32 * std::f32::consts::PI;
75            for &val in in_chunk {
76                out_chunk[write_pos] = (freq * val).sin();
77                write_pos += 1;
78                out_chunk[write_pos] = (freq * val).cos();
79                write_pos += 1;
80            }
81        }
82    }
83
84    Ok(out)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn output_dim_no_include() {
93        let cfg = PosEncConfig {
94            n_freq: 4,
95            include_input: false,
96            input_dim: 3,
97        };
98        assert_eq!(cfg.output_dim(), 24);
99    }
100
101    #[test]
102    fn output_dim_with_include() {
103        let cfg = PosEncConfig {
104            n_freq: 4,
105            include_input: true,
106            input_dim: 3,
107        };
108        assert_eq!(cfg.output_dim(), 27);
109    }
110
111    #[test]
112    fn single_point_shape() {
113        let cfg = PosEncConfig {
114            n_freq: 2,
115            include_input: false,
116            input_dim: 3,
117        };
118        let input = vec![0.1_f32, 0.2, 0.3];
119        let out = positional_encode(&input, &cfg).unwrap();
120        assert_eq!(out.len(), cfg.output_dim());
121    }
122
123    #[test]
124    fn error_on_zero_freq() {
125        let cfg = PosEncConfig {
126            n_freq: 0,
127            include_input: false,
128            input_dim: 3,
129        };
130        assert!(positional_encode(&[0.0_f32; 3], &cfg).is_err());
131    }
132}