1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8 sync::Arc,
9};
10
11fn data(
12 path: &str,
13) -> (
14 (
15 Vec<tensor::Tensor>,
16 Vec<tensor::Tensor>,
17 Vec<tensor::Tensor>,
18 ),
19 (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23 ),
24 (
25 Vec<tensor::Tensor>,
26 Vec<tensor::Tensor>,
27 Vec<tensor::Tensor>,
28 ),
29) {
30 let reader = BufReader::new(File::open(&path).unwrap());
31
32 let mut x_train: Vec<tensor::Tensor> = Vec::new();
33 let mut y_train: Vec<tensor::Tensor> = Vec::new();
34 let mut class_train: Vec<tensor::Tensor> = Vec::new();
35
36 let mut x_test: Vec<tensor::Tensor> = Vec::new();
37 let mut y_test: Vec<tensor::Tensor> = Vec::new();
38 let mut class_test: Vec<tensor::Tensor> = Vec::new();
39
40 let mut x_val: Vec<tensor::Tensor> = Vec::new();
41 let mut y_val: Vec<tensor::Tensor> = Vec::new();
42 let mut class_val: Vec<tensor::Tensor> = Vec::new();
43
44 for line in reader.lines().skip(1) {
45 let line = line.unwrap();
46 let record: Vec<&str> = line.split(',').collect();
47
48 let mut data: Vec<f32> = Vec::new();
49 for i in 0..571 {
50 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
51 }
52 match record.get(573).unwrap() {
53 &"Train" => {
54 x_train.push(tensor::Tensor::single(data));
55 y_train.push(tensor::Tensor::single(vec![record
56 .get(571)
57 .unwrap()
58 .parse::<f32>()
59 .unwrap()]));
60 class_train.push(tensor::Tensor::one_hot(
61 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
63 ));
64 }
65 &"Test" => {
66 x_test.push(tensor::Tensor::single(data));
67 y_test.push(tensor::Tensor::single(vec![record
68 .get(571)
69 .unwrap()
70 .parse::<f32>()
71 .unwrap()]));
72 class_test.push(tensor::Tensor::one_hot(
73 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
75 ));
76 }
77 &"Val" => {
78 x_val.push(tensor::Tensor::single(data));
79 y_val.push(tensor::Tensor::single(vec![record
80 .get(571)
81 .unwrap()
82 .parse::<f32>()
83 .unwrap()]));
84 class_val.push(tensor::Tensor::one_hot(
85 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
87 ));
88 }
89 _ => panic!("> Unknown class."),
90 }
91 }
92
93 (
98 (x_train, y_train, class_train),
99 (x_test, y_test, class_test),
100 (x_val, y_val, class_val),
101 )
102}
103
104fn main() {
105 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
107 data("./examples/datasets/ftir.csv");
108
109 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
110 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
111 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
112
113 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
114 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
115 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
116
117 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
118 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
119 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
120
121 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
122 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
123 println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
124
125 vec!["REGRESSION", "CLASSIFICATION"]
126 .iter()
127 .for_each(|method| {
128 let mut network = network::Network::new(tensor::Shape::Single(571));
130
131 network.dense(128, activation::Activation::ReLU, false, None);
132 network.dense(256, activation::Activation::ReLU, false, None);
133 network.dense(128, activation::Activation::ReLU, false, None);
134
135 if method == &"REGRESSION" {
136 network.dense(1, activation::Activation::Linear, false, None);
137 network.set_objective(objective::Objective::RMSE, None);
138 } else {
139 network.dense(28, activation::Activation::Softmax, false, None);
140 network.set_objective(objective::Objective::CrossEntropy, None);
141 }
142
143 network.loopback(2, 1, 1, Arc::new(|_loops| 1.0), false);
144 network.set_accumulation(feedback::Accumulation::Mean, feedback::Accumulation::Mean);
145
146 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
147
148 println!("{}", network);
149
150 let (train_loss, val_loss, val_acc);
152 if method == &"REGRESSION" {
153 println!("> Training the network for regression.");
154
155 (train_loss, val_loss, val_acc) = network.learn(
156 &x_train,
157 &y_train,
158 Some((&x_val, &y_val, 50)),
159 16,
160 500,
161 Some(100),
162 );
163 } else {
164 println!("> Training the network for classification.");
165
166 (train_loss, val_loss, val_acc) = network.learn(
167 &x_train,
168 &class_train,
169 Some((&x_val, &class_val, 50)),
170 16,
171 500,
172 Some(100),
173 );
174 }
175 plot::loss(
176 &train_loss,
177 &val_loss,
178 &val_acc,
179 &format!("LOOP : FTIR : {}", method),
180 &format!("./output/ftir/mlp-{}-loop.png", method.to_lowercase()),
181 );
182
183 if method == &"REGRESSION" {
184 let prediction = network.predict(x_test.get(0).unwrap());
186 println!(
187 "Prediction. Target: {}. Output: {}.",
188 y_test[0].data, prediction.data
189 );
190 } else {
191 let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
193 println!(
194 "Final validation accuracy: {:.2} % and loss: {:.5}",
195 val_acc * 100.0,
196 val_loss
197 );
198
199 let prediction = network.predict(x_test.get(0).unwrap());
201 println!(
202 "Prediction. Target: {}. Output: {}.",
203 class_test[0].argmax(),
204 prediction.argmax()
205 );
206 }
207 });
208}