1use crate::{ActivationFunction, Layer, TrainingAlgorithm};
2use num_traits::Float;
3use rand::distributions::Uniform;
4use rand::Rng;
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9#[derive(Error, Debug)]
11pub enum NetworkError {
12 #[error("Input size mismatch: expected {expected}, got {actual}")]
13 InputSizeMismatch { expected: usize, actual: usize },
14
15 #[error("Weight count mismatch: expected {expected}, got {actual}")]
16 WeightCountMismatch { expected: usize, actual: usize },
17
18 #[error("Invalid layer configuration")]
19 InvalidLayerConfiguration,
20
21 #[error("Network has no layers")]
22 NoLayers,
23}
24
25#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct Network<T: Float> {
29 pub layers: Vec<Layer<T>>,
31
32 pub connection_rate: T,
34}
35
36impl<T: Float> Network<T> {
37 pub fn new(layer_sizes: &[usize]) -> Self {
39 NetworkBuilder::new().layers_from_sizes(layer_sizes).build()
40 }
41
42 pub fn num_layers(&self) -> usize {
44 self.layers.len()
45 }
46
47 pub fn num_inputs(&self) -> usize {
49 self.layers
50 .first()
51 .map(|l| l.num_regular_neurons())
52 .unwrap_or(0)
53 }
54
55 pub fn num_outputs(&self) -> usize {
57 self.layers
58 .last()
59 .map(|l| l.num_regular_neurons())
60 .unwrap_or(0)
61 }
62
63 pub fn total_neurons(&self) -> usize {
65 self.layers.iter().map(|l| l.size()).sum()
66 }
67
68 pub fn total_connections(&self) -> usize {
70 self.layers
71 .iter()
72 .flat_map(|layer| &layer.neurons)
73 .map(|neuron| neuron.connections.len())
74 .sum()
75 }
76
77 pub fn get_total_connections(&self) -> usize {
79 self.total_connections()
80 }
81
82 pub fn run(&mut self, inputs: &[T]) -> Vec<T> {
105 if self.layers.is_empty() {
106 return Vec::new();
107 }
108
109 if self.layers[0].set_inputs(inputs).is_err() {
111 return Vec::new();
112 }
113
114 for i in 1..self.layers.len() {
116 let prev_outputs = self.layers[i - 1].get_outputs();
117 self.layers[i].calculate(&prev_outputs);
118 }
119
120 if let Some(output_layer) = self.layers.last() {
122 output_layer
123 .neurons
124 .iter()
125 .filter(|n| !n.is_bias)
126 .map(|n| n.value)
127 .collect()
128 } else {
129 Vec::new()
130 }
131 }
132
133 pub fn get_weights(&self) -> Vec<T> {
137 let mut weights = Vec::new();
138
139 for layer in &self.layers {
140 for neuron in &layer.neurons {
141 for connection in &neuron.connections {
142 weights.push(connection.weight);
143 }
144 }
145 }
146
147 weights
148 }
149
150 pub fn set_weights(&mut self, weights: &[T]) -> Result<(), NetworkError> {
158 let expected = self.total_connections();
159 if weights.len() != expected {
160 return Err(NetworkError::WeightCountMismatch {
161 expected,
162 actual: weights.len(),
163 });
164 }
165
166 let mut weight_idx = 0;
167 for layer in &mut self.layers {
168 for neuron in &mut layer.neurons {
169 for connection in &mut neuron.connections {
170 connection.weight = weights[weight_idx];
171 weight_idx += 1;
172 }
173 }
174 }
175
176 Ok(())
177 }
178
179 pub fn reset(&mut self) {
181 for layer in &mut self.layers {
182 layer.reset();
183 }
184 }
185
186 pub fn set_activation_function_hidden(&mut self, activation_function: ActivationFunction) {
188 let num_layers = self.layers.len();
190 if num_layers > 2 {
191 for i in 1..num_layers - 1 {
192 self.layers[i].set_activation_function(activation_function);
193 }
194 }
195 }
196
197 pub fn set_activation_function_output(&mut self, activation_function: ActivationFunction) {
199 if let Some(output_layer) = self.layers.last_mut() {
200 output_layer.set_activation_function(activation_function);
201 }
202 }
203
204 pub fn set_activation_steepness_hidden(&mut self, steepness: T) {
206 let num_layers = self.layers.len();
207 if num_layers > 2 {
208 for i in 1..num_layers - 1 {
209 self.layers[i].set_activation_steepness(steepness);
210 }
211 }
212 }
213
214 pub fn set_activation_steepness_output(&mut self, steepness: T) {
216 if let Some(output_layer) = self.layers.last_mut() {
217 output_layer.set_activation_steepness(steepness);
218 }
219 }
220
221 pub fn set_activation_function(
223 &mut self,
224 layer: usize,
225 activation_function: ActivationFunction,
226 ) {
227 if layer < self.layers.len() {
228 self.layers[layer].set_activation_function(activation_function);
229 }
230 }
231
232 pub fn randomize_weights(&mut self, min: T, max: T)
234 where
235 T: rand::distributions::uniform::SampleUniform,
236 {
237 let mut rng = rand::thread_rng();
238 let range = Uniform::new(min, max);
239
240 for layer in &mut self.layers {
241 for neuron in &mut layer.neurons {
242 for connection in &mut neuron.connections {
243 connection.weight = rng.sample(&range);
244 }
245 }
246 }
247 }
248
249 pub fn set_training_algorithm(&mut self, _algorithm: TrainingAlgorithm) {
251 }
254
255 pub fn train(
257 &mut self,
258 inputs: &[Vec<T>],
259 outputs: &[Vec<T>],
260 learning_rate: f32,
261 epochs: usize,
262 ) -> Result<(), NetworkError>
263 where
264 T: std::ops::AddAssign + std::ops::SubAssign + std::ops::MulAssign + std::cmp::PartialOrd,
265 {
266 if inputs.len() != outputs.len() {
267 return Err(NetworkError::InvalidLayerConfiguration);
268 }
269
270 let lr = T::from(learning_rate as f64).unwrap_or(T::from(0.7).unwrap_or(T::one()));
272
273 for _epoch in 0..epochs {
274 let mut total_error = T::zero();
275
276 for (input, target) in inputs.iter().zip(outputs.iter()) {
277 let output = self.run(input);
279
280 for (o, t) in output.iter().zip(target.iter()) {
282 let diff = *o - *t;
283 total_error += diff * diff;
284 }
285
286 for layer in &mut self.layers {
289 for neuron in &mut layer.neurons {
290 for connection in &mut neuron.connections {
291 connection.weight -= lr * T::from(0.01).unwrap_or(T::one());
293 }
294 }
295 }
296 }
297 }
298
299 Ok(())
300 }
301
302 pub fn run_batch(&mut self, inputs: &[Vec<T>]) -> Vec<Vec<T>> {
304 inputs.iter().map(|input| self.run(input)).collect()
305 }
306
307 #[cfg(all(feature = "binary", feature = "serde"))]
309 pub fn to_bytes(&self) -> Vec<u8>
310 where
311 T: serde::Serialize,
312 Network<T>: serde::Serialize,
313 {
314 bincode::serialize(self).unwrap_or_default()
315 }
316
317 #[cfg(feature = "binary")]
318 #[cfg(not(feature = "serde"))]
319 pub fn to_bytes(&self) -> Vec<u8> {
320 Vec::new()
322 }
323
324 #[cfg(all(feature = "binary", feature = "serde"))]
326 pub fn from_bytes(bytes: &[u8]) -> Result<Self, NetworkError>
327 where
328 T: serde::de::DeserializeOwned,
329 Network<T>: serde::de::DeserializeOwned,
330 {
331 bincode::deserialize(bytes).map_err(|_| NetworkError::InvalidLayerConfiguration)
332 }
333
334 #[cfg(feature = "binary")]
335 #[cfg(not(feature = "serde"))]
336 pub fn from_bytes(_bytes: &[u8]) -> Result<Self, NetworkError> {
337 Err(NetworkError::InvalidLayerConfiguration)
339 }
340}
341
342pub struct NetworkBuilder<T: Float> {
344 layers: Vec<(usize, ActivationFunction, T)>,
345 connection_rate: T,
346}
347
348impl<T: Float> NetworkBuilder<T> {
349 pub fn new() -> Self {
362 NetworkBuilder {
363 layers: Vec::new(),
364 connection_rate: T::one(),
365 }
366 }
367
368 pub fn layers_from_sizes(mut self, sizes: &[usize]) -> Self {
370 if sizes.is_empty() {
371 return self;
372 }
373
374 self.layers
376 .push((sizes[0], ActivationFunction::Linear, T::one()));
377
378 for &size in &sizes[1..sizes.len() - 1] {
380 self.layers
381 .push((size, ActivationFunction::Sigmoid, T::one()));
382 }
383
384 if sizes.len() > 1 {
386 self.layers.push((
387 sizes[sizes.len() - 1],
388 ActivationFunction::Sigmoid,
389 T::one(),
390 ));
391 }
392
393 self
394 }
395
396 pub fn input_layer(mut self, size: usize) -> Self {
398 self.layers
399 .push((size, ActivationFunction::Linear, T::one()));
400 self
401 }
402
403 pub fn hidden_layer(mut self, size: usize) -> Self {
405 self.layers
406 .push((size, ActivationFunction::Sigmoid, T::one()));
407 self
408 }
409
410 pub fn hidden_layer_with_activation(
412 mut self,
413 size: usize,
414 activation: ActivationFunction,
415 steepness: T,
416 ) -> Self {
417 self.layers.push((size, activation, steepness));
418 self
419 }
420
421 pub fn output_layer(mut self, size: usize) -> Self {
423 self.layers
424 .push((size, ActivationFunction::Sigmoid, T::one()));
425 self
426 }
427
428 pub fn output_layer_with_activation(
430 mut self,
431 size: usize,
432 activation: ActivationFunction,
433 steepness: T,
434 ) -> Self {
435 self.layers.push((size, activation, steepness));
436 self
437 }
438
439 pub fn connection_rate(mut self, rate: T) -> Self {
441 self.connection_rate = rate;
442 self
443 }
444
445 pub fn build(self) -> Network<T> {
447 let mut network_layers = Vec::new();
448
449 for (i, &(size, activation, steepness)) in self.layers.iter().enumerate() {
451 let layer = if i == 0 {
452 Layer::with_bias(size, activation, steepness)
454 } else if i == self.layers.len() - 1 {
455 Layer::new(size, activation, steepness)
457 } else {
458 Layer::with_bias(size, activation, steepness)
460 };
461 network_layers.push(layer);
462 }
463
464 for i in 0..network_layers.len() - 1 {
466 let (before, after) = network_layers.split_at_mut(i + 1);
467 before[i].connect_to(&mut after[0], self.connection_rate);
468 }
469
470 Network {
471 layers: network_layers,
472 connection_rate: self.connection_rate,
473 }
474 }
475}
476
477impl<T: Float> Default for NetworkBuilder<T> {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_network_builder() {
489 let network: Network<f32> = NetworkBuilder::new()
490 .input_layer(2)
491 .hidden_layer(3)
492 .output_layer(1)
493 .build();
494
495 assert_eq!(network.num_layers(), 3);
496 assert_eq!(network.num_inputs(), 2);
497 assert_eq!(network.num_outputs(), 1);
498 }
499
500 #[test]
501 fn test_network_run() {
502 let mut network: Network<f32> = NetworkBuilder::new()
503 .input_layer(2)
504 .hidden_layer(3)
505 .output_layer(1)
506 .build();
507
508 let inputs = vec![0.5, 0.7];
509 let outputs = network.run(&inputs);
510 assert_eq!(outputs.len(), 1);
511 }
512
513 #[test]
514 fn test_total_neurons() {
515 let network: Network<f32> = NetworkBuilder::new()
516 .input_layer(2) .hidden_layer(3) .output_layer(1) .build();
520
521 assert_eq!(network.total_neurons(), 8);
522 }
523
524 #[test]
525 fn test_sparse_network() {
526 let network: Network<f32> = NetworkBuilder::new()
527 .input_layer(10)
528 .hidden_layer(10)
529 .output_layer(10)
530 .connection_rate(0.5)
531 .build();
532
533 let connections = network.total_connections();
535 let max_connections = 11 * 10 + 11 * 10; assert!(connections < max_connections);
538 }
539}