1use neurons::{activation, network, objective, optimizer, plot, random, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8};
9
10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11 let reader = BufReader::new(File::open(&path).unwrap());
12
13 let mut x: Vec<tensor::Tensor> = Vec::new();
14 let mut y: Vec<tensor::Tensor> = Vec::new();
15
16 for line in reader.lines().skip(1) {
17 let line = line.unwrap();
18 let record: Vec<&str> = line.split(',').collect();
19
20 let mut data: Vec<f32> = Vec::new();
21 for i in 2..14 {
22 data.push(record.get(i).unwrap().parse::<f32>().unwrap());
23 }
24 x.push(tensor::Tensor::single(data));
25
26 y.push(tensor::Tensor::single(vec![record
27 .get(16)
28 .unwrap()
29 .parse::<f32>()
30 .unwrap()]));
31 }
32
33 let mut generator = random::Generator::create(12345);
34 let mut indices: Vec<usize> = (0..x.len()).collect();
35 generator.shuffle(&mut indices);
36
37 let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
38 let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
39
40 (x, y)
41}
42
43fn main() {
44 let (x, y) = data("./examples/datasets/bike/hour.csv");
46
47 let split = (x.len() as f32 * 0.8) as usize;
48 let x = x.split_at(split);
49 let y = y.split_at(split);
50
51 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
52 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
53 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
54 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
55
56 let mut network = network::Network::new(tensor::Shape::Single(12));
58
59 network.dense(24, activation::Activation::ReLU, false, None);
60 network.dense(24, activation::Activation::ReLU, false, None);
61 network.dense(24, activation::Activation::ReLU, false, None);
62
63 network.dense(1, activation::Activation::Linear, false, None);
64 network.set_objective(objective::Objective::RMSE, None);
65
66 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
67
68 println!("{}", network);
69
70 let (train_loss, val_loss, val_acc) = network.learn(
73 &x_train,
74 &y_train,
75 Some((&x_test, &y_test, 25)),
76 64,
77 600,
78 Some(100),
79 );
80 plot::loss(
81 &train_loss,
82 &val_loss,
83 &val_acc,
84 &"PLAIN : BIKE",
85 &"./output/bike/plain.png",
86 );
87
88 let prediction = network.predict(x_test.get(0).unwrap());
90 println!(
91 "Prediction. Target: {}. Output: {}.",
92 y_test[0].data, prediction.data
93 );
94}