neat/
runnable.rs

1use crate::topology::*;
2
3#[cfg(not(feature = "rayon"))]
4use std::{cell::RefCell, rc::Rc};
5
6#[cfg(feature = "rayon")]
7use rayon::prelude::*;
8#[cfg(feature = "rayon")]
9use std::sync::{Arc, RwLock};
10
11/// A runnable, stated Neural Network generated from a [NeuralNetworkTopology]. Use [`NeuralNetwork::from`] to go from stateles to runnable.
12/// Because this has state, you need to run [`NeuralNetwork::flush_state`] between [`NeuralNetwork::predict`] calls.
13#[derive(Debug)]
14#[cfg(not(feature = "rayon"))]
15pub struct NeuralNetwork<const I: usize, const O: usize> {
16    input_layer: [Rc<RefCell<Neuron>>; I],
17    hidden_layers: Vec<Rc<RefCell<Neuron>>>,
18    output_layer: [Rc<RefCell<Neuron>>; O],
19}
20
21/// Parallelized version of the [`NeuralNetwork`] struct.
22#[derive(Debug)]
23#[cfg(feature = "rayon")]
24pub struct NeuralNetwork<const I: usize, const O: usize> {
25    input_layer: [Arc<RwLock<Neuron>>; I],
26    hidden_layers: Vec<Arc<RwLock<Neuron>>>,
27    output_layer: [Arc<RwLock<Neuron>>; O],
28}
29
30impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
31    /// Predicts an output for the given inputs.
32    #[cfg(not(feature = "rayon"))]
33    pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
34        for (i, v) in inputs.iter().enumerate() {
35            let mut nw = self.input_layer[i].borrow_mut();
36            nw.state.value = *v;
37            nw.state.processed = true;
38        }
39
40        (0..O)
41            .map(NeuronLocation::Output)
42            .map(|loc| self.process_neuron(loc))
43            .collect::<Vec<_>>()
44            .try_into()
45            .unwrap()
46    }
47
48    /// Parallelized prediction of outputs from inputs.
49    #[cfg(feature = "rayon")]
50    pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
51        inputs.par_iter().enumerate().for_each(|(i, v)| {
52            let mut nw = self.input_layer[i].write().unwrap();
53            nw.state.value = *v;
54            nw.state.processed = true;
55        });
56
57        (0..O)
58            .map(NeuronLocation::Output)
59            .collect::<Vec<_>>()
60            .into_par_iter()
61            .map(|loc| self.process_neuron(loc))
62            .collect::<Vec<_>>()
63            .try_into()
64            .unwrap()
65    }
66
67    #[cfg(not(feature = "rayon"))]
68    fn process_neuron(&self, loc: NeuronLocation) -> f32 {
69        let n = self.get_neuron(loc);
70
71        {
72            let nr = n.borrow();
73
74            if nr.state.processed {
75                return nr.state.value;
76            }
77        }
78
79        let mut n = n.borrow_mut();
80
81        for (l, w) in n.inputs.clone() {
82            n.state.value += self.process_neuron(l) * w;
83        }
84
85        n.activate();
86
87        n.state.value
88    }
89
90    #[cfg(feature = "rayon")]
91    fn process_neuron(&self, loc: NeuronLocation) -> f32 {
92        let n = self.get_neuron(loc);
93
94        {
95            let nr = n.read().unwrap();
96
97            if nr.state.processed {
98                return nr.state.value;
99            }
100        }
101
102        let val: f32 = n
103            .read()
104            .unwrap()
105            .inputs
106            .par_iter()
107            .map(|&(n2, w)| {
108                let processed = self.process_neuron(n2);
109                processed * w
110            })
111            .sum();
112
113        let mut nw = n.write().unwrap();
114        nw.state.value += val;
115        nw.activate();
116
117        nw.state.value
118    }
119
120    #[cfg(not(feature = "rayon"))]
121    fn get_neuron(&self, loc: NeuronLocation) -> Rc<RefCell<Neuron>> {
122        match loc {
123            NeuronLocation::Input(i) => self.input_layer[i].clone(),
124            NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
125            NeuronLocation::Output(i) => self.output_layer[i].clone(),
126        }
127    }
128
129    #[cfg(feature = "rayon")]
130    fn get_neuron(&self, loc: NeuronLocation) -> Arc<RwLock<Neuron>> {
131        match loc {
132            NeuronLocation::Input(i) => self.input_layer[i].clone(),
133            NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
134            NeuronLocation::Output(i) => self.output_layer[i].clone(),
135        }
136    }
137
138    /// Flushes the network's state after a [prediction][NeuralNetwork::predict].
139    #[cfg(not(feature = "rayon"))]
140    pub fn flush_state(&self) {
141        for n in &self.input_layer {
142            n.borrow_mut().flush_state();
143        }
144
145        for n in &self.hidden_layers {
146            n.borrow_mut().flush_state();
147        }
148
149        for n in &self.output_layer {
150            n.borrow_mut().flush_state();
151        }
152    }
153
154    /// Flushes the neural network's state.
155    #[cfg(feature = "rayon")]
156    pub fn flush_state(&self) {
157        self.input_layer
158            .par_iter()
159            .for_each(|n| n.write().unwrap().flush_state());
160
161        self.hidden_layers
162            .par_iter()
163            .for_each(|n| n.write().unwrap().flush_state());
164
165        self.output_layer
166            .par_iter()
167            .for_each(|n| n.write().unwrap().flush_state());
168    }
169}
170
171impl<const I: usize, const O: usize> From<&NeuralNetworkTopology<I, O>> for NeuralNetwork<I, O> {
172    #[cfg(not(feature = "rayon"))]
173    fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
174        let input_layer = value
175            .input_layer
176            .iter()
177            .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
178            .collect::<Vec<_>>()
179            .try_into()
180            .unwrap();
181
182        let hidden_layers = value
183            .hidden_layers
184            .iter()
185            .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
186            .collect();
187
188        let output_layer = value
189            .output_layer
190            .iter()
191            .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
192            .collect::<Vec<_>>()
193            .try_into()
194            .unwrap();
195
196        Self {
197            input_layer,
198            hidden_layers,
199            output_layer,
200        }
201    }
202
203    #[cfg(feature = "rayon")]
204    fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
205        let input_layer = value
206            .input_layer
207            .iter()
208            .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
209            .collect::<Vec<_>>()
210            .try_into()
211            .unwrap();
212
213        let hidden_layers = value
214            .hidden_layers
215            .iter()
216            .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
217            .collect();
218
219        let output_layer = value
220            .output_layer
221            .iter()
222            .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
223            .collect::<Vec<_>>()
224            .try_into()
225            .unwrap();
226
227        Self {
228            input_layer,
229            hidden_layers,
230            output_layer,
231        }
232    }
233}
234
235/// A state-filled neuron.
236#[derive(Clone, Debug)]
237pub struct Neuron {
238    inputs: Vec<(NeuronLocation, f32)>,
239    bias: f32,
240
241    /// The current state of the neuron.
242    pub state: NeuronState,
243
244    /// The neuron's activation function
245    pub activation: ActivationFn,
246}
247
248impl Neuron {
249    /// Flushes a neuron's state. Called by [`NeuralNetwork::flush_state`]
250    pub fn flush_state(&mut self) {
251        self.state.value = self.bias;
252    }
253
254    /// Applies the activation function to the neuron
255    pub fn activate(&mut self) {
256        self.state.value = self.activation.func.activate(self.state.value);
257    }
258}
259
260impl From<&NeuronTopology> for Neuron {
261    fn from(value: &NeuronTopology) -> Self {
262        Self {
263            inputs: value.inputs.clone(),
264            bias: value.bias,
265            state: NeuronState {
266                value: value.bias,
267                ..Default::default()
268            },
269            activation: value.activation.clone(),
270        }
271    }
272}
273
274/// A state used in [`Neuron`]s for cache.
275#[derive(Clone, Debug, Default)]
276pub struct NeuronState {
277    /// The current value of the neuron. Initialized to a neuron's bias when flushed.
278    pub value: f32,
279
280    /// Whether or not [`value`][NeuronState::value] has finished processing.
281    pub processed: bool,
282}
283
284/// A blanket trait for iterators meant to help with interpreting the output of a [`NeuralNetwork`]
285#[cfg(feature = "max-index")]
286pub trait MaxIndex<T: PartialOrd> {
287    /// Retrieves the index of the max value.
288    fn max_index(self) -> usize;
289}
290
291#[cfg(feature = "max-index")]
292impl<I: Iterator<Item = T>, T: PartialOrd> MaxIndex<T> for I {
293    // slow and lazy implementation but it works (will prob optimize in the future)
294    fn max_index(self) -> usize {
295        self.enumerate()
296            .max_by(|(_, v), (_, v2)| v.partial_cmp(v2).unwrap())
297            .unwrap()
298            .0
299    }
300}