oxicuda_nerf/encoding/
positional.rs1use crate::error::{NerfError, NerfResult};
11
12#[derive(Debug, Clone, Copy)]
14pub struct PosEncConfig {
15 pub n_freq: usize,
17 pub include_input: bool,
19 pub input_dim: usize,
21}
22
23impl PosEncConfig {
24 #[must_use]
26 pub fn output_dim(&self) -> usize {
27 let enc = self.input_dim * 2 * self.n_freq;
28 if self.include_input {
29 enc + self.input_dim
30 } else {
31 enc
32 }
33 }
34}
35
36pub fn positional_encode(input: &[f32], cfg: &PosEncConfig) -> NerfResult<Vec<f32>> {
47 if cfg.n_freq == 0 {
48 return Err(NerfError::InvalidFreqLevels { levels: 0 });
49 }
50 if input.is_empty() {
51 return Err(NerfError::EmptyInput);
52 }
53 if !input.len().is_multiple_of(cfg.input_dim) {
54 return Err(NerfError::DimensionMismatch {
55 expected: cfg.input_dim,
56 got: input.len() % cfg.input_dim,
57 });
58 }
59
60 let n = input.len() / cfg.input_dim;
61 let out_dim = cfg.output_dim();
62 let mut out = vec![0.0_f32; n * out_dim];
63
64 for (pt_idx, out_chunk) in out.chunks_mut(out_dim).enumerate() {
65 let in_chunk = &input[pt_idx * cfg.input_dim..(pt_idx + 1) * cfg.input_dim];
66 let mut write_pos = 0;
67
68 if cfg.include_input {
69 out_chunk[..cfg.input_dim].copy_from_slice(in_chunk);
70 write_pos += cfg.input_dim;
71 }
72
73 for k in 0..cfg.n_freq {
74 let freq = (2_u32.pow(k as u32)) as f32 * std::f32::consts::PI;
75 for &val in in_chunk {
76 out_chunk[write_pos] = (freq * val).sin();
77 write_pos += 1;
78 out_chunk[write_pos] = (freq * val).cos();
79 write_pos += 1;
80 }
81 }
82 }
83
84 Ok(out)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn output_dim_no_include() {
93 let cfg = PosEncConfig {
94 n_freq: 4,
95 include_input: false,
96 input_dim: 3,
97 };
98 assert_eq!(cfg.output_dim(), 24);
99 }
100
101 #[test]
102 fn output_dim_with_include() {
103 let cfg = PosEncConfig {
104 n_freq: 4,
105 include_input: true,
106 input_dim: 3,
107 };
108 assert_eq!(cfg.output_dim(), 27);
109 }
110
111 #[test]
112 fn single_point_shape() {
113 let cfg = PosEncConfig {
114 n_freq: 2,
115 include_input: false,
116 input_dim: 3,
117 };
118 let input = vec![0.1_f32, 0.2, 0.3];
119 let out = positional_encode(&input, &cfg).unwrap();
120 assert_eq!(out.len(), cfg.output_dim());
121 }
122
123 #[test]
124 fn error_on_zero_freq() {
125 let cfg = PosEncConfig {
126 n_freq: 0,
127 include_input: false,
128 input_dim: 3,
129 };
130 assert!(positional_encode(&[0.0_f32; 3], &cfg).is_err());
131 }
132}