ice_nine/
grad.rs

1use anyhow::Error;
2use byteorder::{LittleEndian, ReadBytesExt};
3use ndarray::{Array1, Array2, ErrorKind, ShapeError};
4use std::io::{Read, Write};
5use std::path::Path;
6
7pub struct Network {
8    pub layers: Vec<Layer>,
9}
10
11impl Network {
12    /// Run inference without calculating gradients
13    pub fn f(&self, v: Array1<f64>) -> Array1<f64> {
14        self.layers.iter().fold(v, |prev_v, layer| layer.f(&prev_v))
15    }
16
17    /// Add a new layer at the end
18    pub fn push(&mut self, layer: Layer) -> Result<(), ShapeError> {
19        if let Some(last_layer) = self.layers.last() {
20            if last_layer.dims().0 != layer.dims().1 {
21                return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
22            }
23        }
24        self.layers.push(layer);
25        Ok(())
26    }
27
28    /// Create a new network with no layers
29    pub fn new() -> Self {
30        Network { layers: Vec::new() }
31    }
32
33    /// Update weights with stored gradients
34    pub fn update(&mut self, learning_rate: f64) {
35        self.layers
36            .iter_mut()
37            .for_each(|layer| layer.update(learning_rate));
38    }
39
40    /// Reset gradients to 0
41    pub fn zero_grad(&mut self) {
42        self.layers.iter_mut().for_each(|layer| layer.zero_grad());
43    }
44
45    /// Save the weights as a binary file
46    pub fn save_weights(&self, output_path: &Path) -> Result<(), Error> {
47        let output_dir = match output_path.parent() {
48            Some(output_dir) => output_dir,
49            None => return Err(Error::msg("Failed to get output_path directory")),
50        };
51        std::fs::create_dir_all(output_dir)?;
52        let mut output_file = std::fs::File::create(output_path)?;
53        let bytes: Vec<u8> = self
54            .layers
55            .iter()
56            .flat_map(|layer| layer.weights.clone().into_raw_vec())
57            .flat_map(|x| x.to_le_bytes())
58            .collect();
59        output_file.write_all(&bytes)?;
60        Ok(())
61    }
62
63    /// Load the weights from a binary file
64    /// Returns ErrorKind::UnexpectedEof if there are not enough weights
65    pub fn load_weights(&mut self, input_path: &Path) -> Result<(), Error> {
66        let mut input_file = std::fs::File::open(input_path)?;
67        let mut bytes = Vec::new();
68        input_file.read_to_end(&mut bytes)?;
69
70        let mut offset = 0;
71        for layer in &mut self.layers {
72            let num_elements = layer.weights.len();
73            let mut raw_vec = Vec::with_capacity(num_elements);
74            for _i in 0..num_elements {
75                let value = (&bytes[offset..offset + 8]).read_f64::<LittleEndian>()?;
76                raw_vec.push(value);
77                offset += 8; // 8 = num bytes in f64
78            }
79            layer.weights = Array2::from_shape_vec(layer.weights.raw_dim(), raw_vec)?;
80        }
81        Ok(())
82    }
83}
84
85impl Default for Network {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91pub struct Layer {
92    pub activation: Box<dyn HasGrad<f64>>,
93    pub weights: Array2<f64>,
94    pub gradients: Array2<f64>,
95}
96
97impl Layer {
98    /// Apply the layer, updating gradients
99    pub fn f(&self, v: &Array1<f64>) -> Array1<f64> {
100        let linear = self.weights.dot(v);
101        linear.mapv_into(|x| self.activation.f(x))
102    }
103
104    /// Get weight matrix dims as a tuple
105    pub fn dims(&self) -> (usize, usize) {
106        self.weights.dim()
107    }
108
109    /// Update layer weights in-place
110    pub fn update(&mut self, learning_rate: f64) {
111        self.weights -= &(learning_rate * &self.gradients);
112    }
113
114    /// Set gradients to 0
115    pub fn zero_grad(&mut self) {
116        self.gradients = Array2::zeros(self.gradients.raw_dim());
117    }
118}
119
120pub trait HasGrad<T> {
121    /// Differentiable function
122    fn f(&self, x: T) -> T;
123    /// Derivative of function
124    fn d_f(&self, x: T) -> T;
125}
126
127pub trait Loss<T: Clone> {
128    fn l(&self, output: &Array1<f64>, target: &T) -> f64;
129    fn d_l(&self, output: &Array1<f64>, target: &T) -> Array1<f64>;
130}
131
132pub struct Optimizer<T> {
133    pub loss: Box<dyn Loss<T>>,
134    pub network: Network,
135    pub max_gradient: Option<f64>,
136}
137
138const EMPTY_ACTIVATIONS_ERROR_MESSAGE: &str = "Fatal: Vec [activations] was empty, but it should always have at least one member (logic error)";
139
140impl<T: Clone> Optimizer<T> {
141    /// Apply the network to an input and update gradients, returning the output and loss as a
142    /// tuple
143    pub fn apply(&mut self, input: &Array1<f64>, target: &T) -> (Array1<f64>, f64) {
144        let num_layers = self.network.layers.len();
145        if num_layers == 0 {
146            return (input.clone(), self.loss.l(input, target));
147        }
148
149        // Forward pass
150        let mut linear_outputs: Vec<Array1<f64>> = vec![Array1::zeros(input.raw_dim())];
151        let mut activations: Vec<Array1<f64>> = vec![input.clone()];
152        let mut d_activations: Vec<Array1<f64>> = vec![Array1::zeros(input.raw_dim())];
153        for layer in self.network.layers.iter() {
154            let prev_activation = activations.last().expect(EMPTY_ACTIVATIONS_ERROR_MESSAGE);
155            let linear_output = layer.weights.dot(prev_activation);
156            let activation = linear_output.clone().mapv_into(|x| layer.activation.f(x));
157            let d_activation = linear_output.clone().mapv_into(|x| layer.activation.d_f(x));
158            linear_outputs.push(linear_output);
159            activations.push(activation);
160            d_activations.push(d_activation);
161        }
162        let output = activations.last().expect(EMPTY_ACTIVATIONS_ERROR_MESSAGE);
163        let loss = self.loss.l(output, target);
164
165        // Backward pass
166        let d_l = self.loss.d_l(output, target);
167        let mut loss_activation_gradients: Vec<Array1<f64>> = vec![d_l];
168
169        for k in (1..num_layers + 1).rev() {
170            let layer = &mut self.network.layers[k - 1];
171
172            for i in 0..layer.gradients.nrows() {
173                for j in 0..layer.gradients.ncols() {
174                    // dbg!(loss_activation_gradients[loss_activation_gradients.len() - 1][i]);
175                    // dbg!(d_activations[k][i]);
176                    // dbg!(activations[activations.len() - 1][j]);
177                    layer.gradients[[i, j]] += loss_activation_gradients[num_layers - k][i]
178                        * d_activations[k][i]
179                        * activations[k - 1][j];
180                    if let Some(max_gradient) = self.max_gradient {
181                        layer.gradients[[i, j]] =
182                            layer.gradients[[i, j]].clamp(-max_gradient, max_gradient);
183                    }
184                }
185            }
186
187            let mut new_loss_activation_gradient: Array1<f64> =
188                Array1::zeros(activations[k - 1].raw_dim());
189            for i in 0..activations[k - 1].len() {
190                for j in 0..activations[k].len() {
191                    new_loss_activation_gradient[i] += loss_activation_gradients[num_layers - k][j]
192                        * d_activations[k][j]
193                        * self.network.layers[k - 1].weights[[j, i]];
194                }
195            }
196            loss_activation_gradients.push(new_loss_activation_gradient);
197        }
198
199        (
200            activations.pop().expect(EMPTY_ACTIVATIONS_ERROR_MESSAGE),
201            loss,
202        )
203    }
204
205    /// Update weights and zero the gradients
206    pub fn step(&mut self, learning_rate: f64) {
207        self.network.update(learning_rate);
208        self.network.zero_grad();
209    }
210}