ducky_learn/
train.rs

1use super::layers::*;
2use ndarray::{arr1, Array1, Array2};
3use std::iter::zip;
4
5pub fn train<L>(
6    model: &Vec<L>,
7    train_data: Array2<f64>,
8    train_lbl: Array2<f64>,
9    test_data: Array2<f64>,
10    test_lbl: Array2<f64>,
11) where
12    L: Layer1d,
13{
14    todo!()
15}
16
17//noinspection RsBorrowChecker For some reason it says that the item is moved eventhough it isn't
18pub fn forward_pass<L>(model: &Vec<L>, data: Array1<f64>) -> (Vec<Array1<f64>>, Vec<Array1<f64>>)
19where
20    L: Layer1d,
21{
22    let mut weights_bias_vec: Vec<Array1<f64>> = Vec::with_capacity(model.len());
23    let mut activation_vec: Vec<Array1<f64>> = Vec::with_capacity(model.len());
24
25    let mut activation_pass = data;
26    let mut weight_pass;
27
28    for layer in model.iter() {
29        (weight_pass, activation_pass) = layer.pass(activation_pass.clone());
30
31        weights_bias_vec.push(weight_pass);
32        activation_vec.push(activation_pass.clone());
33    }
34
35    (weights_bias_vec, activation_vec)
36}
37
38pub fn back_propagation<L>(
39    model: &Vec<L>,
40    weights_bias_vec: Vec<Array1<f64>>,
41    activation_vec: Vec<Array1<f64>>,
42    target_out: Array1<f64>,
43) where
44    L: Layer1d,
45{
46    todo!()
47}
48
49#[cfg(test)]
50mod train_tests {
51    use super::*;
52    use crate::layers::*;
53    use crate::activations::*;
54    use ndarray::arr1;
55
56    #[test]
57    fn forwards_pass_1() {
58        let model = vec![
59            Dense1d::new(1, 3, relu_1d, deriv_relu_1d),
60            Dense1d::new(3, 5, relu_1d, deriv_relu_1d),
61            Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
62        ];
63
64        let (weights_bias_vec, activation_vec) = forward_pass(&model, arr1(&[1.]));
65
66        assert_eq!(weights_bias_vec.len(), 3);
67        assert_eq!(activation_vec.len(), 3)
68    }
69
70    #[test]
71    fn forwards_pass_2() {
72        let model = vec![
73            Dense1d::new(5, 5, relu_1d, deriv_relu_1d),
74            Dense1d::new(5, 5, relu_1d, deriv_relu_1d),
75            Dense1d::new(5, 5, softmax_1d, deriv_relu_1d),
76        ];
77
78        let (weights_bias_vec, activation_vec) =
79            forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
80
81        assert_eq!(weights_bias_vec.first().unwrap().shape(), [5]);
82        assert_eq!(activation_vec.first().unwrap().shape(), [5])
83    }
84
85    #[test]
86    fn forwards_pass_3() {
87        let model = vec![
88            Dense1d::new(5, 3, relu_1d, deriv_relu_1d),
89            Dense1d::new(3, 5, relu_1d, deriv_relu_1d),
90            Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
91        ];
92
93        let (weights_bias_vec, activation_vec) =
94            forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
95
96        assert_eq!(weights_bias_vec.last().unwrap().shape(), [10]);
97        assert_eq!(activation_vec.last().unwrap().shape(), [10])
98    }
99
100    #[test]
101    #[should_panic]
102    fn forwards_pass_4() {
103        let model = vec![
104            Dense1d::new(5, 3, relu_1d, deriv_relu_1d),
105            Dense1d::new(4, 5, relu_1d, deriv_relu_1d),
106            Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
107        ];
108
109        let (weights_bias_vec, activation_vec) =
110            forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
111    }
112}