oxicuda_nerf/field/
tensorf.rs1use crate::error::{NerfError, NerfResult};
12use crate::handle::LcgRng;
13
14#[derive(Debug, Clone)]
18pub struct TensorRfConfig {
19 pub rank: usize,
21 pub grid_dim: usize,
23 pub n_color_feat: usize,
25}
26
27#[derive(Debug, Clone)]
31pub struct TensorRf {
32 density_vecs: Vec<f32>,
34 color_vecs: Vec<f32>,
36 config: TensorRfConfig,
38}
39
40impl TensorRf {
41 pub fn new(cfg: TensorRfConfig, rng: &mut LcgRng) -> NerfResult<Self> {
47 if cfg.rank == 0 {
48 return Err(NerfError::TensorDecompError {
49 msg: "rank must be > 0".into(),
50 });
51 }
52 if cfg.grid_dim == 0 {
53 return Err(NerfError::TensorDecompError {
54 msg: "grid_dim must be > 0".into(),
55 });
56 }
57 if cfg.n_color_feat == 0 {
58 return Err(NerfError::TensorDecompError {
59 msg: "n_color_feat must be > 0".into(),
60 });
61 }
62
63 let density_size = cfg.rank * 3 * cfg.grid_dim;
64 let color_size = cfg.rank * 3 * cfg.grid_dim * cfg.n_color_feat;
65
66 let mut density_vecs = vec![0.0_f32; density_size];
67 let mut color_vecs = vec![0.0_f32; color_size];
68
69 let scale = 0.01_f32;
70 for v in density_vecs.iter_mut() {
71 let (a, _) = rng.next_normal_pair();
72 *v = a * scale;
73 }
74 for v in color_vecs.iter_mut() {
75 let (a, _) = rng.next_normal_pair();
76 *v = a * scale;
77 }
78
79 Ok(Self {
80 density_vecs,
81 color_vecs,
82 config: cfg,
83 })
84 }
85
86 pub fn query_density(&self, xyz: [f32; 3]) -> NerfResult<f32> {
94 let g = self.config.grid_dim;
95 let r = self.config.rank;
96
97 let mut sum = 0.0_f32;
98 for rank_idx in 0..r {
99 let x_val = interp_vector(
101 &self.density_vecs[rank_idx * 3 * g..rank_idx * 3 * g + g],
102 xyz[0],
103 );
104 let y_val = interp_vector(
105 &self.density_vecs[rank_idx * 3 * g + g..rank_idx * 3 * g + 2 * g],
106 xyz[1],
107 );
108 let z_val = interp_vector(
109 &self.density_vecs[rank_idx * 3 * g + 2 * g..rank_idx * 3 * g + 3 * g],
110 xyz[2],
111 );
112 sum += x_val * y_val * z_val;
113 }
114
115 if !sum.is_finite() {
116 return Err(NerfError::NanEncountered {
117 context: "TensorRf::query_density".into(),
118 });
119 }
120
121 Ok(sum.max(0.0))
122 }
123
124 pub fn query_color(&self, xyz: [f32; 3]) -> NerfResult<Vec<f32>> {
132 let g = self.config.grid_dim;
133 let r = self.config.rank;
134 let nf = self.config.n_color_feat;
135
136 let mut out = vec![0.0_f32; nf];
137
138 for rank_idx in 0..r {
139 let base = rank_idx * 3 * g * nf;
140 let x_base = base;
142 let y_base = base + g * nf;
143 let z_base = base + 2 * g * nf;
144
145 let x_val = interp_vector_scalar(&self.color_vecs[x_base..x_base + g], xyz[0]);
146 let y_val = interp_vector_scalar(&self.color_vecs[y_base..y_base + g], xyz[1]);
147 let z_val = interp_vector_scalar(&self.color_vecs[z_base..z_base + g], xyz[2]);
148
149 let scalar = x_val * y_val * z_val;
150 for feat in out.iter_mut() {
153 *feat += scalar;
154 }
155 let _ = nf; }
157
158 for v in &out {
159 if !v.is_finite() {
160 return Err(NerfError::NanEncountered {
161 context: "TensorRf::query_color".into(),
162 });
163 }
164 }
165
166 Ok(out)
167 }
168
169 #[must_use]
171 pub fn param_count(&self) -> usize {
172 let g = self.config.grid_dim;
173 let r = self.config.rank;
174 let nf = self.config.n_color_feat;
175 r * 3 * g + r * 3 * g * nf
176 }
177}
178
179fn interp_vector(vec: &[f32], coord: f32) -> f32 {
183 let g = vec.len();
184 if g == 0 {
185 return 0.0;
186 }
187 if g == 1 {
188 return vec[0];
189 }
190 let t = (coord.clamp(-1.0, 1.0) + 1.0) * 0.5 * (g - 1) as f32;
192 let lo = t.floor() as usize;
193 let hi = (lo + 1).min(g - 1);
194 let frac = t - lo as f32;
195 vec[lo] * (1.0 - frac) + vec[hi] * frac
196}
197
198fn interp_vector_scalar(vec: &[f32], coord: f32) -> f32 {
200 interp_vector(vec, coord)
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 fn make_tensorf(seed: u64) -> TensorRf {
208 let cfg = TensorRfConfig {
209 rank: 4,
210 grid_dim: 8,
211 n_color_feat: 3,
212 };
213 let mut rng = LcgRng::new(seed);
214 TensorRf::new(cfg, &mut rng).unwrap()
215 }
216
217 #[test]
218 fn density_nonneg() {
219 let tf = make_tensorf(42);
220 let d = tf.query_density([0.1, -0.3, 0.5]).unwrap();
221 assert!(d >= 0.0);
222 }
223
224 #[test]
225 fn color_shape() {
226 let tf = make_tensorf(17);
227 let c = tf.query_color([0.0, 0.0, 0.0]).unwrap();
228 assert_eq!(c.len(), tf.config.n_color_feat);
229 }
230
231 #[test]
232 fn param_count() {
233 let cfg = TensorRfConfig {
234 rank: 4,
235 grid_dim: 8,
236 n_color_feat: 3,
237 };
238 let mut rng = LcgRng::new(1);
239 let tf = TensorRf::new(cfg, &mut rng).unwrap();
240 assert_eq!(tf.param_count(), 96 + 288);
242 }
243}