1use crate::topology::*;
2
3#[cfg(not(feature = "rayon"))]
4use std::{cell::RefCell, rc::Rc};
5
6#[cfg(feature = "rayon")]
7use rayon::prelude::*;
8#[cfg(feature = "rayon")]
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug)]
14#[cfg(not(feature = "rayon"))]
15pub struct NeuralNetwork<const I: usize, const O: usize> {
16 input_layer: [Rc<RefCell<Neuron>>; I],
17 hidden_layers: Vec<Rc<RefCell<Neuron>>>,
18 output_layer: [Rc<RefCell<Neuron>>; O],
19}
20
21#[derive(Debug)]
23#[cfg(feature = "rayon")]
24pub struct NeuralNetwork<const I: usize, const O: usize> {
25 input_layer: [Arc<RwLock<Neuron>>; I],
26 hidden_layers: Vec<Arc<RwLock<Neuron>>>,
27 output_layer: [Arc<RwLock<Neuron>>; O],
28}
29
30impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
31 #[cfg(not(feature = "rayon"))]
33 pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
34 for (i, v) in inputs.iter().enumerate() {
35 let mut nw = self.input_layer[i].borrow_mut();
36 nw.state.value = *v;
37 nw.state.processed = true;
38 }
39
40 (0..O)
41 .map(NeuronLocation::Output)
42 .map(|loc| self.process_neuron(loc))
43 .collect::<Vec<_>>()
44 .try_into()
45 .unwrap()
46 }
47
48 #[cfg(feature = "rayon")]
50 pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
51 inputs.par_iter().enumerate().for_each(|(i, v)| {
52 let mut nw = self.input_layer[i].write().unwrap();
53 nw.state.value = *v;
54 nw.state.processed = true;
55 });
56
57 (0..O)
58 .map(NeuronLocation::Output)
59 .collect::<Vec<_>>()
60 .into_par_iter()
61 .map(|loc| self.process_neuron(loc))
62 .collect::<Vec<_>>()
63 .try_into()
64 .unwrap()
65 }
66
67 #[cfg(not(feature = "rayon"))]
68 fn process_neuron(&self, loc: NeuronLocation) -> f32 {
69 let n = self.get_neuron(loc);
70
71 {
72 let nr = n.borrow();
73
74 if nr.state.processed {
75 return nr.state.value;
76 }
77 }
78
79 let mut n = n.borrow_mut();
80
81 for (l, w) in n.inputs.clone() {
82 n.state.value += self.process_neuron(l) * w;
83 }
84
85 n.activate();
86
87 n.state.value
88 }
89
90 #[cfg(feature = "rayon")]
91 fn process_neuron(&self, loc: NeuronLocation) -> f32 {
92 let n = self.get_neuron(loc);
93
94 {
95 let nr = n.read().unwrap();
96
97 if nr.state.processed {
98 return nr.state.value;
99 }
100 }
101
102 let val: f32 = n
103 .read()
104 .unwrap()
105 .inputs
106 .par_iter()
107 .map(|&(n2, w)| {
108 let processed = self.process_neuron(n2);
109 processed * w
110 })
111 .sum();
112
113 let mut nw = n.write().unwrap();
114 nw.state.value += val;
115 nw.activate();
116
117 nw.state.value
118 }
119
120 #[cfg(not(feature = "rayon"))]
121 fn get_neuron(&self, loc: NeuronLocation) -> Rc<RefCell<Neuron>> {
122 match loc {
123 NeuronLocation::Input(i) => self.input_layer[i].clone(),
124 NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
125 NeuronLocation::Output(i) => self.output_layer[i].clone(),
126 }
127 }
128
129 #[cfg(feature = "rayon")]
130 fn get_neuron(&self, loc: NeuronLocation) -> Arc<RwLock<Neuron>> {
131 match loc {
132 NeuronLocation::Input(i) => self.input_layer[i].clone(),
133 NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
134 NeuronLocation::Output(i) => self.output_layer[i].clone(),
135 }
136 }
137
138 #[cfg(not(feature = "rayon"))]
140 pub fn flush_state(&self) {
141 for n in &self.input_layer {
142 n.borrow_mut().flush_state();
143 }
144
145 for n in &self.hidden_layers {
146 n.borrow_mut().flush_state();
147 }
148
149 for n in &self.output_layer {
150 n.borrow_mut().flush_state();
151 }
152 }
153
154 #[cfg(feature = "rayon")]
156 pub fn flush_state(&self) {
157 self.input_layer
158 .par_iter()
159 .for_each(|n| n.write().unwrap().flush_state());
160
161 self.hidden_layers
162 .par_iter()
163 .for_each(|n| n.write().unwrap().flush_state());
164
165 self.output_layer
166 .par_iter()
167 .for_each(|n| n.write().unwrap().flush_state());
168 }
169}
170
171impl<const I: usize, const O: usize> From<&NeuralNetworkTopology<I, O>> for NeuralNetwork<I, O> {
172 #[cfg(not(feature = "rayon"))]
173 fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
174 let input_layer = value
175 .input_layer
176 .iter()
177 .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
178 .collect::<Vec<_>>()
179 .try_into()
180 .unwrap();
181
182 let hidden_layers = value
183 .hidden_layers
184 .iter()
185 .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
186 .collect();
187
188 let output_layer = value
189 .output_layer
190 .iter()
191 .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
192 .collect::<Vec<_>>()
193 .try_into()
194 .unwrap();
195
196 Self {
197 input_layer,
198 hidden_layers,
199 output_layer,
200 }
201 }
202
203 #[cfg(feature = "rayon")]
204 fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
205 let input_layer = value
206 .input_layer
207 .iter()
208 .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
209 .collect::<Vec<_>>()
210 .try_into()
211 .unwrap();
212
213 let hidden_layers = value
214 .hidden_layers
215 .iter()
216 .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
217 .collect();
218
219 let output_layer = value
220 .output_layer
221 .iter()
222 .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
223 .collect::<Vec<_>>()
224 .try_into()
225 .unwrap();
226
227 Self {
228 input_layer,
229 hidden_layers,
230 output_layer,
231 }
232 }
233}
234
235#[derive(Clone, Debug)]
237pub struct Neuron {
238 inputs: Vec<(NeuronLocation, f32)>,
239 bias: f32,
240
241 pub state: NeuronState,
243
244 pub activation: ActivationFn,
246}
247
248impl Neuron {
249 pub fn flush_state(&mut self) {
251 self.state.value = self.bias;
252 }
253
254 pub fn activate(&mut self) {
256 self.state.value = self.activation.func.activate(self.state.value);
257 }
258}
259
260impl From<&NeuronTopology> for Neuron {
261 fn from(value: &NeuronTopology) -> Self {
262 Self {
263 inputs: value.inputs.clone(),
264 bias: value.bias,
265 state: NeuronState {
266 value: value.bias,
267 ..Default::default()
268 },
269 activation: value.activation.clone(),
270 }
271 }
272}
273
274#[derive(Clone, Debug, Default)]
276pub struct NeuronState {
277 pub value: f32,
279
280 pub processed: bool,
282}
283
284#[cfg(feature = "max-index")]
286pub trait MaxIndex<T: PartialOrd> {
287 fn max_index(self) -> usize;
289}
290
291#[cfg(feature = "max-index")]
292impl<I: Iterator<Item = T>, T: PartialOrd> MaxIndex<T> for I {
293 fn max_index(self) -> usize {
295 self.enumerate()
296 .max_by(|(_, v), (_, v2)| v.partial_cmp(v2).unwrap())
297 .unwrap()
298 .0
299 }
300}