only_brain/
neural_network.rs

1use crate::activation_functions::{sigmoid, ActivationFunction, ACTION_FUNCTIONS_MAP};
2use crate::layer::Layer;
3use nalgebra::{DMatrix, DVector};
4use rand::thread_rng;
5use std::collections::HashMap;
6use std::fmt;
7use serde::{Deserialize, Serialize};
8
9/// Neural Network
10///
11/// This is the main struct of the library. It contains a vector of layers and an
12/// activation function. You can use this struct and its methods to create, manipulate and
13/// even implement your ways to train a neural network.
14///
15/// # Example
16///
17/// ```
18/// use only_brain::NeuralNetwork;
19/// use nalgebra::dmatrix;
20/// use nalgebra::dvector;
21///
22/// fn main() {
23///     let mut nn = NeuralNetwork::new(&vec![2, 2, 1]);
24///
25///     nn.set_layer_weights(1, dmatrix![0.1, 0.2;
26///                                      0.3, 0.4]);
27///     nn.set_layer_biases(1, dvector![0.1, 0.2]);
28///
29///     nn.set_layer_weights(2, dmatrix![0.9, 0.8]);
30///     nn.set_layer_biases(2, dvector![0.1]);
31///
32///     let input = vec![0.5, 0.2];
33///     let output = nn.feed_forward(&input);
34///
35///     println!("{:?}", output);
36/// }
37/// ```
38#[derive(Serialize, Deserialize)]
39pub struct NeuralNetwork {
40    layers: Vec<Layer>,
41    activation_function: Option<ActivationFunction>,
42}
43
44impl NeuralNetwork {
45    /// Creates a new Neural Network with the given layers. The layers vector must contain
46    /// the number of neurons for each layer.
47    ///
48    /// # Example
49    ///
50    /// ```
51    /// # use only_brain::NeuralNetwork;
52    /// # fn main() {
53    /// let nn = NeuralNetwork::new(&vec![2, 2, 1]);
54    /// # }
55    pub fn new(layers: &Vec<usize>) -> Self {
56        let mut rng = thread_rng();
57
58        let layers = layers
59            .iter()
60            .zip(layers.iter().skip(1))
61            .map(|(a, b)| Layer::from_size(*b, *a, &mut rng))
62            .collect::<Vec<Layer>>();
63
64        Self {
65            layers,
66            activation_function: None,
67        }
68    }
69
70    /// Feeds the given inputs to the neural network and returns the output. The inputs
71    /// vector must have the same size as the first layer of the network.
72    ///
73    /// # Example
74    ///
75    /// ```
76    /// # use only_brain::NeuralNetwork;
77    /// # use nalgebra::dmatrix;
78    /// # use nalgebra::dvector;
79    /// # fn main() {
80    /// let mut nn = NeuralNetwork::new(&vec![1, 1]);
81    ///
82    /// nn.set_layer_weights(1, dmatrix![0.5]);
83    /// nn.set_layer_biases(1, dvector![0.5]);
84    ///
85    /// let input = vec![0.5];
86    /// let output = nn.feed_forward(&input);
87    /// assert_eq!(output, vec![0.679178699175393]);
88    /// # }
89    /// ```
90    pub fn feed_forward(&self, inputs: &Vec<f64>) -> Vec<f64> {
91        let mut outputs = DVector::from(Vec::clone(inputs));
92
93        for layer in &self.layers {
94            outputs = layer.forward(&outputs, self.activation_function());
95        }
96
97        outputs.data.into()
98    }
99
100    /// Sets the layer weights for the given layer. The weights matrix must have the size
101    /// of the layer neurons x layer inputs. The layer index must be greater than 0 since it
102    /// corresponds to the layer number that receives these weights.
103    pub fn set_layer_weights(&mut self, layer: usize, weights: DMatrix<f64>) {
104        if layer <= 0 {
105            panic!("Invalid layer index");
106        }
107        self.layers[layer - 1].set_weights(weights);
108    }
109
110    /// Sets the layer biases for the given layer. The biases vector must have the size
111    /// of the layer neurons. The layer index must be greater than 0 since the input layer
112    /// does not have biases.
113    pub fn set_layer_biases(&mut self, layer: usize, biases: DVector<f64>) {
114        if layer <= 0 {
115            panic!("Invalid layer index");
116        }
117        self.layers[layer - 1].set_biases(biases);
118    }
119
120    /// Sets the weight of a specific neuron connection. The layer index must be greater
121    /// than 0 since the input layer does not have weights.
122    pub fn set_weight(&mut self, layer: usize, neuron: usize, input: usize, weight: f64) {
123        if layer <= 0 {
124            panic!("Invalid layer index");
125        }
126        self.layers[layer - 1].set_weight(neuron, input, weight);
127    }
128
129    /// Gets the weight of a specific neuron connection. The layer index must be greater
130    /// than 0 since the input layer does not have weights.
131    pub fn get_weight(&self, layer: usize, neuron: usize, input: usize) -> f64 {
132        if layer <= 0 {
133            panic!("Invalid layer index");
134        }
135        self.layers[layer - 1].weights()[(neuron, input)]
136    }
137
138    /// Returns the number of layers of the neural network.
139    pub fn num_layers(&self) -> usize {
140        self.layers.len() + 1
141    }
142
143    /// Returns the number of neurons of the given layer.
144    pub fn layer_size(&self, layer: usize) -> usize {
145        if layer == 0 {
146            return self.input_layer_size();
147        }
148        self.layers[layer - 1].size()
149    }
150
151    fn input_layer_size(&self) -> usize {
152        self.layers[0].weights().ncols()
153    }
154
155    fn activation_function(&self) -> fn(f64) -> f64 {
156        if self.activation_function.is_none() {
157            return sigmoid;
158        }
159        let functions_map = ACTION_FUNCTIONS_MAP
160            .iter()
161            .cloned()
162            .collect::<HashMap<ActivationFunction, _>>();
163
164        functions_map
165            .get(&self.activation_function.unwrap())
166            .unwrap()
167            .clone()
168    }
169
170    pub fn print(&self) {
171        for layer in &self.layers {
172            println!("{} {}", layer.weights(), layer.biases());
173        }
174    }
175}
176
177impl fmt::Display for NeuralNetwork {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        writeln!(f, "Neural Network")?;
180        writeln!(f, "Activation Function: {:?}", self.activation_function)?;
181        writeln!(f)?;
182        writeln!(f, "Input Layer Size: {}", self.input_layer_size())?;
183        writeln!(f)?;
184        for (index, layer) in self.layers.iter().enumerate() {
185            writeln!(f, "Layer {}: {}", index + 1, layer)?;
186        }
187        Ok(())
188    }
189}