1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8};
9
10fn data(
11 path: &str,
12) -> (
13 (
14 Vec<tensor::Tensor>,
15 Vec<tensor::Tensor>,
16 Vec<tensor::Tensor>,
17 ),
18 (
19 Vec<tensor::Tensor>,
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 ),
23 (
24 Vec<tensor::Tensor>,
25 Vec<tensor::Tensor>,
26 Vec<tensor::Tensor>,
27 ),
28) {
29 let reader = BufReader::new(File::open(&path).unwrap());
30
31 let mut x_train: Vec<tensor::Tensor> = Vec::new();
32 let mut y_train: Vec<tensor::Tensor> = Vec::new();
33 let mut class_train: Vec<tensor::Tensor> = Vec::new();
34
35 let mut x_test: Vec<tensor::Tensor> = Vec::new();
36 let mut y_test: Vec<tensor::Tensor> = Vec::new();
37 let mut class_test: Vec<tensor::Tensor> = Vec::new();
38
39 let mut x_val: Vec<tensor::Tensor> = Vec::new();
40 let mut y_val: Vec<tensor::Tensor> = Vec::new();
41 let mut class_val: Vec<tensor::Tensor> = Vec::new();
42
43 for line in reader.lines().skip(1) {
44 let line = line.unwrap();
45 let record: Vec<&str> = line.split(',').collect();
46
47 let mut data: Vec<f32> = Vec::new();
48 for i in 0..571 {
49 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
50 }
51 let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
52 match record.get(573).unwrap() {
53 &"Train" => {
54 x_train.push(tensor::Tensor::triple(data));
55 y_train.push(tensor::Tensor::single(vec![record
56 .get(571)
57 .unwrap()
58 .parse::<f32>()
59 .unwrap()]));
60 class_train.push(tensor::Tensor::one_hot(
61 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
63 ));
64 }
65 &"Test" => {
66 x_test.push(tensor::Tensor::triple(data));
67 y_test.push(tensor::Tensor::single(vec![record
68 .get(571)
69 .unwrap()
70 .parse::<f32>()
71 .unwrap()]));
72 class_test.push(tensor::Tensor::one_hot(
73 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
75 ));
76 }
77 &"Val" => {
78 x_val.push(tensor::Tensor::triple(data));
79 y_val.push(tensor::Tensor::single(vec![record
80 .get(571)
81 .unwrap()
82 .parse::<f32>()
83 .unwrap()]));
84 class_val.push(tensor::Tensor::one_hot(
85 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
87 ));
88 }
89 _ => panic!("> Unknown class."),
90 }
91 }
92
93 (
98 (x_train, y_train, class_train),
99 (x_test, y_test, class_test),
100 (x_val, y_val, class_val),
101 )
102}
103
104fn main() {
105 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
107 data("./examples/datasets/ftir.csv");
108
109 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
110 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
111 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
112
113 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
114 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
115 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
116
117 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
118 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
119 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
120
121 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
122 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
123 println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
124
125 vec!["REGRESSION", "CLASSIFICATION"]
126 .iter()
127 .for_each(|method| {
128 let mut network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
130
131 network.convolution(
132 1,
133 (1, 9),
134 (1, 1),
135 (0, 4),
136 (1, 1),
137 activation::Activation::ReLU,
138 None,
139 );
140 network.dense(32, activation::Activation::ReLU, false, None);
141
142 if method == &"REGRESSION" {
143 network.dense(1, activation::Activation::Linear, false, None);
144 network.set_objective(objective::Objective::RMSE, None);
145 } else {
146 network.dense(28, activation::Activation::Softmax, false, None);
147 network.set_objective(objective::Objective::CrossEntropy, None);
148 }
149
150 network.connect(0, 1);
151 network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
152
153 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
154
155 println!("{}", network);
156
157 let (train_loss, val_loss, val_acc);
159 if method == &"REGRESSION" {
160 println!("> Training the network for regression.");
161
162 (train_loss, val_loss, val_acc) = network.learn(
163 &x_train,
164 &y_train,
165 Some((&x_val, &y_val, 50)),
166 16,
167 500,
168 Some(100),
169 );
170 } else {
171 println!("> Training the network for classification.");
172
173 (train_loss, val_loss, val_acc) = network.learn(
174 &x_train,
175 &class_train,
176 Some((&x_val, &class_val, 50)),
177 16,
178 500,
179 Some(100),
180 );
181 }
182 plot::loss(
183 &train_loss,
184 &val_loss,
185 &val_acc,
186 &format!("SKIP : FTIR : {}", method),
187 &format!("./output/ftir/cnn-{}-skip.png", method.to_lowercase()),
188 );
189
190 if method == &"REGRESSION" {
191 let prediction = network.predict(x_test.get(0).unwrap());
193 println!(
194 "Prediction. Target: {}. Output: {}.",
195 y_test[0].data, prediction.data
196 );
197 } else {
198 let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
200 println!(
201 "Final validation accuracy: {:.2} % and loss: {:.5}",
202 val_acc * 100.0,
203 val_loss
204 );
205
206 let prediction = network.predict(x_test.get(0).unwrap());
208 println!(
209 "Prediction. Target: {}. Output: {}.",
210 class_test[0].argmax(),
211 prediction.argmax()
212 );
213 }
214 });
215}