only_brain/neural_network.rs
1use crate::activation_functions::{sigmoid, ActivationFunction, ACTION_FUNCTIONS_MAP};
2use crate::layer::Layer;
3use nalgebra::{DMatrix, DVector};
4use rand::thread_rng;
5use std::collections::HashMap;
6use std::fmt;
7use serde::{Deserialize, Serialize};
8
9/// Neural Network
10///
11/// This is the main struct of the library. It contains a vector of layers and an
12/// activation function. You can use this struct and its methods to create, manipulate and
13/// even implement your ways to train a neural network.
14///
15/// # Example
16///
17/// ```
18/// use only_brain::NeuralNetwork;
19/// use nalgebra::dmatrix;
20/// use nalgebra::dvector;
21///
22/// fn main() {
23/// let mut nn = NeuralNetwork::new(&vec![2, 2, 1]);
24///
25/// nn.set_layer_weights(1, dmatrix![0.1, 0.2;
26/// 0.3, 0.4]);
27/// nn.set_layer_biases(1, dvector![0.1, 0.2]);
28///
29/// nn.set_layer_weights(2, dmatrix![0.9, 0.8]);
30/// nn.set_layer_biases(2, dvector![0.1]);
31///
32/// let input = vec![0.5, 0.2];
33/// let output = nn.feed_forward(&input);
34///
35/// println!("{:?}", output);
36/// }
37/// ```
38#[derive(Serialize, Deserialize)]
39pub struct NeuralNetwork {
40 layers: Vec<Layer>,
41 activation_function: Option<ActivationFunction>,
42}
43
44impl NeuralNetwork {
45 /// Creates a new Neural Network with the given layers. The layers vector must contain
46 /// the number of neurons for each layer.
47 ///
48 /// # Example
49 ///
50 /// ```
51 /// # use only_brain::NeuralNetwork;
52 /// # fn main() {
53 /// let nn = NeuralNetwork::new(&vec![2, 2, 1]);
54 /// # }
55 pub fn new(layers: &Vec<usize>) -> Self {
56 let mut rng = thread_rng();
57
58 let layers = layers
59 .iter()
60 .zip(layers.iter().skip(1))
61 .map(|(a, b)| Layer::from_size(*b, *a, &mut rng))
62 .collect::<Vec<Layer>>();
63
64 Self {
65 layers,
66 activation_function: None,
67 }
68 }
69
70 /// Feeds the given inputs to the neural network and returns the output. The inputs
71 /// vector must have the same size as the first layer of the network.
72 ///
73 /// # Example
74 ///
75 /// ```
76 /// # use only_brain::NeuralNetwork;
77 /// # use nalgebra::dmatrix;
78 /// # use nalgebra::dvector;
79 /// # fn main() {
80 /// let mut nn = NeuralNetwork::new(&vec![1, 1]);
81 ///
82 /// nn.set_layer_weights(1, dmatrix![0.5]);
83 /// nn.set_layer_biases(1, dvector![0.5]);
84 ///
85 /// let input = vec![0.5];
86 /// let output = nn.feed_forward(&input);
87 /// assert_eq!(output, vec![0.679178699175393]);
88 /// # }
89 /// ```
90 pub fn feed_forward(&self, inputs: &Vec<f64>) -> Vec<f64> {
91 let mut outputs = DVector::from(Vec::clone(inputs));
92
93 for layer in &self.layers {
94 outputs = layer.forward(&outputs, self.activation_function());
95 }
96
97 outputs.data.into()
98 }
99
100 /// Sets the layer weights for the given layer. The weights matrix must have the size
101 /// of the layer neurons x layer inputs. The layer index must be greater than 0 since it
102 /// corresponds to the layer number that receives these weights.
103 pub fn set_layer_weights(&mut self, layer: usize, weights: DMatrix<f64>) {
104 if layer <= 0 {
105 panic!("Invalid layer index");
106 }
107 self.layers[layer - 1].set_weights(weights);
108 }
109
110 /// Sets the layer biases for the given layer. The biases vector must have the size
111 /// of the layer neurons. The layer index must be greater than 0 since the input layer
112 /// does not have biases.
113 pub fn set_layer_biases(&mut self, layer: usize, biases: DVector<f64>) {
114 if layer <= 0 {
115 panic!("Invalid layer index");
116 }
117 self.layers[layer - 1].set_biases(biases);
118 }
119
120 /// Sets the weight of a specific neuron connection. The layer index must be greater
121 /// than 0 since the input layer does not have weights.
122 pub fn set_weight(&mut self, layer: usize, neuron: usize, input: usize, weight: f64) {
123 if layer <= 0 {
124 panic!("Invalid layer index");
125 }
126 self.layers[layer - 1].set_weight(neuron, input, weight);
127 }
128
129 /// Gets the weight of a specific neuron connection. The layer index must be greater
130 /// than 0 since the input layer does not have weights.
131 pub fn get_weight(&self, layer: usize, neuron: usize, input: usize) -> f64 {
132 if layer <= 0 {
133 panic!("Invalid layer index");
134 }
135 self.layers[layer - 1].weights()[(neuron, input)]
136 }
137
138 /// Returns the number of layers of the neural network.
139 pub fn num_layers(&self) -> usize {
140 self.layers.len() + 1
141 }
142
143 /// Returns the number of neurons of the given layer.
144 pub fn layer_size(&self, layer: usize) -> usize {
145 if layer == 0 {
146 return self.input_layer_size();
147 }
148 self.layers[layer - 1].size()
149 }
150
151 fn input_layer_size(&self) -> usize {
152 self.layers[0].weights().ncols()
153 }
154
155 fn activation_function(&self) -> fn(f64) -> f64 {
156 if self.activation_function.is_none() {
157 return sigmoid;
158 }
159 let functions_map = ACTION_FUNCTIONS_MAP
160 .iter()
161 .cloned()
162 .collect::<HashMap<ActivationFunction, _>>();
163
164 functions_map
165 .get(&self.activation_function.unwrap())
166 .unwrap()
167 .clone()
168 }
169
170 pub fn print(&self) {
171 for layer in &self.layers {
172 println!("{} {}", layer.weights(), layer.biases());
173 }
174 }
175}
176
177impl fmt::Display for NeuralNetwork {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 writeln!(f, "Neural Network")?;
180 writeln!(f, "Activation Function: {:?}", self.activation_function)?;
181 writeln!(f)?;
182 writeln!(f, "Input Layer Size: {}", self.input_layer_size())?;
183 writeln!(f)?;
184 for (index, layer) in self.layers.iter().enumerate() {
185 writeln!(f, "Layer {}: {}", index + 1, layer)?;
186 }
187 Ok(())
188 }
189}