Skip to main content

oxicuda_nerf/field/
hash_field.rs

1//! Instant-NGP style neural field: hash grid + tiny MLP decoder.
2
3use crate::encoding::hash_grid::{HashGrid, HashGridConfig};
4use crate::error::{NerfError, NerfResult};
5use crate::handle::LcgRng;
6
7// ─── HashField ───────────────────────────────────────────────────────────────
8
9/// Instant-NGP style neural field: multi-resolution hash grid + 2-layer MLP.
10///
11/// Architecture:
12/// - HashGrid query → `[n_levels * F]` features
13/// - Concat with `dir_enc` → `[(n_levels*F) + dir_enc_dim]`
14/// - Linear + ReLU → `hidden_dim`
15/// - Linear → `(1 + color_dim)`
16/// - sigma = ReLU(output\[0\]), color_feat = output\[1..\]
17#[derive(Debug, Clone)]
18pub struct HashField {
19    /// Multi-resolution hash grid.
20    pub grid: HashGrid,
21    /// MLP layer 1 weights: `[hidden_dim * (n_levels*F + dir_enc_dim)]`.
22    mlp_w1: Vec<f32>,
23    /// MLP layer 1 bias: `[hidden_dim]`.
24    mlp_b1: Vec<f32>,
25    /// MLP layer 2 weights: `[(1 + color_dim) * hidden_dim]`.
26    mlp_w2: Vec<f32>,
27    /// MLP layer 2 bias: `[(1 + color_dim)]`.
28    mlp_b2: Vec<f32>,
29    /// Hidden layer width.
30    hidden_dim: usize,
31    /// Dimension of encoded view direction.
32    dir_enc_dim: usize,
33    /// Number of output color features.
34    color_dim: usize,
35}
36
37impl HashField {
38    /// Create a new `HashField`.
39    ///
40    /// # Errors
41    ///
42    /// Returns `InvalidHashConfig` or `InvalidFeatureDim` for bad parameters.
43    pub fn new(
44        grid_cfg: HashGridConfig,
45        hidden_dim: usize,
46        dir_enc_dim: usize,
47        color_dim: usize,
48        rng: &mut LcgRng,
49    ) -> NerfResult<Self> {
50        if hidden_dim == 0 {
51            return Err(NerfError::InvalidFeatureDim { dim: 0 });
52        }
53        if color_dim == 0 {
54            return Err(NerfError::InvalidFeatureDim { dim: 0 });
55        }
56        let grid = HashGrid::new(grid_cfg, rng)?;
57        let grid_feat_dim = grid.output_dim();
58        let in_dim = grid_feat_dim + dir_enc_dim;
59        let out_dim = 1 + color_dim;
60
61        let mut init = |fan_in: usize, fan_out: usize| -> (Vec<f32>, Vec<f32>) {
62            let s = (2.0_f32 / fan_in as f32).sqrt();
63            let mut w = vec![0.0_f32; fan_out * fan_in];
64            for v in w.iter_mut() {
65                let (a, _) = rng.next_normal_pair();
66                *v = a * s;
67            }
68            (w, vec![0.0_f32; fan_out])
69        };
70
71        let (mlp_w1, mlp_b1) = init(in_dim, hidden_dim);
72        let (mlp_w2, mlp_b2) = init(hidden_dim, out_dim);
73
74        Ok(Self {
75            grid,
76            mlp_w1,
77            mlp_b1,
78            mlp_w2,
79            mlp_b2,
80            hidden_dim,
81            dir_enc_dim,
82            color_dim,
83        })
84    }
85
86    /// Query the hash field at a 3D world point with encoded view direction.
87    ///
88    /// Returns `(sigma: f32, color_feat: Vec<f32>)` where color_feat has `color_dim` elements.
89    ///
90    /// # Errors
91    ///
92    /// Returns `DimensionMismatch` if `dir_enc.len() != dir_enc_dim`.
93    pub fn forward(&self, xyz: [f32; 3], dir_enc: &[f32]) -> NerfResult<(f32, Vec<f32>)> {
94        if dir_enc.len() != self.dir_enc_dim {
95            return Err(NerfError::DimensionMismatch {
96                expected: self.dir_enc_dim,
97                got: dir_enc.len(),
98            });
99        }
100
101        // Hash grid feature lookup
102        let grid_feat = self.grid.query(xyz)?;
103
104        // Concatenate with direction encoding
105        let mut input = Vec::with_capacity(grid_feat.len() + self.dir_enc_dim);
106        input.extend_from_slice(&grid_feat);
107        input.extend_from_slice(dir_enc);
108
109        let in_dim = input.len();
110        let h = self.hidden_dim;
111        let out_dim = 1 + self.color_dim;
112
113        // Layer 1: FC + ReLU
114        let mut hidden = vec![0.0_f32; h];
115        for (o, (wo, &bi)) in hidden
116            .iter_mut()
117            .zip(self.mlp_w1.chunks(in_dim).zip(self.mlp_b1.iter()))
118        {
119            *o = (wo
120                .iter()
121                .zip(input.iter())
122                .map(|(&wi, &xi)| wi * xi)
123                .sum::<f32>()
124                + bi)
125                .max(0.0);
126        }
127
128        // Layer 2: FC, no activation here
129        let mut out = vec![0.0_f32; out_dim];
130        for (o, (wo, &bi)) in out
131            .iter_mut()
132            .zip(self.mlp_w2.chunks(h).zip(self.mlp_b2.iter()))
133        {
134            *o = wo
135                .iter()
136                .zip(hidden.iter())
137                .map(|(&wi, &xi)| wi * xi)
138                .sum::<f32>()
139                + bi;
140        }
141
142        let sigma = out[0].max(0.0);
143        let color_feat = out[1..].to_vec();
144
145        Ok((sigma, color_feat))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn make_hash_field(seed: u64) -> HashField {
154        let cfg = HashGridConfig {
155            n_levels: 4,
156            n_features_per_level: 2,
157            log2_hashmap_size: 8,
158            base_resolution: 4,
159            max_resolution: 32,
160        };
161        let mut rng = LcgRng::new(seed);
162        HashField::new(cfg, 16, 8, 3, &mut rng).unwrap()
163    }
164
165    #[test]
166    fn forward_output_types() {
167        let hf = make_hash_field(42);
168        let dir_enc = vec![0.1_f32; 8];
169        let (sigma, color) = hf.forward([0.5, 0.3, 0.7], &dir_enc).unwrap();
170        assert!(sigma >= 0.0);
171        assert_eq!(color.len(), 3);
172    }
173
174    #[test]
175    fn wrong_dir_enc_dim() {
176        let hf = make_hash_field(99);
177        let dir_enc = vec![0.0_f32; 5]; // Wrong size (expected 8)
178        assert!(hf.forward([0.0, 0.0, 0.0], &dir_enc).is_err());
179    }
180}