Skip to main content

rnn/network/
network.rs

1use core::fmt;
2
3use crate::activations::ActivationKind;
4use crate::engine::{forward_dense_plan, ForwardError};
5use crate::layers::{DenseLayerDesc, LayerPlan, LayerSpec};
6
7pub struct NeuralNetwork<'a> {
8    pub layers: &'a [usize],
9    pub weights: &'a [f32],
10    pub biases: &'a [f32],
11}
12
13impl<'a> NeuralNetwork<'a> {
14    pub fn from_parts(layers: &'a [usize], weights: &'a [f32], biases: &'a [f32]) -> Option<Self> {
15        if layers.len() < 2 { return None; }
16        let mut expected_w = 0usize;
17        let mut expected_b = 0usize;
18        for i in 0..layers.len() - 1 {
19            let in_sz = layers[i];
20            let out_sz = layers[i + 1];
21            expected_w = expected_w.saturating_add(in_sz.saturating_mul(out_sz));
22            expected_b = expected_b.saturating_add(out_sz);
23        }
24        if weights.len() != expected_w || biases.len() != expected_b { return None; }
25        Some(NeuralNetwork { layers, weights, biases })
26    }
27
28    pub fn expected_weights_count(layers: &[usize]) -> Option<usize> {
29        if layers.len() < 2 { return None; }
30        let mut expected = 0usize;
31        for i in 0..layers.len() - 1 {
32            expected = expected.checked_add(layers[i].checked_mul(layers[i+1])?)?;
33        }
34        Some(expected)
35    }
36
37    pub fn expected_biases_count(layers: &[usize]) -> Option<usize> {
38        if layers.len() < 2 { return None; }
39        let mut expected = 0usize;
40        for i in 0..layers.len() - 1 {
41            expected = expected.checked_add(layers[i+1])?;
42        }
43        Some(expected)
44    }
45
46    pub fn conceptualize_5d<'s>(
47        &self,
48        storage: &'s mut [crate::sphere5d::NeuronPoint],
49        radius: f32,
50    ) -> Result<crate::sphere5d::Sphere5D<'s>, crate::sphere5d::SphereError> {
51        crate::sphere5d::Sphere5D::from_network(self, storage, radius)
52    }
53
54    pub fn layer_count(&self) -> usize {
55        self.layers.len().saturating_sub(1)
56    }
57
58    pub fn build_dense_layer_specs(
59        &self,
60        hidden_activation: ActivationKind,
61        output_activation: ActivationKind,
62        out: &mut [LayerSpec],
63    ) -> Option<usize> {
64        let layer_count = self.layer_count();
65        if layer_count == 0 || out.len() < layer_count {
66            return None;
67        }
68
69        let mut w_off = 0usize;
70        let mut b_off = 0usize;
71
72        for i in 0..layer_count {
73            let in_size = self.layers[i];
74            let out_size = self.layers[i + 1];
75            let weight_len = in_size.checked_mul(out_size)?;
76            let activation = if i + 1 == layer_count { output_activation } else { hidden_activation };
77
78            out[i] = LayerSpec::Dense(DenseLayerDesc {
79                input_size: in_size,
80                output_size: out_size,
81                weight_offset: w_off,
82                bias_offset: b_off,
83                activation,
84            });
85
86            w_off = w_off.checked_add(weight_len)?;
87            b_off = b_off.checked_add(out_size)?;
88        }
89
90        if w_off != self.weights.len() || b_off != self.biases.len() {
91            return None;
92        }
93
94        Some(layer_count)
95    }
96
97    pub fn forward_dense(
98        &self,
99        input: &[f32],
100        output: &mut [f32],
101        scratch: &mut [f32],
102        layer_scratch: &mut [LayerSpec],
103        hidden_activation: ActivationKind,
104        output_activation: ActivationKind,
105    ) -> Result<(), ForwardError> {
106        let used_layers = self
107            .build_dense_layer_specs(hidden_activation, output_activation, layer_scratch)
108            .ok_or(ForwardError::InvalidPlan)?;
109
110        let plan = LayerPlan {
111            layers: &layer_scratch[..used_layers],
112            weights: self.weights,
113            biases: self.biases,
114        };
115
116        forward_dense_plan(&plan, input, output, scratch)
117    }
118}
119
120impl<'a> fmt::Debug for NeuralNetwork<'a> {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        f.debug_struct("NeuralNetwork")
123            .field("layers", &self.layers)
124            .field("weights_len", &self.weights.len())
125            .field("biases_len", &self.biases.len())
126            .finish()
127    }
128}
129