matrux/
nn.rs

1use std::collections::HashMap;
2
3use crate::{Scalar, MatrixPlan, Matrix, Layer, Activation, DenseLayer, Optimizer};
4
5#[derive(Default)]
6pub struct NeuralNetworkBuilder<I: Scalar> {
7    plan: Option<MatrixPlan<I>>,
8    layers: Vec<Box<dyn Layer<I>>>,
9    inputs: usize,
10    trained_steps: usize,
11}
12
13impl<I: Scalar> NeuralNetworkBuilder<I> {
14    pub fn new() -> Self {
15        Default::default()
16    }
17
18    pub fn input(mut self, count: usize) -> Self {
19        assert!(self.plan.is_none());
20        self.plan = Some(MatrixPlan::input(count, 1, "input"));
21        self.inputs = count;
22        self
23    }
24
25    pub fn add_dense_layer<A: Activation>(self, count: usize, activation: A) -> Self {
26        let last_count = if self.layers.is_empty() {
27            self.inputs
28        } else {
29            self.layers.last().unwrap().output_shape().0
30        };
31        let weights = Matrix::new(count, last_count).fill(I::from_f64(0.5));
32
33        self.add_dense_layer_weighted(weights, activation)
34    }
35
36    pub fn add_dense_layer_weighted<A: Activation>(mut self, weights: Matrix<I>, activation: A) -> Self {
37        assert!(self.plan.is_some());
38        let mut layer = DenseLayer::new(
39            weights,
40            activation,
41        );
42        layer.prepare_input(self.layers.len());
43
44        self.plan = Some(layer.forward_plan(self.plan.take().unwrap()));
45        self.layers.push(Box::new(layer));
46        self
47    }
48
49    // pub fn layers(&self) -> &[Layer<I>] {
50    //     &self.layers[..]
51    // }
52
53    pub fn hidden_layers(&self) -> usize {
54        self.layers.len()
55    }
56
57    // pub fn set_layer_weights(&mut self, layer: usize, weights: &[I]) {
58    //     assert!(layer < self.layers.len());
59    //     assert_eq!(self.layers[layer].weights.rows(), weights.len());
60    //     for (i, weight) in weights.iter().copied().enumerate() {
61    //         self.layers[layer].weights[i][0] = weight;
62    //     }
63    // }
64
65    pub fn fill_plan_weights(&self, output: &mut HashMap<String, Matrix<I>>) {
66        for layer in self.layers.iter() {
67            layer.assign_input(output);
68        }
69    }
70
71    pub fn eval(&self, inputs: &[I]) -> Vec<I> {
72        assert!(self.plan.is_some());
73
74        let matrix = Matrix::from_col(inputs.iter().copied());
75
76        let mut inputs = HashMap::new();
77        inputs.insert("input".to_string(), matrix);
78        self.fill_plan_weights(&mut inputs);
79        let (outputs, _) = self.plan.as_ref().unwrap().execute_cpu(&inputs);
80        assert_eq!(outputs.cols(), 1);
81        outputs.col(0).collect()
82    }
83
84    pub fn apply_backprop<O: Optimizer<I>>(&mut self, optimizer: &mut O, gradients: Vec<Matrix<I>>) {
85        assert_eq!(self.layers.len(), gradients.len());
86        self.layers.iter_mut().zip(gradients.into_iter()).for_each(|(current, gradient)| {
87            //TODO: weight-less layers
88            let weights = current.get_weights().unwrap().clone();
89            current.set_weights(optimizer.optimize(weights, gradient, self.trained_steps));
90        });
91        self.trained_steps += 1;
92    }
93
94    pub fn plan_backprop(&self, batch_size: usize) -> MatrixPlan<I> {
95        assert!(!self.layers.is_empty());
96
97        let targets = MatrixPlan::<I>::input(self.plan.as_ref().unwrap().rows(), batch_size, "targets");
98        let inputs = MatrixPlan::<I>::input(self.inputs, batch_size, "inputs");
99
100        let mut state = inputs;
101        let mut layer_values = vec![state.clone()];
102
103        for layer in &self.layers {
104            state = layer.forward_plan(state);
105            layer_values.push(state.clone());
106        }
107
108        let outputs = layer_values.last().cloned().unwrap().output("outputs");
109
110        let diff = layer_values.last().cloned().unwrap() - targets;
111
112        let mut prior = diff;
113        let mut output = vec![];
114        for ((layer, layer_value), lower_layer_value) in self.layers.iter().rev()
115            .zip(layer_values.iter().rev())
116            .zip(layer_values.iter().rev().skip(1)) {
117            let (new_prior, out) = layer.backward_plan(prior, layer_value.clone(), lower_layer_value.clone());
118            prior = new_prior;
119            output.push(out);
120        }
121
122        output.reverse();
123        output.iter_mut().enumerate().for_each(|(i, sigma)| {
124            let old = std::mem::take(sigma);
125            *sigma = old.output(format!("gradient_{}", i));
126        });
127        println!("output = {:#?}", output);
128        output.push(outputs);
129
130        MatrixPlan::merge_outputs(output)
131    }
132}