ftir_mlp_loop/
looping.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6    fs::File,
7    io::{BufRead, BufReader},
8    sync::Arc,
9};
10
11fn data(
12    path: &str,
13) -> (
14    (
15        Vec<tensor::Tensor>,
16        Vec<tensor::Tensor>,
17        Vec<tensor::Tensor>,
18    ),
19    (
20        Vec<tensor::Tensor>,
21        Vec<tensor::Tensor>,
22        Vec<tensor::Tensor>,
23    ),
24    (
25        Vec<tensor::Tensor>,
26        Vec<tensor::Tensor>,
27        Vec<tensor::Tensor>,
28    ),
29) {
30    let reader = BufReader::new(File::open(&path).unwrap());
31
32    let mut x_train: Vec<tensor::Tensor> = Vec::new();
33    let mut y_train: Vec<tensor::Tensor> = Vec::new();
34    let mut class_train: Vec<tensor::Tensor> = Vec::new();
35
36    let mut x_test: Vec<tensor::Tensor> = Vec::new();
37    let mut y_test: Vec<tensor::Tensor> = Vec::new();
38    let mut class_test: Vec<tensor::Tensor> = Vec::new();
39
40    let mut x_val: Vec<tensor::Tensor> = Vec::new();
41    let mut y_val: Vec<tensor::Tensor> = Vec::new();
42    let mut class_val: Vec<tensor::Tensor> = Vec::new();
43
44    for line in reader.lines().skip(1) {
45        let line = line.unwrap();
46        let record: Vec<&str> = line.split(',').collect();
47
48        let mut data: Vec<f32> = Vec::new();
49        for i in 0..571 {
50            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
51        }
52        match record.get(573).unwrap() {
53            &"Train" => {
54                x_train.push(tensor::Tensor::single(data));
55                y_train.push(tensor::Tensor::single(vec![record
56                    .get(571)
57                    .unwrap()
58                    .parse::<f32>()
59                    .unwrap()]));
60                class_train.push(tensor::Tensor::one_hot(
61                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
62                    28,
63                ));
64            }
65            &"Test" => {
66                x_test.push(tensor::Tensor::single(data));
67                y_test.push(tensor::Tensor::single(vec![record
68                    .get(571)
69                    .unwrap()
70                    .parse::<f32>()
71                    .unwrap()]));
72                class_test.push(tensor::Tensor::one_hot(
73                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
74                    28,
75                ));
76            }
77            &"Val" => {
78                x_val.push(tensor::Tensor::single(data));
79                y_val.push(tensor::Tensor::single(vec![record
80                    .get(571)
81                    .unwrap()
82                    .parse::<f32>()
83                    .unwrap()]));
84                class_val.push(tensor::Tensor::one_hot(
85                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
86                    28,
87                ));
88            }
89            _ => panic!("> Unknown class."),
90        }
91    }
92
93    // let mut generator = random::Generator::create(12345);
94    // let mut indices: Vec<usize> = (0..x.len()).collect();
95    // generator.shuffle(&mut indices);
96
97    (
98        (x_train, y_train, class_train),
99        (x_test, y_test, class_test),
100        (x_val, y_val, class_val),
101    )
102}
103
104fn main() {
105    // Load the ftir dataset
106    let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
107        data("./examples/datasets/ftir.csv");
108
109    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
110    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
111    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
112
113    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
114    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
115    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
116
117    let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
118    let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
119    let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
120
121    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
122    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
123    println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
124
125    vec!["REGRESSION", "CLASSIFICATION"]
126        .iter()
127        .for_each(|method| {
128            // Create the network
129            let mut network = network::Network::new(tensor::Shape::Single(571));
130
131            network.dense(128, activation::Activation::ReLU, false, None);
132            network.dense(256, activation::Activation::ReLU, false, None);
133            network.dense(128, activation::Activation::ReLU, false, None);
134
135            if method == &"REGRESSION" {
136                network.dense(1, activation::Activation::Linear, false, None);
137                network.set_objective(objective::Objective::RMSE, None);
138            } else {
139                network.dense(28, activation::Activation::Softmax, false, None);
140                network.set_objective(objective::Objective::CrossEntropy, None);
141            }
142
143            network.loopback(2, 1, 1, Arc::new(|_loops| 1.0), false);
144            network.set_accumulation(feedback::Accumulation::Mean, feedback::Accumulation::Mean);
145
146            network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
147
148            println!("{}", network);
149
150            // Train the network
151            let (train_loss, val_loss, val_acc);
152            if method == &"REGRESSION" {
153                println!("> Training the network for regression.");
154
155                (train_loss, val_loss, val_acc) = network.learn(
156                    &x_train,
157                    &y_train,
158                    Some((&x_val, &y_val, 50)),
159                    16,
160                    500,
161                    Some(100),
162                );
163            } else {
164                println!("> Training the network for classification.");
165
166                (train_loss, val_loss, val_acc) = network.learn(
167                    &x_train,
168                    &class_train,
169                    Some((&x_val, &class_val, 50)),
170                    16,
171                    500,
172                    Some(100),
173                );
174            }
175            plot::loss(
176                &train_loss,
177                &val_loss,
178                &val_acc,
179                &format!("LOOP : FTIR : {}", method),
180                &format!("./output/ftir/mlp-{}-loop.png", method.to_lowercase()),
181            );
182
183            if method == &"REGRESSION" {
184                // Use the network
185                let prediction = network.predict(x_test.get(0).unwrap());
186                println!(
187                    "Prediction. Target: {}. Output: {}.",
188                    y_test[0].data, prediction.data
189                );
190            } else {
191                // Validate the network
192                let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
193                println!(
194                    "Final validation accuracy: {:.2} % and loss: {:.5}",
195                    val_acc * 100.0,
196                    val_loss
197                );
198
199                // Use the network
200                let prediction = network.predict(x_test.get(0).unwrap());
201                println!(
202                    "Prediction. Target: {}. Output: {}.",
203                    class_test[0].argmax(),
204                    prediction.argmax()
205                );
206            }
207        });
208}