Skip to main content

oxicuda_nerf/network/
tiny_nerf.rs

1//! Compact 4-layer NeRF for tests and fast experiments.
2//!
3//! Architecture: input → FC → ReLU → FC → ReLU → FC → ReLU → FC → (sigma, RGB)
4//! Last layer: no activation for sigma (then ReLU); Sigmoid for RGB.
5
6use crate::error::{NerfError, NerfResult};
7use crate::handle::LcgRng;
8
9/// Tiny NeRF MLP: 3 hidden layers + output layer.
10#[derive(Debug, Clone)]
11pub struct TinyNerf {
12    /// Weight-bias pairs for 4 layers.
13    layers: Vec<(Vec<f32>, Vec<f32>)>,
14    /// Input dimensionality.
15    input_dim: usize,
16}
17
18impl TinyNerf {
19    /// Create a new TinyNerf.
20    ///
21    /// 3 hidden layers of `hidden_dim` followed by an output of 4 (sigma + RGB).
22    #[must_use]
23    pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut LcgRng) -> Self {
24        let scale = (2.0_f32 / input_dim as f32).sqrt();
25        let mut init = |fan_in: usize, fan_out: usize| -> (Vec<f32>, Vec<f32>) {
26            let s = (2.0_f32 / fan_in as f32).sqrt();
27            let mut w = vec![0.0_f32; fan_out * fan_in];
28            for v in w.iter_mut() {
29                let (a, _) = rng.next_normal_pair();
30                *v = a * s;
31            }
32            (w, vec![0.0_f32; fan_out])
33        };
34        let _ = scale; // scale computed but only used via init closure
35
36        let layer0 = init(input_dim, hidden_dim);
37        let layer1 = init(hidden_dim, hidden_dim);
38        let layer2 = init(hidden_dim, hidden_dim);
39        let layer3 = init(hidden_dim, 4); // sigma + RGB
40
41        Self {
42            layers: vec![layer0, layer1, layer2, layer3],
43            input_dim,
44        }
45    }
46
47    /// Forward pass: returns `(sigma: f32, rgb: [f32; 3])`.
48    ///
49    /// sigma = ReLU(output\[0\]), rgb = Sigmoid(output\[1..4\]).
50    ///
51    /// # Errors
52    ///
53    /// Returns `DimensionMismatch` if `x.len() != input_dim`.
54    pub fn forward(&self, x: &[f32]) -> NerfResult<(f32, [f32; 3])> {
55        if x.len() != self.input_dim {
56            return Err(NerfError::DimensionMismatch {
57                expected: self.input_dim,
58                got: x.len(),
59            });
60        }
61
62        let h = self.layers[0].1.len(); // hidden_dim
63
64        // Hidden layers 0–2 with ReLU
65        let a0 = fc_relu(x, &self.layers[0].0, &self.layers[0].1, h);
66        let a1 = fc_relu(&a0, &self.layers[1].0, &self.layers[1].1, h);
67        let a2 = fc_relu(&a1, &self.layers[2].0, &self.layers[2].1, h);
68
69        // Output layer: 4 units, no intermediate activation
70        let out = fc_linear(&a2, &self.layers[3].0, &self.layers[3].1, 4);
71
72        let sigma = out[0].max(0.0);
73        let rgb = [sigmoid(out[1]), sigmoid(out[2]), sigmoid(out[3])];
74
75        Ok((sigma, rgb))
76    }
77}
78
79// ─── Activation utilities ────────────────────────────────────────────────────
80
81#[inline]
82fn sigmoid(x: f32) -> f32 {
83    1.0 / (1.0 + (-x).exp())
84}
85
86fn fc_relu(x: &[f32], w: &[f32], b: &[f32], out_dim: usize) -> Vec<f32> {
87    let in_dim = x.len();
88    let mut out = vec![0.0_f32; out_dim];
89    for (o, (wo, &bi)) in out.iter_mut().zip(w.chunks(in_dim).zip(b.iter())) {
90        *o = (wo
91            .iter()
92            .zip(x.iter())
93            .map(|(&wi, &xi)| wi * xi)
94            .sum::<f32>()
95            + bi)
96            .max(0.0);
97    }
98    out
99}
100
101fn fc_linear(x: &[f32], w: &[f32], b: &[f32], out_dim: usize) -> Vec<f32> {
102    let in_dim = x.len();
103    let mut out = vec![0.0_f32; out_dim];
104    for (o, (wo, &bi)) in out.iter_mut().zip(w.chunks(in_dim).zip(b.iter())) {
105        *o = wo
106            .iter()
107            .zip(x.iter())
108            .map(|(&wi, &xi)| wi * xi)
109            .sum::<f32>()
110            + bi;
111    }
112    out
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn tiny_nerf_forward_shape() {
121        let mut rng = LcgRng::new(77);
122        let net = TinyNerf::new(10, 32, &mut rng);
123        let x = vec![0.1_f32; 10];
124        let (sigma, rgb) = net.forward(&x).unwrap();
125        assert!(sigma >= 0.0);
126        assert!(rgb.iter().all(|&v| (0.0..=1.0).contains(&v)));
127    }
128
129    #[test]
130    fn tiny_nerf_wrong_input() {
131        let mut rng = LcgRng::new(88);
132        let net = TinyNerf::new(10, 32, &mut rng);
133        assert!(net.forward(&[0.0_f32; 5]).is_err());
134    }
135}