Skip to main content

oxicuda_nerf/encoding/
hash_grid.rs

1//! Instant-NGP multi-resolution hash grid encoding.
2//!
3//! L levels, T buckets per level, F features per entry.
4//! Resolution at level l: `N_l = floor(N_min * b^l)` where `b = exp(ln(N_max/N_min)/(L-1))`.
5//! Hash: `h(x1,x2,x3) = (x1 XOR x2*pi2 XOR x3*pi3) % T`
6//! with pi1=1, pi2=2654435761, pi3=805459861.
7
8use crate::error::{NerfError, NerfResult};
9use crate::handle::LcgRng;
10
11const PI2: u64 = 2_654_435_761;
12const PI3: u64 = 805_459_861;
13
14// ─── Config ──────────────────────────────────────────────────────────────────
15
16/// Configuration for the multi-resolution hash grid.
17#[derive(Debug, Clone)]
18pub struct HashGridConfig {
19    /// L — number of resolution levels (typical: 16).
20    pub n_levels: usize,
21    /// F — number of features per hash-table entry (typical: 2).
22    pub n_features_per_level: usize,
23    /// log2(T) where T = 2^this is the number of hash buckets (typical: 19 → T=524288).
24    pub log2_hashmap_size: usize,
25    /// N_min — base (coarsest) grid resolution (typical: 16).
26    pub base_resolution: usize,
27    /// N_max — finest grid resolution (typical: 2048).
28    pub max_resolution: usize,
29}
30
31// ─── HashGrid ────────────────────────────────────────────────────────────────
32
33/// Multi-resolution hash grid with trilinear interpolation.
34#[derive(Debug, Clone)]
35pub struct HashGrid {
36    /// Grid configuration.
37    pub config: HashGridConfig,
38    /// Flat feature storage: `[n_levels * T * F]`.
39    pub data: Vec<f32>,
40    /// Per-level grid resolution N_l.
41    level_resolutions: Vec<usize>,
42}
43
44impl HashGrid {
45    /// Create a new hash grid with parameters initialized to U(-0.0001, 0.0001).
46    ///
47    /// # Errors
48    ///
49    /// Returns `InvalidHashConfig` for invalid configuration.
50    pub fn new(cfg: HashGridConfig, rng: &mut LcgRng) -> NerfResult<Self> {
51        if cfg.n_levels == 0 {
52            return Err(NerfError::InvalidHashConfig {
53                msg: "n_levels must be > 0".into(),
54            });
55        }
56        if cfg.n_features_per_level == 0 {
57            return Err(NerfError::InvalidHashConfig {
58                msg: "n_features_per_level must be > 0".into(),
59            });
60        }
61        if cfg.log2_hashmap_size == 0 || cfg.log2_hashmap_size > 32 {
62            return Err(NerfError::InvalidHashConfig {
63                msg: "log2_hashmap_size must be in 1..=32".into(),
64            });
65        }
66        if cfg.base_resolution == 0 {
67            return Err(NerfError::InvalidHashConfig {
68                msg: "base_resolution must be > 0".into(),
69            });
70        }
71        if cfg.max_resolution < cfg.base_resolution {
72            return Err(NerfError::InvalidHashConfig {
73                msg: "max_resolution must be >= base_resolution".into(),
74            });
75        }
76
77        let t = 1_usize << cfg.log2_hashmap_size;
78
79        // Compute per-level resolutions
80        let level_resolutions = if cfg.n_levels == 1 {
81            vec![cfg.base_resolution]
82        } else {
83            let b = ((cfg.max_resolution as f64) / (cfg.base_resolution as f64)).ln()
84                / (cfg.n_levels - 1) as f64;
85            (0..cfg.n_levels)
86                .map(|l| {
87                    let n_l = (cfg.base_resolution as f64 * (b * l as f64).exp()).floor() as usize;
88                    n_l.max(1)
89                })
90                .collect()
91        };
92
93        let total = cfg.n_levels * t * cfg.n_features_per_level;
94        let mut data = vec![0.0_f32; total];
95        for v in data.iter_mut() {
96            *v = rng.next_f32_range(-0.0001, 0.0001);
97        }
98
99        Ok(Self {
100            config: cfg,
101            data,
102            level_resolutions,
103        })
104    }
105
106    /// Total output dimension: `n_levels * n_features_per_level`.
107    #[must_use]
108    pub fn output_dim(&self) -> usize {
109        self.config.n_levels * self.config.n_features_per_level
110    }
111
112    /// Query a single 3D point in `[0, 1]^3`.
113    ///
114    /// Returns a feature vector of length `output_dim`.
115    ///
116    /// # Errors
117    ///
118    /// Returns `DimensionMismatch` for wrong input size.
119    pub fn query(&self, xyz: [f32; 3]) -> NerfResult<Vec<f32>> {
120        let t = 1_usize << self.config.log2_hashmap_size;
121        let f = self.config.n_features_per_level;
122        let mut out = vec![0.0_f32; self.output_dim()];
123
124        for (level, &n_l) in self.level_resolutions.iter().enumerate() {
125            // Scale xyz to level resolution [0, N_l]
126            let sx = xyz[0].clamp(0.0, 1.0) * (n_l as f32);
127            let sy = xyz[1].clamp(0.0, 1.0) * (n_l as f32);
128            let sz = xyz[2].clamp(0.0, 1.0) * (n_l as f32);
129
130            let ix = sx.floor() as i64;
131            let iy = sy.floor() as i64;
132            let iz = sz.floor() as i64;
133            let fx = sx - ix as f32;
134            let fy = sy - iy as f32;
135            let fz = sz - iz as f32;
136
137            let level_offset = level * t * f;
138
139            // Trilinear interpolation over 8 corners
140            for cx in 0_u8..=1 {
141                for cy in 0_u8..=1 {
142                    for cz in 0_u8..=1 {
143                        let xi = ix + i64::from(cx);
144                        let yi = iy + i64::from(cy);
145                        let zi = iz + i64::from(cz);
146
147                        let bucket = hash_coord(xi, yi, zi, t);
148                        let w = trilinear_weight(fx, fy, fz, cx, cy, cz);
149                        let base = level_offset + bucket * f;
150
151                        let out_base = level * f;
152                        for feat in 0..f {
153                            out[out_base + feat] += w * self.data[base + feat];
154                        }
155                    }
156                }
157            }
158        }
159
160        Ok(out)
161    }
162
163    /// Batch query: `xyz_batch` is a flat `[N * 3]` array.
164    ///
165    /// Returns `[N * output_dim]`.
166    ///
167    /// # Errors
168    ///
169    /// Returns `DimensionMismatch` if `xyz_batch.len() != n * 3`.
170    pub fn query_batch(&self, xyz_batch: &[f32], n: usize) -> NerfResult<Vec<f32>> {
171        if xyz_batch.len() != n * 3 {
172            return Err(NerfError::DimensionMismatch {
173                expected: n * 3,
174                got: xyz_batch.len(),
175            });
176        }
177        let out_dim = self.output_dim();
178        let mut out = vec![0.0_f32; n * out_dim];
179
180        for (i, out_chunk) in out.chunks_mut(out_dim).enumerate() {
181            let x = xyz_batch[i * 3];
182            let y = xyz_batch[i * 3 + 1];
183            let z = xyz_batch[i * 3 + 2];
184            let feat = self.query([x, y, z])?;
185            out_chunk.copy_from_slice(&feat);
186        }
187
188        Ok(out)
189    }
190}
191
192// ─── Internal helpers ────────────────────────────────────────────────────────
193
194/// Hash a grid cell coordinate to a bucket index in `[0, t)`.
195#[inline]
196fn hash_coord(xi: i64, yi: i64, zi: i64, t: usize) -> usize {
197    let hx = xi as u64;
198    let hy = (yi as u64).wrapping_mul(PI2);
199    let hz = (zi as u64).wrapping_mul(PI3);
200    (hx ^ hy ^ hz) as usize % t
201}
202
203/// Trilinear interpolation weight for corner (cx, cy, cz) given fractional (fx, fy, fz).
204#[inline]
205fn trilinear_weight(fx: f32, fy: f32, fz: f32, cx: u8, cy: u8, cz: u8) -> f32 {
206    let wx = if cx == 1 { fx } else { 1.0 - fx };
207    let wy = if cy == 1 { fy } else { 1.0 - fy };
208    let wz = if cz == 1 { fz } else { 1.0 - fz };
209    wx * wy * wz
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    fn make_grid(seed: u64) -> HashGrid {
217        let cfg = HashGridConfig {
218            n_levels: 4,
219            n_features_per_level: 2,
220            log2_hashmap_size: 8,
221            base_resolution: 4,
222            max_resolution: 32,
223        };
224        let mut rng = LcgRng::new(seed);
225        HashGrid::new(cfg, &mut rng).unwrap()
226    }
227
228    #[test]
229    fn query_output_shape() {
230        let grid = make_grid(1);
231        let feat = grid.query([0.5, 0.5, 0.5]).unwrap();
232        assert_eq!(feat.len(), grid.output_dim());
233    }
234
235    #[test]
236    fn batch_output_shape() {
237        let grid = make_grid(2);
238        let pts: Vec<f32> = (0..5).flat_map(|i| [i as f32 * 0.2; 3]).collect();
239        let out = grid.query_batch(&pts, 5).unwrap();
240        assert_eq!(out.len(), 5 * grid.output_dim());
241    }
242
243    #[test]
244    fn hash_coord_deterministic() {
245        assert_eq!(hash_coord(1, 2, 3, 256), hash_coord(1, 2, 3, 256));
246    }
247
248    #[test]
249    fn trilinear_weights_sum_to_one() {
250        let (fx, fy, fz) = (0.3, 0.7, 0.1);
251        let mut sum = 0.0_f32;
252        for cx in 0_u8..=1 {
253            for cy in 0_u8..=1 {
254                for cz in 0_u8..=1 {
255                    sum += trilinear_weight(fx, fy, fz, cx, cy, cz);
256                }
257            }
258        }
259        assert!((sum - 1.0).abs() < 1e-6, "weights sum={sum}");
260    }
261}