1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8};
9
10fn data(
11 path: &str,
12) -> (
13 (
14 Vec<tensor::Tensor>,
15 Vec<tensor::Tensor>,
16 Vec<tensor::Tensor>,
17 ),
18 (
19 Vec<tensor::Tensor>,
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 ),
23 (
24 Vec<tensor::Tensor>,
25 Vec<tensor::Tensor>,
26 Vec<tensor::Tensor>,
27 ),
28) {
29 let reader = BufReader::new(File::open(&path).unwrap());
30
31 let mut x_train: Vec<tensor::Tensor> = Vec::new();
32 let mut y_train: Vec<tensor::Tensor> = Vec::new();
33 let mut class_train: Vec<tensor::Tensor> = Vec::new();
34
35 let mut x_test: Vec<tensor::Tensor> = Vec::new();
36 let mut y_test: Vec<tensor::Tensor> = Vec::new();
37 let mut class_test: Vec<tensor::Tensor> = Vec::new();
38
39 let mut x_val: Vec<tensor::Tensor> = Vec::new();
40 let mut y_val: Vec<tensor::Tensor> = Vec::new();
41 let mut class_val: Vec<tensor::Tensor> = Vec::new();
42
43 for line in reader.lines().skip(1) {
44 let line = line.unwrap();
45 let record: Vec<&str> = line.split(',').collect();
46
47 let mut data: Vec<f32> = Vec::new();
48 for i in 0..571 {
49 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
50 }
51 match record.get(573).unwrap() {
52 &"Train" => {
53 x_train.push(tensor::Tensor::single(data));
54 y_train.push(tensor::Tensor::single(vec![record
55 .get(571)
56 .unwrap()
57 .parse::<f32>()
58 .unwrap()]));
59 class_train.push(tensor::Tensor::one_hot(
60 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
62 ));
63 }
64 &"Test" => {
65 x_test.push(tensor::Tensor::single(data));
66 y_test.push(tensor::Tensor::single(vec![record
67 .get(571)
68 .unwrap()
69 .parse::<f32>()
70 .unwrap()]));
71 class_test.push(tensor::Tensor::one_hot(
72 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
74 ));
75 }
76 &"Val" => {
77 x_val.push(tensor::Tensor::single(data));
78 y_val.push(tensor::Tensor::single(vec![record
79 .get(571)
80 .unwrap()
81 .parse::<f32>()
82 .unwrap()]));
83 class_val.push(tensor::Tensor::one_hot(
84 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
86 ));
87 }
88 _ => panic!("> Unknown class."),
89 }
90 }
91
92 (
97 (x_train, y_train, class_train),
98 (x_test, y_test, class_test),
99 (x_val, y_val, class_val),
100 )
101}
102
103fn main() {
104 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
106 data("./examples/datasets/ftir.csv");
107
108 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
109 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
110 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
111
112 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
113 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
114 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
115
116 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
117 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
118 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
119
120 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
121 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
122 println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
123
124 vec!["REGRESSION", "CLASSIFICATION"]
125 .iter()
126 .for_each(|method| {
127 let mut network = network::Network::new(tensor::Shape::Single(571));
129
130 network.dense(128, activation::Activation::ReLU, false, None);
131 network.dense(256, activation::Activation::ReLU, false, None);
132 network.dense(128, activation::Activation::ReLU, false, None);
133
134 if method == &"REGRESSION" {
135 network.dense(1, activation::Activation::Linear, false, None);
136 network.set_objective(objective::Objective::RMSE, None);
137 } else {
138 network.dense(28, activation::Activation::Softmax, false, None);
139 network.set_objective(objective::Objective::CrossEntropy, None);
140 }
141
142 network.connect(1, 3);
143 network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
144
145 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
146
147 println!("{}", network);
148
149 let (train_loss, val_loss, val_acc);
151 if method == &"REGRESSION" {
152 println!("> Training the network for regression.");
153
154 (train_loss, val_loss, val_acc) = network.learn(
155 &x_train,
156 &y_train,
157 Some((&x_val, &y_val, 50)),
158 16,
159 500,
160 Some(100),
161 );
162 } else {
163 println!("> Training the network for classification.");
164
165 (train_loss, val_loss, val_acc) = network.learn(
166 &x_train,
167 &class_train,
168 Some((&x_val, &class_val, 50)),
169 16,
170 500,
171 Some(100),
172 );
173 }
174 plot::loss(
175 &train_loss,
176 &val_loss,
177 &val_acc,
178 &format!("SKIP : FTIR : {}", method),
179 &format!("./output/ftir/mlp-{}-skip.png", method.to_lowercase()),
180 );
181
182 if method == &"REGRESSION" {
183 let prediction = network.predict(x_test.get(0).unwrap());
185 println!(
186 "Prediction. Target: {}. Output: {}.",
187 y_test[0].data, prediction.data
188 );
189 } else {
190 let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
192 println!(
193 "Final validation accuracy: {:.2} % and loss: {:.5}",
194 val_acc * 100.0,
195 val_loss
196 );
197
198 let prediction = network.predict(x_test.get(0).unwrap());
200 println!(
201 "Prediction. Target: {}. Output: {}.",
202 class_test[0].argmax(),
203 prediction.argmax()
204 );
205 }
206 });
207}