ftir_mlp_feedback/
feedback.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6    fs::File,
7    io::{BufRead, BufReader},
8};
9
10fn data(
11    path: &str,
12) -> (
13    (
14        Vec<tensor::Tensor>,
15        Vec<tensor::Tensor>,
16        Vec<tensor::Tensor>,
17    ),
18    (
19        Vec<tensor::Tensor>,
20        Vec<tensor::Tensor>,
21        Vec<tensor::Tensor>,
22    ),
23    (
24        Vec<tensor::Tensor>,
25        Vec<tensor::Tensor>,
26        Vec<tensor::Tensor>,
27    ),
28) {
29    let reader = BufReader::new(File::open(&path).unwrap());
30
31    let mut x_train: Vec<tensor::Tensor> = Vec::new();
32    let mut y_train: Vec<tensor::Tensor> = Vec::new();
33    let mut class_train: Vec<tensor::Tensor> = Vec::new();
34
35    let mut x_test: Vec<tensor::Tensor> = Vec::new();
36    let mut y_test: Vec<tensor::Tensor> = Vec::new();
37    let mut class_test: Vec<tensor::Tensor> = Vec::new();
38
39    let mut x_val: Vec<tensor::Tensor> = Vec::new();
40    let mut y_val: Vec<tensor::Tensor> = Vec::new();
41    let mut class_val: Vec<tensor::Tensor> = Vec::new();
42
43    for line in reader.lines().skip(1) {
44        let line = line.unwrap();
45        let record: Vec<&str> = line.split(',').collect();
46
47        let mut data: Vec<f32> = Vec::new();
48        for i in 0..571 {
49            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
50        }
51        match record.get(573).unwrap() {
52            &"Train" => {
53                x_train.push(tensor::Tensor::single(data));
54                y_train.push(tensor::Tensor::single(vec![record
55                    .get(571)
56                    .unwrap()
57                    .parse::<f32>()
58                    .unwrap()]));
59                class_train.push(tensor::Tensor::one_hot(
60                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
61                    28,
62                ));
63            }
64            &"Test" => {
65                x_test.push(tensor::Tensor::single(data));
66                y_test.push(tensor::Tensor::single(vec![record
67                    .get(571)
68                    .unwrap()
69                    .parse::<f32>()
70                    .unwrap()]));
71                class_test.push(tensor::Tensor::one_hot(
72                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
73                    28,
74                ));
75            }
76            &"Val" => {
77                x_val.push(tensor::Tensor::single(data));
78                y_val.push(tensor::Tensor::single(vec![record
79                    .get(571)
80                    .unwrap()
81                    .parse::<f32>()
82                    .unwrap()]));
83                class_val.push(tensor::Tensor::one_hot(
84                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
85                    28,
86                ));
87            }
88            _ => panic!("> Unknown class."),
89        }
90    }
91
92    // let mut generator = random::Generator::create(12345);
93    // let mut indices: Vec<usize> = (0..x.len()).collect();
94    // generator.shuffle(&mut indices);
95
96    (
97        (x_train, y_train, class_train),
98        (x_test, y_test, class_test),
99        (x_val, y_val, class_val),
100    )
101}
102
103fn main() {
104    // Load the ftir dataset
105    let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
106        data("./examples/datasets/ftir.csv");
107
108    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
109    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
110    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
111
112    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
113    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
114    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
115
116    let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
117    let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
118    let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
119
120    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
121    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
122    println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
123
124    vec!["REGRESSION", "CLASSIFICATION"]
125        .iter()
126        .for_each(|method| {
127            // Create the network
128            let mut network = network::Network::new(tensor::Shape::Single(571));
129
130            network.dense(128, activation::Activation::ReLU, false, None);
131
132            network.feedback(
133                vec![
134                    feedback::Layer::Dense(256, activation::Activation::ReLU, false, None),
135                    feedback::Layer::Dense(128, activation::Activation::ReLU, false, None),
136                ],
137                2,
138                false,
139                false,
140                feedback::Accumulation::Mean,
141            );
142
143            if method == &"REGRESSION" {
144                network.dense(1, activation::Activation::Linear, false, None);
145                network.set_objective(objective::Objective::RMSE, None);
146            } else {
147                network.dense(28, activation::Activation::Softmax, false, None);
148                network.set_objective(objective::Objective::CrossEntropy, None);
149            }
150
151            // Include skip connection bypassing the feedback block
152            // network.connect(1, 2);
153            // network.set_accumulation(feedback::Accumulation::Add);
154
155            network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
156
157            println!("{}", network);
158
159            // Train the network
160            let (train_loss, val_loss, val_acc);
161            if method == &"REGRESSION" {
162                println!("> Training the network for regression.");
163
164                (train_loss, val_loss, val_acc) = network.learn(
165                    &x_train,
166                    &y_train,
167                    Some((&x_val, &y_val, 50)),
168                    16,
169                    500,
170                    Some(100),
171                );
172            } else {
173                println!("> Training the network for classification.");
174
175                (train_loss, val_loss, val_acc) = network.learn(
176                    &x_train,
177                    &class_train,
178                    Some((&x_val, &class_val, 50)),
179                    16,
180                    500,
181                    Some(100),
182                );
183            }
184            plot::loss(
185                &train_loss,
186                &val_loss,
187                &val_acc,
188                &format!("FEEDBACK : FTIR : {}", method),
189                &format!("./output/ftir/mlp-{}-feedback.png", method.to_lowercase()),
190            );
191
192            if method == &"REGRESSION" {
193                // Use the network
194                let prediction = network.predict(x_test.get(0).unwrap());
195                println!(
196                    "Prediction. Target: {}. Output: {}.",
197                    y_test[0].data, prediction.data
198                );
199            } else {
200                // Validate the network
201                let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
202                println!(
203                    "Final validation accuracy: {:.2} % and loss: {:.5}",
204                    val_acc * 100.0,
205                    val_loss
206                );
207
208                // Use the network
209                let prediction = network.predict(x_test.get(0).unwrap());
210                println!(
211                    "Prediction. Target: {}. Output: {}.",
212                    class_test[0].argmax(),
213                    prediction.argmax()
214                );
215            }
216        });
217}