1use neurons::{activation, feedback, network, objective, optimizer, plot, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8 sync::Arc,
9};
10
11fn data(
12 path: &str,
13) -> (
14 (
15 Vec<tensor::Tensor>,
16 Vec<tensor::Tensor>,
17 Vec<tensor::Tensor>,
18 ),
19 (
20 Vec<tensor::Tensor>,
21 Vec<tensor::Tensor>,
22 Vec<tensor::Tensor>,
23 ),
24 (
25 Vec<tensor::Tensor>,
26 Vec<tensor::Tensor>,
27 Vec<tensor::Tensor>,
28 ),
29) {
30 let reader = BufReader::new(File::open(&path).unwrap());
31
32 let mut x_train: Vec<tensor::Tensor> = Vec::new();
33 let mut y_train: Vec<tensor::Tensor> = Vec::new();
34 let mut class_train: Vec<tensor::Tensor> = Vec::new();
35
36 let mut x_test: Vec<tensor::Tensor> = Vec::new();
37 let mut y_test: Vec<tensor::Tensor> = Vec::new();
38 let mut class_test: Vec<tensor::Tensor> = Vec::new();
39
40 let mut x_val: Vec<tensor::Tensor> = Vec::new();
41 let mut y_val: Vec<tensor::Tensor> = Vec::new();
42 let mut class_val: Vec<tensor::Tensor> = Vec::new();
43
44 for line in reader.lines().skip(1) {
45 let line = line.unwrap();
46 let record: Vec<&str> = line.split(',').collect();
47
48 let mut data: Vec<f32> = Vec::new();
49 for i in 0..571 {
50 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
51 }
52 let data: Vec<Vec<Vec<f32>>> = vec![vec![data]];
53 match record.get(573).unwrap() {
54 &"Train" => {
55 x_train.push(tensor::Tensor::triple(data));
56 y_train.push(tensor::Tensor::single(vec![record
57 .get(571)
58 .unwrap()
59 .parse::<f32>()
60 .unwrap()]));
61 class_train.push(tensor::Tensor::one_hot(
62 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
64 ));
65 }
66 &"Test" => {
67 x_test.push(tensor::Tensor::triple(data));
68 y_test.push(tensor::Tensor::single(vec![record
69 .get(571)
70 .unwrap()
71 .parse::<f32>()
72 .unwrap()]));
73 class_test.push(tensor::Tensor::one_hot(
74 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
76 ));
77 }
78 &"Val" => {
79 x_val.push(tensor::Tensor::triple(data));
80 y_val.push(tensor::Tensor::single(vec![record
81 .get(571)
82 .unwrap()
83 .parse::<f32>()
84 .unwrap()]));
85 class_val.push(tensor::Tensor::one_hot(
86 record.get(572).unwrap().parse::<usize>().unwrap() - 1, 28,
88 ));
89 }
90 _ => panic!("> Unknown class."),
91 }
92 }
93
94 (
99 (x_train, y_train, class_train),
100 (x_test, y_test, class_test),
101 (x_val, y_val, class_val),
102 )
103}
104
105fn main() {
106 let ((x_train, y_train, class_train), (x_test, y_test, class_test), (x_val, y_val, class_val)) =
108 data("./examples/datasets/ftir.csv");
109
110 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
111 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
112 let class_train: Vec<&tensor::Tensor> = class_train.iter().collect();
113
114 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
115 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
116 let class_test: Vec<&tensor::Tensor> = class_test.iter().collect();
117
118 let x_val: Vec<&tensor::Tensor> = x_val.iter().collect();
119 let y_val: Vec<&tensor::Tensor> = y_val.iter().collect();
120 let class_val: Vec<&tensor::Tensor> = class_val.iter().collect();
121
122 println!("Train data {}x{}", x_train.len(), x_train[0].shape,);
123 println!("Test data {}x{}", x_test.len(), x_test[0].shape,);
124 println!("Validation data {}x{}", x_val.len(), x_val[0].shape,);
125
126 vec!["REGRESSION", "CLASSIFICATION"]
127 .iter()
128 .for_each(|method| {
129 let mut network = network::Network::new(tensor::Shape::Triple(1, 1, 571));
131
132 network.convolution(
133 1,
134 (1, 9),
135 (1, 1),
136 (0, 4),
137 (1, 1),
138 activation::Activation::ReLU,
139 None,
140 );
141 network.dense(32, activation::Activation::ReLU, false, None);
142
143 if method == &"REGRESSION" {
144 network.dense(1, activation::Activation::Linear, false, None);
145 network.set_objective(objective::Objective::RMSE, None);
146 } else {
147 network.dense(28, activation::Activation::Softmax, false, None);
148 network.set_objective(objective::Objective::CrossEntropy, None);
149 }
150
151 network.loopback(0, 0, 1, Arc::new(|loops| loops), false);
152 network.set_accumulation(feedback::Accumulation::Add, feedback::Accumulation::Add);
153
154 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
155
156 println!("{}", network);
157
158 let (train_loss, val_loss, val_acc);
160 if method == &"REGRESSION" {
161 println!("> Training the network for regression.");
162
163 (train_loss, val_loss, val_acc) = network.learn(
164 &x_train,
165 &y_train,
166 Some((&x_val, &y_val, 50)),
167 16,
168 500,
169 Some(100),
170 );
171 } else {
172 println!("> Training the network for classification.");
173
174 (train_loss, val_loss, val_acc) = network.learn(
175 &x_train,
176 &class_train,
177 Some((&x_val, &class_val, 50)),
178 16,
179 500,
180 Some(100),
181 );
182 }
183 plot::loss(
184 &train_loss,
185 &val_loss,
186 &val_acc,
187 &format!("LOOP : FTIR : {}", method),
188 &format!("./output/ftir/cnn-{}-loop.png", method.to_lowercase()),
189 );
190
191 if method == &"REGRESSION" {
192 let prediction = network.predict(x_test.get(0).unwrap());
194 println!(
195 "Prediction. Target: {}. Output: {}.",
196 y_test[0].data, prediction.data
197 );
198 } else {
199 let (val_loss, val_acc) = network.validate(&x_test, &class_test, 1e-6);
201 println!(
202 "Final validation accuracy: {:.2} % and loss: {:.5}",
203 val_acc * 100.0,
204 val_loss
205 );
206
207 let prediction = network.predict(x_test.get(0).unwrap());
209 println!(
210 "Prediction. Target: {}. Output: {}.",
211 class_test[0].argmax(),
212 prediction.argmax()
213 );
214 }
215 });
216}