oxicuda_nerf/encoding/
hash_grid.rs1use crate::error::{NerfError, NerfResult};
9use crate::handle::LcgRng;
10
11const PI2: u64 = 2_654_435_761;
12const PI3: u64 = 805_459_861;
13
14#[derive(Debug, Clone)]
18pub struct HashGridConfig {
19 pub n_levels: usize,
21 pub n_features_per_level: usize,
23 pub log2_hashmap_size: usize,
25 pub base_resolution: usize,
27 pub max_resolution: usize,
29}
30
31#[derive(Debug, Clone)]
35pub struct HashGrid {
36 pub config: HashGridConfig,
38 pub data: Vec<f32>,
40 level_resolutions: Vec<usize>,
42}
43
44impl HashGrid {
45 pub fn new(cfg: HashGridConfig, rng: &mut LcgRng) -> NerfResult<Self> {
51 if cfg.n_levels == 0 {
52 return Err(NerfError::InvalidHashConfig {
53 msg: "n_levels must be > 0".into(),
54 });
55 }
56 if cfg.n_features_per_level == 0 {
57 return Err(NerfError::InvalidHashConfig {
58 msg: "n_features_per_level must be > 0".into(),
59 });
60 }
61 if cfg.log2_hashmap_size == 0 || cfg.log2_hashmap_size > 32 {
62 return Err(NerfError::InvalidHashConfig {
63 msg: "log2_hashmap_size must be in 1..=32".into(),
64 });
65 }
66 if cfg.base_resolution == 0 {
67 return Err(NerfError::InvalidHashConfig {
68 msg: "base_resolution must be > 0".into(),
69 });
70 }
71 if cfg.max_resolution < cfg.base_resolution {
72 return Err(NerfError::InvalidHashConfig {
73 msg: "max_resolution must be >= base_resolution".into(),
74 });
75 }
76
77 let t = 1_usize << cfg.log2_hashmap_size;
78
79 let level_resolutions = if cfg.n_levels == 1 {
81 vec![cfg.base_resolution]
82 } else {
83 let b = ((cfg.max_resolution as f64) / (cfg.base_resolution as f64)).ln()
84 / (cfg.n_levels - 1) as f64;
85 (0..cfg.n_levels)
86 .map(|l| {
87 let n_l = (cfg.base_resolution as f64 * (b * l as f64).exp()).floor() as usize;
88 n_l.max(1)
89 })
90 .collect()
91 };
92
93 let total = cfg.n_levels * t * cfg.n_features_per_level;
94 let mut data = vec![0.0_f32; total];
95 for v in data.iter_mut() {
96 *v = rng.next_f32_range(-0.0001, 0.0001);
97 }
98
99 Ok(Self {
100 config: cfg,
101 data,
102 level_resolutions,
103 })
104 }
105
106 #[must_use]
108 pub fn output_dim(&self) -> usize {
109 self.config.n_levels * self.config.n_features_per_level
110 }
111
112 pub fn query(&self, xyz: [f32; 3]) -> NerfResult<Vec<f32>> {
120 let t = 1_usize << self.config.log2_hashmap_size;
121 let f = self.config.n_features_per_level;
122 let mut out = vec![0.0_f32; self.output_dim()];
123
124 for (level, &n_l) in self.level_resolutions.iter().enumerate() {
125 let sx = xyz[0].clamp(0.0, 1.0) * (n_l as f32);
127 let sy = xyz[1].clamp(0.0, 1.0) * (n_l as f32);
128 let sz = xyz[2].clamp(0.0, 1.0) * (n_l as f32);
129
130 let ix = sx.floor() as i64;
131 let iy = sy.floor() as i64;
132 let iz = sz.floor() as i64;
133 let fx = sx - ix as f32;
134 let fy = sy - iy as f32;
135 let fz = sz - iz as f32;
136
137 let level_offset = level * t * f;
138
139 for cx in 0_u8..=1 {
141 for cy in 0_u8..=1 {
142 for cz in 0_u8..=1 {
143 let xi = ix + i64::from(cx);
144 let yi = iy + i64::from(cy);
145 let zi = iz + i64::from(cz);
146
147 let bucket = hash_coord(xi, yi, zi, t);
148 let w = trilinear_weight(fx, fy, fz, cx, cy, cz);
149 let base = level_offset + bucket * f;
150
151 let out_base = level * f;
152 for feat in 0..f {
153 out[out_base + feat] += w * self.data[base + feat];
154 }
155 }
156 }
157 }
158 }
159
160 Ok(out)
161 }
162
163 pub fn query_batch(&self, xyz_batch: &[f32], n: usize) -> NerfResult<Vec<f32>> {
171 if xyz_batch.len() != n * 3 {
172 return Err(NerfError::DimensionMismatch {
173 expected: n * 3,
174 got: xyz_batch.len(),
175 });
176 }
177 let out_dim = self.output_dim();
178 let mut out = vec![0.0_f32; n * out_dim];
179
180 for (i, out_chunk) in out.chunks_mut(out_dim).enumerate() {
181 let x = xyz_batch[i * 3];
182 let y = xyz_batch[i * 3 + 1];
183 let z = xyz_batch[i * 3 + 2];
184 let feat = self.query([x, y, z])?;
185 out_chunk.copy_from_slice(&feat);
186 }
187
188 Ok(out)
189 }
190}
191
192#[inline]
196fn hash_coord(xi: i64, yi: i64, zi: i64, t: usize) -> usize {
197 let hx = xi as u64;
198 let hy = (yi as u64).wrapping_mul(PI2);
199 let hz = (zi as u64).wrapping_mul(PI3);
200 (hx ^ hy ^ hz) as usize % t
201}
202
203#[inline]
205fn trilinear_weight(fx: f32, fy: f32, fz: f32, cx: u8, cy: u8, cz: u8) -> f32 {
206 let wx = if cx == 1 { fx } else { 1.0 - fx };
207 let wy = if cy == 1 { fy } else { 1.0 - fy };
208 let wz = if cz == 1 { fz } else { 1.0 - fz };
209 wx * wy * wz
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 fn make_grid(seed: u64) -> HashGrid {
217 let cfg = HashGridConfig {
218 n_levels: 4,
219 n_features_per_level: 2,
220 log2_hashmap_size: 8,
221 base_resolution: 4,
222 max_resolution: 32,
223 };
224 let mut rng = LcgRng::new(seed);
225 HashGrid::new(cfg, &mut rng).unwrap()
226 }
227
228 #[test]
229 fn query_output_shape() {
230 let grid = make_grid(1);
231 let feat = grid.query([0.5, 0.5, 0.5]).unwrap();
232 assert_eq!(feat.len(), grid.output_dim());
233 }
234
235 #[test]
236 fn batch_output_shape() {
237 let grid = make_grid(2);
238 let pts: Vec<f32> = (0..5).flat_map(|i| [i as f32 * 0.2; 3]).collect();
239 let out = grid.query_batch(&pts, 5).unwrap();
240 assert_eq!(out.len(), 5 * grid.output_dim());
241 }
242
243 #[test]
244 fn hash_coord_deterministic() {
245 assert_eq!(hash_coord(1, 2, 3, 256), hash_coord(1, 2, 3, 256));
246 }
247
248 #[test]
249 fn trilinear_weights_sum_to_one() {
250 let (fx, fy, fz) = (0.3, 0.7, 0.1);
251 let mut sum = 0.0_f32;
252 for cx in 0_u8..=1 {
253 for cy in 0_u8..=1 {
254 for cz in 0_u8..=1 {
255 sum += trilinear_weight(fx, fy, fz, cx, cy, cz);
256 }
257 }
258 }
259 assert!((sum - 1.0).abs() < 1e-6, "weights sum={sum}");
260 }
261}