1use std::collections::HashMap;
2
3use crate::{Scalar, MatrixPlan, Matrix, Layer, Activation, DenseLayer, Optimizer};
4
5#[derive(Default)]
6pub struct NeuralNetworkBuilder<I: Scalar> {
7 plan: Option<MatrixPlan<I>>,
8 layers: Vec<Box<dyn Layer<I>>>,
9 inputs: usize,
10 trained_steps: usize,
11}
12
13impl<I: Scalar> NeuralNetworkBuilder<I> {
14 pub fn new() -> Self {
15 Default::default()
16 }
17
18 pub fn input(mut self, count: usize) -> Self {
19 assert!(self.plan.is_none());
20 self.plan = Some(MatrixPlan::input(count, 1, "input"));
21 self.inputs = count;
22 self
23 }
24
25 pub fn add_dense_layer<A: Activation>(self, count: usize, activation: A) -> Self {
26 let last_count = if self.layers.is_empty() {
27 self.inputs
28 } else {
29 self.layers.last().unwrap().output_shape().0
30 };
31 let weights = Matrix::new(count, last_count).fill(I::from_f64(0.5));
32
33 self.add_dense_layer_weighted(weights, activation)
34 }
35
36 pub fn add_dense_layer_weighted<A: Activation>(mut self, weights: Matrix<I>, activation: A) -> Self {
37 assert!(self.plan.is_some());
38 let mut layer = DenseLayer::new(
39 weights,
40 activation,
41 );
42 layer.prepare_input(self.layers.len());
43
44 self.plan = Some(layer.forward_plan(self.plan.take().unwrap()));
45 self.layers.push(Box::new(layer));
46 self
47 }
48
49 pub fn hidden_layers(&self) -> usize {
54 self.layers.len()
55 }
56
57 pub fn fill_plan_weights(&self, output: &mut HashMap<String, Matrix<I>>) {
66 for layer in self.layers.iter() {
67 layer.assign_input(output);
68 }
69 }
70
71 pub fn eval(&self, inputs: &[I]) -> Vec<I> {
72 assert!(self.plan.is_some());
73
74 let matrix = Matrix::from_col(inputs.iter().copied());
75
76 let mut inputs = HashMap::new();
77 inputs.insert("input".to_string(), matrix);
78 self.fill_plan_weights(&mut inputs);
79 let (outputs, _) = self.plan.as_ref().unwrap().execute_cpu(&inputs);
80 assert_eq!(outputs.cols(), 1);
81 outputs.col(0).collect()
82 }
83
84 pub fn apply_backprop<O: Optimizer<I>>(&mut self, optimizer: &mut O, gradients: Vec<Matrix<I>>) {
85 assert_eq!(self.layers.len(), gradients.len());
86 self.layers.iter_mut().zip(gradients.into_iter()).for_each(|(current, gradient)| {
87 let weights = current.get_weights().unwrap().clone();
89 current.set_weights(optimizer.optimize(weights, gradient, self.trained_steps));
90 });
91 self.trained_steps += 1;
92 }
93
94 pub fn plan_backprop(&self, batch_size: usize) -> MatrixPlan<I> {
95 assert!(!self.layers.is_empty());
96
97 let targets = MatrixPlan::<I>::input(self.plan.as_ref().unwrap().rows(), batch_size, "targets");
98 let inputs = MatrixPlan::<I>::input(self.inputs, batch_size, "inputs");
99
100 let mut state = inputs;
101 let mut layer_values = vec![state.clone()];
102
103 for layer in &self.layers {
104 state = layer.forward_plan(state);
105 layer_values.push(state.clone());
106 }
107
108 let outputs = layer_values.last().cloned().unwrap().output("outputs");
109
110 let diff = layer_values.last().cloned().unwrap() - targets;
111
112 let mut prior = diff;
113 let mut output = vec![];
114 for ((layer, layer_value), lower_layer_value) in self.layers.iter().rev()
115 .zip(layer_values.iter().rev())
116 .zip(layer_values.iter().rev().skip(1)) {
117 let (new_prior, out) = layer.backward_plan(prior, layer_value.clone(), lower_layer_value.clone());
118 prior = new_prior;
119 output.push(out);
120 }
121
122 output.reverse();
123 output.iter_mut().enumerate().for_each(|(i, sigma)| {
124 let old = std::mem::take(sigma);
125 *sigma = old.output(format!("gradient_{}", i));
126 });
127 println!("output = {:#?}", output);
128 output.push(outputs);
129
130 MatrixPlan::merge_outputs(output)
131 }
132}