1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8};
9
10fn data(
11 path: &str,
12) -> (
13 (
14 Vec<tensor::Tensor>,
15 Vec<tensor::Tensor>,
16 Vec<tensor::Tensor>,
17 ),
18 (
19 Vec<tensor::Tensor>,
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 ),
23 (
24 Vec<tensor::Tensor>,
25 Vec<tensor::Tensor>,
26 Vec<tensor::Tensor>,
27 ),
28) {
29 let reader = BufReader::new(File::open(&path).unwrap());
30
31 let mut x_train: Vec<tensor::Tensor> = Vec::new();
32 let mut y_train: Vec<tensor::Tensor> = Vec::new();
33 let mut class_train: Vec<tensor::Tensor> = Vec::new();
34
35 let mut x_test: Vec<tensor::Tensor> = Vec::new();
36 let mut y_test: Vec<tensor::Tensor> = Vec::new();
37 let mut class_test: Vec<tensor::Tensor> = Vec::new();
38
39 let mut x_val: Vec<tensor::Tensor> = Vec::new();
40 let mut y_val: Vec<tensor::Tensor> = Vec::new();
41 let mut class_val: Vec<tensor::Tensor> = Vec::new();
42
43 for line in reader.lines().skip(1) {
44 let line = line.unwrap();
45 let record: Vec<&str> = line.split(',').collect();
46
47 let mut data: Vec<f32> = Vec::new();
48 for i in 0..571 {
49 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
50 }
51 match record.get(573).unwrap() {
52 &"Train" => {
53 x_train.push(tensor::Tensor::single(data));
54 y_train.push(tensor::Tensor::single(vec![record
55 .get(571)
56 .unwrap()
57 .parse::<f32>()
58 .unwrap()]));
59 class_train.push(tensor::Tensor::one_hot(
60 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
62 ));
63 }
64 &"Test" => {
65 x_test.push(tensor::Tensor::single(data));
66 y_test.push(tensor::Tensor::single(vec![record
67 .get(571)
68 .unwrap()
69 .parse::<f32>()
70 .unwrap()]));
71 class_test.push(tensor::Tensor::one_hot(
72 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
74 ));
75 }
76 &"Val" => {
77 x_val.push(tensor::Tensor::single(data));
78 y_val.push(tensor::Tensor::single(vec![record
79 .get(571)
80 .unwrap()
81 .parse::<f32>()
82 .unwrap()]));
83 class_val.push(tensor::Tensor::one_hot(
84 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
86 ));
87 }
88 _ => panic!("> Unknown class."),
89 }
90 }
91
92 (
97 (x_train, y_train, class_train),
98 (x_test, y_test, class_test),
99 (x_val, y_val, class_val),
100 )
101}
102
103fn main() {
104 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
106 data("./examples/datasets/ftir.csv");
107
108 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
109 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
110 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
111
112 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
113 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
114 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
115
116 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
117 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
118 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
119
120 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
121 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
122 println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
123
124 vec!["REGRESSION", "CLASSIFICATION"]
125 .iter()
126 .for_each(|method| {
127 let mut network = network::Network::new(tensor::Shape::Single(571));
129
130 network.dense(128, activation::Activation::ReLU, false, None);
131
132 network.feedback(
133 vec![
134 feedback::Layer::Dense(256, activation::Activation::ReLU, false, None),
135 feedback::Layer::Dense(128, activation::Activation::ReLU, false, None),
136 ],
137 2,
138 false,
139 false,
140 feedback::Accumulation::Mean,
141 );
142
143 if method == &"REGRESSION" {
144 network.dense(1, activation::Activation::Linear, false, None);
145 network.set_objective(objective::Objective::RMSE, None);
146 } else {
147 network.dense(28, activation::Activation::Softmax, false, None);
148 network.set_objective(objective::Objective::CrossEntropy, None);
149 }
150
151 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
156
157 println!("{}", network);
158
159 let (train_loss, val_loss, val_acc);
161 if method == &"REGRESSION" {
162 println!("> Training the network for regression.");
163
164 (train_loss, val_loss, val_acc) = network.learn(
165 &x_train,
166 &y_train,
167 Some((&x_val, &y_val, 50)),
168 16,
169 500,
170 Some(100),
171 );
172 } else {
173 println!("> Training the network for classification.");
174
175 (train_loss, val_loss, val_acc) = network.learn(
176 &x_train,
177 &class_train,
178 Some((&x_val, &class_val, 50)),
179 16,
180 500,
181 Some(100),
182 );
183 }
184 plot::loss(
185 &train_loss,
186 &val_loss,
187 &val_acc,
188 &format!("FEEDBACK : FTIR : {}", method),
189 &format!("./output/ftir/mlp-{}-feedback.png", method.to_lowercase()),
190 );
191
192 if method == &"REGRESSION" {
193 let prediction = network.predict(x_test.get(0).unwrap());
195 println!(
196 "Prediction. Target: {}. Output: {}.",
197 y_test[0].data, prediction.data
198 );
199 } else {
200 let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
202 println!(
203 "Final validation accuracy: {:.2} % and loss: {:.5}",
204 val_acc * 100.0,
205 val_loss
206 );
207
208 let prediction = network.predict(x_test.get(0).unwrap());
210 println!(
211 "Prediction. Target: {}. Output: {}.",
212 class_test[0].argmax(),
213 prediction.argmax()
214 );
215 }
216 });
217}