oxicuda_nerf/network/
tiny_nerf.rs1use crate::error::{NerfError, NerfResult};
7use crate::handle::LcgRng;
8
9#[derive(Debug, Clone)]
11pub struct TinyNerf {
12 layers: Vec<(Vec<f32>, Vec<f32>)>,
14 input_dim: usize,
16}
17
18impl TinyNerf {
19 #[must_use]
23 pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut LcgRng) -> Self {
24 let scale = (2.0_f32 / input_dim as f32).sqrt();
25 let mut init = |fan_in: usize, fan_out: usize| -> (Vec<f32>, Vec<f32>) {
26 let s = (2.0_f32 / fan_in as f32).sqrt();
27 let mut w = vec![0.0_f32; fan_out * fan_in];
28 for v in w.iter_mut() {
29 let (a, _) = rng.next_normal_pair();
30 *v = a * s;
31 }
32 (w, vec![0.0_f32; fan_out])
33 };
34 let _ = scale; let layer0 = init(input_dim, hidden_dim);
37 let layer1 = init(hidden_dim, hidden_dim);
38 let layer2 = init(hidden_dim, hidden_dim);
39 let layer3 = init(hidden_dim, 4); Self {
42 layers: vec![layer0, layer1, layer2, layer3],
43 input_dim,
44 }
45 }
46
47 pub fn forward(&self, x: &[f32]) -> NerfResult<(f32, [f32; 3])> {
55 if x.len() != self.input_dim {
56 return Err(NerfError::DimensionMismatch {
57 expected: self.input_dim,
58 got: x.len(),
59 });
60 }
61
62 let h = self.layers[0].1.len(); let a0 = fc_relu(x, &self.layers[0].0, &self.layers[0].1, h);
66 let a1 = fc_relu(&a0, &self.layers[1].0, &self.layers[1].1, h);
67 let a2 = fc_relu(&a1, &self.layers[2].0, &self.layers[2].1, h);
68
69 let out = fc_linear(&a2, &self.layers[3].0, &self.layers[3].1, 4);
71
72 let sigma = out[0].max(0.0);
73 let rgb = [sigmoid(out[1]), sigmoid(out[2]), sigmoid(out[3])];
74
75 Ok((sigma, rgb))
76 }
77}
78
79#[inline]
82fn sigmoid(x: f32) -> f32 {
83 1.0 / (1.0 + (-x).exp())
84}
85
86fn fc_relu(x: &[f32], w: &[f32], b: &[f32], out_dim: usize) -> Vec<f32> {
87 let in_dim = x.len();
88 let mut out = vec![0.0_f32; out_dim];
89 for (o, (wo, &bi)) in out.iter_mut().zip(w.chunks(in_dim).zip(b.iter())) {
90 *o = (wo
91 .iter()
92 .zip(x.iter())
93 .map(|(&wi, &xi)| wi * xi)
94 .sum::<f32>()
95 + bi)
96 .max(0.0);
97 }
98 out
99}
100
101fn fc_linear(x: &[f32], w: &[f32], b: &[f32], out_dim: usize) -> Vec<f32> {
102 let in_dim = x.len();
103 let mut out = vec![0.0_f32; out_dim];
104 for (o, (wo, &bi)) in out.iter_mut().zip(w.chunks(in_dim).zip(b.iter())) {
105 *o = wo
106 .iter()
107 .zip(x.iter())
108 .map(|(&wi, &xi)| wi * xi)
109 .sum::<f32>()
110 + bi;
111 }
112 out
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn tiny_nerf_forward_shape() {
121 let mut rng = LcgRng::new(77);
122 let net = TinyNerf::new(10, 32, &mut rng);
123 let x = vec![0.1_f32; 10];
124 let (sigma, rgb) = net.forward(&x).unwrap();
125 assert!(sigma >= 0.0);
126 assert!(rgb.iter().all(|&v| (0.0..=1.0).contains(&v)));
127 }
128
129 #[test]
130 fn tiny_nerf_wrong_input() {
131 let mut rng = LcgRng::new(88);
132 let net = TinyNerf::new(10, 32, &mut rng);
133 assert!(net.forward(&[0.0_f32; 5]).is_err());
134 }
135}