1use super::layers::*;
2use ndarray::{arr1, Array1, Array2};
3use std::iter::zip;
4
5pub fn train<L>(
6 model: &Vec<L>,
7 train_data: Array2<f64>,
8 train_lbl: Array2<f64>,
9 test_data: Array2<f64>,
10 test_lbl: Array2<f64>,
11) where
12 L: Layer1d,
13{
14 todo!()
15}
16
17pub fn forward_pass<L>(model: &Vec<L>, data: Array1<f64>) -> (Vec<Array1<f64>>, Vec<Array1<f64>>)
19where
20 L: Layer1d,
21{
22 let mut weights_bias_vec: Vec<Array1<f64>> = Vec::with_capacity(model.len());
23 let mut activation_vec: Vec<Array1<f64>> = Vec::with_capacity(model.len());
24
25 let mut activation_pass = data;
26 let mut weight_pass;
27
28 for layer in model.iter() {
29 (weight_pass, activation_pass) = layer.pass(activation_pass.clone());
30
31 weights_bias_vec.push(weight_pass);
32 activation_vec.push(activation_pass.clone());
33 }
34
35 (weights_bias_vec, activation_vec)
36}
37
38pub fn back_propagation<L>(
39 model: &Vec<L>,
40 weights_bias_vec: Vec<Array1<f64>>,
41 activation_vec: Vec<Array1<f64>>,
42 target_out: Array1<f64>,
43) where
44 L: Layer1d,
45{
46 todo!()
47}
48
49#[cfg(test)]
50mod train_tests {
51 use super::*;
52 use crate::layers::*;
53 use crate::activations::*;
54 use ndarray::arr1;
55
56 #[test]
57 fn forwards_pass_1() {
58 let model = vec![
59 Dense1d::new(1, 3, relu_1d, deriv_relu_1d),
60 Dense1d::new(3, 5, relu_1d, deriv_relu_1d),
61 Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
62 ];
63
64 let (weights_bias_vec, activation_vec) = forward_pass(&model, arr1(&[1.]));
65
66 assert_eq!(weights_bias_vec.len(), 3);
67 assert_eq!(activation_vec.len(), 3)
68 }
69
70 #[test]
71 fn forwards_pass_2() {
72 let model = vec![
73 Dense1d::new(5, 5, relu_1d, deriv_relu_1d),
74 Dense1d::new(5, 5, relu_1d, deriv_relu_1d),
75 Dense1d::new(5, 5, softmax_1d, deriv_relu_1d),
76 ];
77
78 let (weights_bias_vec, activation_vec) =
79 forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
80
81 assert_eq!(weights_bias_vec.first().unwrap().shape(), [5]);
82 assert_eq!(activation_vec.first().unwrap().shape(), [5])
83 }
84
85 #[test]
86 fn forwards_pass_3() {
87 let model = vec![
88 Dense1d::new(5, 3, relu_1d, deriv_relu_1d),
89 Dense1d::new(3, 5, relu_1d, deriv_relu_1d),
90 Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
91 ];
92
93 let (weights_bias_vec, activation_vec) =
94 forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
95
96 assert_eq!(weights_bias_vec.last().unwrap().shape(), [10]);
97 assert_eq!(activation_vec.last().unwrap().shape(), [10])
98 }
99
100 #[test]
101 #[should_panic]
102 fn forwards_pass_4() {
103 let model = vec![
104 Dense1d::new(5, 3, relu_1d, deriv_relu_1d),
105 Dense1d::new(4, 5, relu_1d, deriv_relu_1d),
106 Dense1d::new(5, 10, softmax_1d, deriv_relu_1d),
107 ];
108
109 let (weights_bias_vec, activation_vec) =
110 forward_pass(&model, arr1(&[1., 2., 0.2, 1., 0.32]));
111 }
112}