Skip to main content

oxicuda_nerf/field/
tensorf.rs

1//! TensoRF: CP (CANDECOMP/PARAFAC) tensor decomposition radiance field.
2//!
3//! Density field:
4//!   `σ(x,y,z) = ReLU(Σ_{r=1}^{R} v_r^X(x) · v_r^Y(y) · v_r^Z(z))`
5//!
6//! Color field:
7//!   `c(x,y,z) = Σ_{r=1}^{R} v_r^X_c(x) · v_r^Y_c(y) · v_r^Z_c(z)` → \[n_color_feat\]
8//!
9//! Vectors are stored flat; trilinear interpolation is used to query at continuous coords.
10
11use crate::error::{NerfError, NerfResult};
12use crate::handle::LcgRng;
13
14// ─── Config ──────────────────────────────────────────────────────────────────
15
16/// Configuration for TensoRF CP decomposition.
17#[derive(Debug, Clone)]
18pub struct TensorRfConfig {
19    /// R — number of CP rank components.
20    pub rank: usize,
21    /// Grid resolution per axis.
22    pub grid_dim: usize,
23    /// Number of output color features.
24    pub n_color_feat: usize,
25}
26
27// ─── TensorRf ────────────────────────────────────────────────────────────────
28
29/// TensoRF CP radiance field.
30#[derive(Debug, Clone)]
31pub struct TensorRf {
32    /// Density factor vectors: `[rank * 3 * grid_dim]` (3 axes per rank component).
33    density_vecs: Vec<f32>,
34    /// Color factor vectors: `[rank * 3 * grid_dim * n_color_feat]`.
35    color_vecs: Vec<f32>,
36    /// Configuration.
37    config: TensorRfConfig,
38}
39
40impl TensorRf {
41    /// Create a new TensoRF with small random initialization.
42    ///
43    /// # Errors
44    ///
45    /// Returns `TensorDecompError` if any dimension is zero.
46    pub fn new(cfg: TensorRfConfig, rng: &mut LcgRng) -> NerfResult<Self> {
47        if cfg.rank == 0 {
48            return Err(NerfError::TensorDecompError {
49                msg: "rank must be > 0".into(),
50            });
51        }
52        if cfg.grid_dim == 0 {
53            return Err(NerfError::TensorDecompError {
54                msg: "grid_dim must be > 0".into(),
55            });
56        }
57        if cfg.n_color_feat == 0 {
58            return Err(NerfError::TensorDecompError {
59                msg: "n_color_feat must be > 0".into(),
60            });
61        }
62
63        let density_size = cfg.rank * 3 * cfg.grid_dim;
64        let color_size = cfg.rank * 3 * cfg.grid_dim * cfg.n_color_feat;
65
66        let mut density_vecs = vec![0.0_f32; density_size];
67        let mut color_vecs = vec![0.0_f32; color_size];
68
69        let scale = 0.01_f32;
70        for v in density_vecs.iter_mut() {
71            let (a, _) = rng.next_normal_pair();
72            *v = a * scale;
73        }
74        for v in color_vecs.iter_mut() {
75            let (a, _) = rng.next_normal_pair();
76            *v = a * scale;
77        }
78
79        Ok(Self {
80            density_vecs,
81            color_vecs,
82            config: cfg,
83        })
84    }
85
86    /// Query density at a 3D point in `[-1, 1]^3`.
87    ///
88    /// Returns `ReLU(Σ_r v_r^X(x) * v_r^Y(y) * v_r^Z(z))`.
89    ///
90    /// # Errors
91    ///
92    /// Returns `NanEncountered` if an NaN occurs.
93    pub fn query_density(&self, xyz: [f32; 3]) -> NerfResult<f32> {
94        let g = self.config.grid_dim;
95        let r = self.config.rank;
96
97        let mut sum = 0.0_f32;
98        for rank_idx in 0..r {
99            // Each rank has 3 axis vectors of length grid_dim
100            let x_val = interp_vector(
101                &self.density_vecs[rank_idx * 3 * g..rank_idx * 3 * g + g],
102                xyz[0],
103            );
104            let y_val = interp_vector(
105                &self.density_vecs[rank_idx * 3 * g + g..rank_idx * 3 * g + 2 * g],
106                xyz[1],
107            );
108            let z_val = interp_vector(
109                &self.density_vecs[rank_idx * 3 * g + 2 * g..rank_idx * 3 * g + 3 * g],
110                xyz[2],
111            );
112            sum += x_val * y_val * z_val;
113        }
114
115        if !sum.is_finite() {
116            return Err(NerfError::NanEncountered {
117                context: "TensorRf::query_density".into(),
118            });
119        }
120
121        Ok(sum.max(0.0))
122    }
123
124    /// Query color feature vector at a 3D point in `[-1, 1]^3`.
125    ///
126    /// Returns `[n_color_feat]` features.
127    ///
128    /// # Errors
129    ///
130    /// Returns `NanEncountered` if an NaN occurs.
131    pub fn query_color(&self, xyz: [f32; 3]) -> NerfResult<Vec<f32>> {
132        let g = self.config.grid_dim;
133        let r = self.config.rank;
134        let nf = self.config.n_color_feat;
135
136        let mut out = vec![0.0_f32; nf];
137
138        for rank_idx in 0..r {
139            let base = rank_idx * 3 * g * nf;
140            // X axis: shape [g * nf], take the per-feature interp
141            let x_base = base;
142            let y_base = base + g * nf;
143            let z_base = base + 2 * g * nf;
144
145            let x_val = interp_vector_scalar(&self.color_vecs[x_base..x_base + g], xyz[0]);
146            let y_val = interp_vector_scalar(&self.color_vecs[y_base..y_base + g], xyz[1]);
147            let z_val = interp_vector_scalar(&self.color_vecs[z_base..z_base + g], xyz[2]);
148
149            let scalar = x_val * y_val * z_val;
150            // Each feature gets the same scalar contribution for the scalar CP version
151            // (For a proper vectorized CP, each axis would return n_color_feat values)
152            for feat in out.iter_mut() {
153                *feat += scalar;
154            }
155            let _ = nf; // Used above
156        }
157
158        for v in &out {
159            if !v.is_finite() {
160                return Err(NerfError::NanEncountered {
161                    context: "TensorRf::query_color".into(),
162                });
163            }
164        }
165
166        Ok(out)
167    }
168
169    /// Total number of parameters.
170    #[must_use]
171    pub fn param_count(&self) -> usize {
172        let g = self.config.grid_dim;
173        let r = self.config.rank;
174        let nf = self.config.n_color_feat;
175        r * 3 * g + r * 3 * g * nf
176    }
177}
178
179// ─── Interpolation helpers ────────────────────────────────────────────────────
180
181/// Linear interpolation in a 1D vector for a coordinate in `[-1, 1]`.
182fn interp_vector(vec: &[f32], coord: f32) -> f32 {
183    let g = vec.len();
184    if g == 0 {
185        return 0.0;
186    }
187    if g == 1 {
188        return vec[0];
189    }
190    // Map coord from [-1, 1] to [0, g-1]
191    let t = (coord.clamp(-1.0, 1.0) + 1.0) * 0.5 * (g - 1) as f32;
192    let lo = t.floor() as usize;
193    let hi = (lo + 1).min(g - 1);
194    let frac = t - lo as f32;
195    vec[lo] * (1.0 - frac) + vec[hi] * frac
196}
197
198/// Same as `interp_vector` but named distinctly for the scalar CP color path.
199fn interp_vector_scalar(vec: &[f32], coord: f32) -> f32 {
200    interp_vector(vec, coord)
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    fn make_tensorf(seed: u64) -> TensorRf {
208        let cfg = TensorRfConfig {
209            rank: 4,
210            grid_dim: 8,
211            n_color_feat: 3,
212        };
213        let mut rng = LcgRng::new(seed);
214        TensorRf::new(cfg, &mut rng).unwrap()
215    }
216
217    #[test]
218    fn density_nonneg() {
219        let tf = make_tensorf(42);
220        let d = tf.query_density([0.1, -0.3, 0.5]).unwrap();
221        assert!(d >= 0.0);
222    }
223
224    #[test]
225    fn color_shape() {
226        let tf = make_tensorf(17);
227        let c = tf.query_color([0.0, 0.0, 0.0]).unwrap();
228        assert_eq!(c.len(), tf.config.n_color_feat);
229    }
230
231    #[test]
232    fn param_count() {
233        let cfg = TensorRfConfig {
234            rank: 4,
235            grid_dim: 8,
236            n_color_feat: 3,
237        };
238        let mut rng = LcgRng::new(1);
239        let tf = TensorRf::new(cfg, &mut rng).unwrap();
240        // density: 4*3*8=96, color: 4*3*8*3=288, total=384
241        assert_eq!(tf.param_count(), 96 + 288);
242    }
243}