ftir_cnn_loop/
looping.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6    fs::File,
7    io::{BufRead, BufReader},
8    sync::Arc,
9};
10
11fn data(
12    path: &str,
13) -> (
14    (
15        Vec<tensor::Tensor>,
16        Vec<tensor::Tensor>,
17        Vec<tensor::Tensor>,
18    ),
19    (
20        Vec<tensor::Tensor>,
21        Vec<tensor::Tensor>,
22        Vec<tensor::Tensor>,
23    ),
24    (
25        Vec<tensor::Tensor>,
26        Vec<tensor::Tensor>,
27        Vec<tensor::Tensor>,
28    ),
29) {
30    let reader = BufReader::new(File::open(&path).unwrap());
31
32    let mut x_train: Vec<tensor::Tensor> = Vec::new();
33    let mut y_train: Vec<tensor::Tensor> = Vec::new();
34    let mut class_train: Vec<tensor::Tensor> = Vec::new();
35
36    let mut x_test: Vec<tensor::Tensor> = Vec::new();
37    let mut y_test: Vec<tensor::Tensor> = Vec::new();
38    let mut class_test: Vec<tensor::Tensor> = Vec::new();
39
40    let mut x_val: Vec<tensor::Tensor> = Vec::new();
41    let mut y_val: Vec<tensor::Tensor> = Vec::new();
42    let mut class_val: Vec<tensor::Tensor> = Vec::new();
43
44    for line in reader.lines().skip(1) {
45        let line = line.unwrap();
46        let record: Vec<&str> = line.split(',').collect();
47
48        let mut data: Vec<f32> = Vec::new();
49        for i in 0..571 {
50            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
51        }
52        let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
53        match record.get(573).unwrap() {
54            &"Train" => {
55                x_train.push(tensor::Tensor::triple(data));
56                y_train.push(tensor::Tensor::single(vec![record
57                    .get(571)
58                    .unwrap()
59                    .parse::<f32>()
60                    .unwrap()]));
61                class_train.push(tensor::Tensor::one_hot(
62                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
63                    28,
64                ));
65            }
66            &"Test" => {
67                x_test.push(tensor::Tensor::triple(data));
68                y_test.push(tensor::Tensor::single(vec![record
69                    .get(571)
70                    .unwrap()
71                    .parse::<f32>()
72                    .unwrap()]));
73                class_test.push(tensor::Tensor::one_hot(
74                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
75                    28,
76                ));
77            }
78            &"Val" => {
79                x_val.push(tensor::Tensor::triple(data));
80                y_val.push(tensor::Tensor::single(vec![record
81                    .get(571)
82                    .unwrap()
83                    .parse::<f32>()
84                    .unwrap()]));
85                class_val.push(tensor::Tensor::one_hot(
86                    record.get(572).unwrap().parse::<usize>().unwrap() - 1, // For zero-indexed.
87                    28,
88                ));
89            }
90            _ => panic!("> Unknown class."),
91        }
92    }
93
94    // let mut generator = random::Generator::create(12345);
95    // let mut indices: Vec<usize> = (0..x.len()).collect();
96    // generator.shuffle(&mut indices);
97
98    (
99        (x_train, y_train, class_train),
100        (x_test, y_test, class_test),
101        (x_val, y_val, class_val),
102    )
103}
104
105fn main() {
106    // Load the ftir dataset
107    let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
108        data("./examples/datasets/ftir.csv");
109
110    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
111    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
112    let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
113
114    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
115    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
116    let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
117
118    let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
119    let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
120    let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
121
122    println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
123    println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
124    println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
125
126    vec!["REGRESSION", "CLASSIFICATION"]
127        .iter()
128        .for_each(|method| {
129            // Create the network
130            let mut network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
131
132            network.convolution(
133                1,
134                (1, 9),
135                (1, 1),
136                (0, 4),
137                (1, 1),
138                activation::Activation::ReLU,
139                None,
140            );
141            network.dense(32, activation::Activation::ReLU, false, None);
142
143            if method == &"REGRESSION" {
144                network.dense(1, activation::Activation::Linear, false, None);
145                network.set_objective(objective::Objective::RMSE, None);
146            } else {
147                network.dense(28, activation::Activation::Softmax, false, None);
148                network.set_objective(objective::Objective::CrossEntropy, None);
149            }
150
151            network.loopback(0, 0, 1, Arc::new(|loops| loops), false);
152            network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
153
154            network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
155
156            println!("{}", network);
157
158            // Train the network
159            let (train_loss, val_loss, val_acc);
160            if method == &"REGRESSION" {
161                println!("> Training the network for regression.");
162
163                (train_loss, val_loss, val_acc) = network.learn(
164                    &x_train,
165                    &y_train,
166                    Some((&x_val, &y_val, 50)),
167                    16,
168                    500,
169                    Some(100),
170                );
171            } else {
172                println!("> Training the network for classification.");
173
174                (train_loss, val_loss, val_acc) = network.learn(
175                    &x_train,
176                    &class_train,
177                    Some((&x_val, &class_val, 50)),
178                    16,
179                    500,
180                    Some(100),
181                );
182            }
183            plot::loss(
184                &train_loss,
185                &val_loss,
186                &val_acc,
187                &format!("LOOP : FTIR : {}", method),
188                &format!("./output/ftir/cnn-{}-loop.png", method.to_lowercase()),
189            );
190
191            if method == &"REGRESSION" {
192                // Use the network
193                let prediction = network.predict(x_test.get(0).unwrap());
194                println!(
195                    "Prediction. Target: {}. Output: {}.",
196                    y_test[0].data, prediction.data
197                );
198            } else {
199                // Validate the network
200                let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
201                println!(
202                    "Final validation accuracy: {:.2} % and loss: {:.5}",
203                    val_acc * 100.0,
204                    val_loss
205                );
206
207                // Use the network
208                let prediction = network.predict(x_test.get(0).unwrap());
209                println!(
210                    "Prediction. Target: {}. Output: {}.",
211                    class_test[0].argmax(),
212                    prediction.argmax()
213                );
214            }
215        });
216}