oxicuda_nerf/field/
hash_field.rs1use crate::encoding::hash_grid::{HashGrid, HashGridConfig};
4use crate::error::{NerfError, NerfResult};
5use crate::handle::LcgRng;
6
7#[derive(Debug, Clone)]
18pub struct HashField {
19 pub grid: HashGrid,
21 mlp_w1: Vec<f32>,
23 mlp_b1: Vec<f32>,
25 mlp_w2: Vec<f32>,
27 mlp_b2: Vec<f32>,
29 hidden_dim: usize,
31 dir_enc_dim: usize,
33 color_dim: usize,
35}
36
37impl HashField {
38 pub fn new(
44 grid_cfg: HashGridConfig,
45 hidden_dim: usize,
46 dir_enc_dim: usize,
47 color_dim: usize,
48 rng: &mut LcgRng,
49 ) -> NerfResult<Self> {
50 if hidden_dim == 0 {
51 return Err(NerfError::InvalidFeatureDim { dim: 0 });
52 }
53 if color_dim == 0 {
54 return Err(NerfError::InvalidFeatureDim { dim: 0 });
55 }
56 let grid = HashGrid::new(grid_cfg, rng)?;
57 let grid_feat_dim = grid.output_dim();
58 let in_dim = grid_feat_dim + dir_enc_dim;
59 let out_dim = 1 + color_dim;
60
61 let mut init = |fan_in: usize, fan_out: usize| -> (Vec<f32>, Vec<f32>) {
62 let s = (2.0_f32 / fan_in as f32).sqrt();
63 let mut w = vec![0.0_f32; fan_out * fan_in];
64 for v in w.iter_mut() {
65 let (a, _) = rng.next_normal_pair();
66 *v = a * s;
67 }
68 (w, vec![0.0_f32; fan_out])
69 };
70
71 let (mlp_w1, mlp_b1) = init(in_dim, hidden_dim);
72 let (mlp_w2, mlp_b2) = init(hidden_dim, out_dim);
73
74 Ok(Self {
75 grid,
76 mlp_w1,
77 mlp_b1,
78 mlp_w2,
79 mlp_b2,
80 hidden_dim,
81 dir_enc_dim,
82 color_dim,
83 })
84 }
85
86 pub fn forward(&self, xyz: [f32; 3], dir_enc: &[f32]) -> NerfResult<(f32, Vec<f32>)> {
94 if dir_enc.len() != self.dir_enc_dim {
95 return Err(NerfError::DimensionMismatch {
96 expected: self.dir_enc_dim,
97 got: dir_enc.len(),
98 });
99 }
100
101 let grid_feat = self.grid.query(xyz)?;
103
104 let mut input = Vec::with_capacity(grid_feat.len() + self.dir_enc_dim);
106 input.extend_from_slice(&grid_feat);
107 input.extend_from_slice(dir_enc);
108
109 let in_dim = input.len();
110 let h = self.hidden_dim;
111 let out_dim = 1 + self.color_dim;
112
113 let mut hidden = vec![0.0_f32; h];
115 for (o, (wo, &bi)) in hidden
116 .iter_mut()
117 .zip(self.mlp_w1.chunks(in_dim).zip(self.mlp_b1.iter()))
118 {
119 *o = (wo
120 .iter()
121 .zip(input.iter())
122 .map(|(&wi, &xi)| wi * xi)
123 .sum::<f32>()
124 + bi)
125 .max(0.0);
126 }
127
128 let mut out = vec![0.0_f32; out_dim];
130 for (o, (wo, &bi)) in out
131 .iter_mut()
132 .zip(self.mlp_w2.chunks(h).zip(self.mlp_b2.iter()))
133 {
134 *o = wo
135 .iter()
136 .zip(hidden.iter())
137 .map(|(&wi, &xi)| wi * xi)
138 .sum::<f32>()
139 + bi;
140 }
141
142 let sigma = out[0].max(0.0);
143 let color_feat = out[1..].to_vec();
144
145 Ok((sigma, color_feat))
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn make_hash_field(seed: u64) -> HashField {
154 let cfg = HashGridConfig {
155 n_levels: 4,
156 n_features_per_level: 2,
157 log2_hashmap_size: 8,
158 base_resolution: 4,
159 max_resolution: 32,
160 };
161 let mut rng = LcgRng::new(seed);
162 HashField::new(cfg, 16, 8, 3, &mut rng).unwrap()
163 }
164
165 #[test]
166 fn forward_output_types() {
167 let hf = make_hash_field(42);
168 let dir_enc = vec![0.1_f32; 8];
169 let (sigma, color) = hf.forward([0.5, 0.3, 0.7], &dir_enc).unwrap();
170 assert!(sigma >= 0.0);
171 assert_eq!(color.len(), 3);
172 }
173
174 #[test]
175 fn wrong_dir_enc_dim() {
176 let hf = make_hash_field(99);
177 let dir_enc = vec![0.0_f32; 5]; assert!(hf.forward([0.0, 0.0, 0.0], &dir_enc).is_err());
179 }
180}