pr_ml/
neural.rs

1//! Neural network module.
2
3use std::borrow::Borrow;
4
5use super::{Matrix, RowVector};
6
7/// A linear layer in a neural network, with sigmoid activation.
8pub struct LinearLayer<const I: usize, const O: usize> {
9    weights: Matrix<I, O>,
10    biases: RowVector<O>,
11    learning_rate: f32,
12}
13
14impl<const I: usize, const O: usize> LinearLayer<I, O> {
15    /// Creates a new linear layer with random weights and biases.
16    #[must_use]
17    pub fn new(learning_rate: f32) -> Self {
18        // Xavier/Glorot initialization for sigmoid activation
19        // Scale factor: sqrt(2 / (fan_in + fan_out))
20        #[allow(
21            clippy::cast_precision_loss,
22            reason = "We're not dealing with huge matrices here."
23        )]
24        let scale = (2.0 / (I + O) as f32).sqrt();
25        let mut weights = Matrix::<I, O>::new_random();
26        weights.apply(|x| *x = (*x - 0.5) * 2.0 * scale);
27        let biases = RowVector::zeros();
28        Self {
29            weights,
30            biases,
31            learning_rate,
32        }
33    }
34
35    /// Performs a forward pass through the layer, returning the output and activated output.
36    pub fn feedforward(&self, input: &RowVector<I>) -> (RowVector<O>, RowVector<O>) {
37        let output = input * self.weights + self.biases;
38        let activated_output = output.map(sigmoid);
39        (output, activated_output)
40    }
41
42    /// Performs backpropagation for this layer, updating weights and biases, and returning the error to propagate to the previous layer.
43    ///
44    /// # Arguments
45    ///
46    /// - `input`: The input to this layer (activated output from previous layer, or raw input for first layer)
47    /// - `output`: The pre-activation output from the forward pass
48    /// - `error_next`: The error propagated from the next layer (or output error for final layer)
49    /// - `learning_rate`: The learning rate for weight updates
50    ///
51    /// # Returns
52    ///
53    /// The error to propagate to the previous layer
54    pub fn backpropagate(
55        &mut self,
56        input: &RowVector<I>,
57        output: &RowVector<O>,
58        error_next: &RowVector<O>,
59    ) -> RowVector<I> {
60        // Compute gradient: derivative of sigmoid times error from next layer
61        let gradient = output.map(d_sigmoid).component_mul(error_next);
62
63        // Compute weight and bias deltas
64        let weights_delta = input.transpose() * gradient;
65
66        // Update weights and biases
67        self.weights -= weights_delta * self.learning_rate;
68        self.biases -= gradient * self.learning_rate;
69
70        // Propagate error to previous layer
71        error_next * self.weights.transpose()
72    }
73}
74
75/// A trait for neural networks, with `I` inputs and `O` outputs.
76///
77/// # Associated Types
78///
79/// - [`LayerOutputs`](NeuralNetwork::LayerOutputs): The type representing the outputs of each layer in the network.
80///
81/// # Required methods
82///
83/// - [`feedforward`](NeuralNetwork::feedforward)
84/// - [`backpropagate`](NeuralNetwork::backpropagate)
85///
86/// # Provided methods
87///
88/// - [`train_once`](NeuralNetwork::train_once)
89/// - [`train`](NeuralNetwork::train)
90pub trait NeuralNetwork<const I: usize, const O: usize> {
91    /// The type representing the outputs of each layer in the network.
92    type LayerOutputs: LayerOutputs<O>;
93
94    /// Feedforward the input through the network, updating layer outputs, and returning the final output.
95    fn feedforward(&self, input: &RowVector<I>) -> Self::LayerOutputs;
96
97    /// Perform backpropagation to adjust weights and biases based on the target output, returning the loss.
98    ///
99    /// # Arguments
100    ///
101    /// - `input`: The input matrix of shape (1, I).
102    /// - `output`: The target output matrix of shape (1, O).
103    /// - `layer_outputs`: The outputs for each layer.
104    fn backpropagate(
105        &mut self,
106        input: &RowVector<I>,
107        output: &RowVector<O>,
108        layer_outputs: Self::LayerOutputs,
109    ) -> f32;
110
111    /// Train the neural network once with a pair of input and output, returning the average loss.
112    fn train_once<BI, BO>(&mut self, input: BI, output: BO) -> f32
113    where
114        BI: Borrow<RowVector<I>>,
115        BO: Borrow<RowVector<O>>,
116    {
117        let (input, output) = (input.borrow(), output.borrow());
118        let layer_outputs = self.feedforward(input);
119        self.backpropagate(input, output, layer_outputs)
120    }
121
122    /// Train the neural network with the provided data, calling an optional `callback` after each sample, returning the total number of samples and the final average loss.
123    fn train<D, T>(&mut self, data: D, mut callback: Option<impl FnMut(usize, f32)>) -> (usize, f32)
124    where
125        D: IntoIterator<Item = T>,
126        T: Borrow<(RowVector<I>, RowVector<O>)>,
127    {
128        let data = data.into_iter();
129        let mut total_loss = 0.0;
130        let mut count = 0;
131
132        for item in data {
133            let (input, output) = item.borrow();
134            let single_loss = self.train_once(input, output);
135            total_loss += single_loss;
136            if let Some(ref mut cb) = callback {
137                cb(count, single_loss);
138            }
139            count += 1;
140        }
141
142        if count > 0 {
143            #[allow(
144                clippy::cast_precision_loss,
145                reason = "We're calculating ratios, so precision loss is acceptable."
146            )]
147            (count, total_loss / count as f32)
148        } else {
149            (0, 0.0)
150        }
151    }
152
153    /// Predict the output for a given input.
154    fn predict(&self, input: &RowVector<I>) -> RowVector<O> {
155        let layer_outputs = self.feedforward(input);
156        layer_outputs.get_output()
157    }
158}
159
160/// A trait for layer outputs in a neural network, requiring a method to get the final output.
161pub trait LayerOutputs<const O: usize> {
162    /// Get the final output of the network from the layer outputs.
163    fn get_output(self) -> RowVector<O>;
164}
165
166/// A simple feedforward neural network with `I` inputs, `H` neurons in the hidden layer, and `O` outputs.
167pub struct SimpleNeuralNetwork<const I: usize, const H: usize, const O: usize> {
168    input_layer: LinearLayer<I, H>,
169    output_layer: LinearLayer<H, O>,
170}
171
172impl<const I: usize, const H: usize, const O: usize> NeuralNetwork<I, O>
173    for SimpleNeuralNetwork<I, H, O>
174{
175    type LayerOutputs = SimpleLayerOutputs<H, O>;
176
177    fn feedforward(&self, input: &RowVector<I>) -> Self::LayerOutputs {
178        let (hidden_output, activated_hidden_output) = self.input_layer.feedforward(input);
179        let (final_output, activated_final_output) =
180            self.output_layer.feedforward(&activated_hidden_output);
181        SimpleLayerOutputs {
182            hidden_output,
183            final_output,
184            activated_hidden_output,
185            activated_final_output,
186        }
187    }
188
189    fn backpropagate(
190        &mut self,
191        input: &RowVector<I>,
192        output: &RowVector<O>,
193        layer_outputs: Self::LayerOutputs,
194    ) -> f32 {
195        // Calculate output error
196        let error_output = layer_outputs.activated_final_output - output;
197
198        // Backpropagate through layers
199        let error_hidden = self.output_layer.backpropagate(
200            &layer_outputs.activated_hidden_output,
201            &layer_outputs.final_output,
202            &error_output,
203        );
204        self.input_layer
205            .backpropagate(input, &layer_outputs.hidden_output, &error_hidden);
206
207        // Calculate and return loss
208        error_output.map(|e| e * e).sum()
209    }
210}
211
212impl<const I: usize, const H: usize, const O: usize> SimpleNeuralNetwork<I, H, O> {
213    /// Creates a new [`SimpleNeuralNetwork`] with random weights and biases.
214    #[must_use]
215    pub fn new(learning_rate: f32) -> Self {
216        Self {
217            input_layer: LinearLayer::new(learning_rate),
218            output_layer: LinearLayer::new(learning_rate),
219        }
220    }
221}
222
223/// Outputs for [`SimpleNeuralNetwork`].
224#[allow(
225    clippy::struct_field_names,
226    reason = "Names reflect their purpose in the network."
227)]
228pub struct SimpleLayerOutputs<const H: usize, const O: usize> {
229    hidden_output: RowVector<H>,
230    final_output: RowVector<O>,
231    activated_hidden_output: RowVector<H>,
232    activated_final_output: RowVector<O>, // FIXME: Is activation needed here?
233}
234
235impl<const H: usize, const O: usize> LayerOutputs<O> for SimpleLayerOutputs<H, O> {
236    fn get_output(self) -> RowVector<O> {
237        self.activated_final_output // FIXME: Is activation needed here?
238    }
239}
240
241// Activation functions
242
243fn sigmoid(x: f32) -> f32 {
244    1.0 / (1.0 + (-x).exp())
245}
246
247fn d_sigmoid(x: f32) -> f32 {
248    let y = sigmoid(x);
249    y * (1.0 - y)
250}