bike_loop/
looping.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, plot, random, tensor};
4
5use std::{
6    fs::File,
7    io::{BufRead, BufReader},
8    sync::Arc,
9};
10
11fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
12    let reader = BufReader::new(File::open(&path).unwrap());
13
14    let mut x: Vec<tensor::Tensor> = Vec::new();
15    let mut y: Vec<tensor::Tensor> = Vec::new();
16
17    for line in reader.lines().skip(1) {
18        let line = line.unwrap();
19        let record: Vec<&str> = line.split(',').collect();
20
21        let mut data: Vec<f32> = Vec::new();
22        for i in 2..14 {
23            data.push(record.get(i).unwrap().parse::<f32>().unwrap());
24        }
25        x.push(tensor::Tensor::single(data));
26
27        y.push(tensor::Tensor::single(vec![record
28            .get(16)
29            .unwrap()
30            .parse::<f32>()
31            .unwrap()]));
32    }
33
34    let mut generator = random::Generator::create(12345);
35    let mut indices: Vec<usize> = (0..x.len()).collect();
36    generator.shuffle(&mut indices);
37
38    let x: Vec<tensor::Tensor> = indices.iter().map(|i| x[*i].clone()).collect();
39    let y: Vec<tensor::Tensor> = indices.iter().map(|i| y[*i].clone()).collect();
40
41    (x, y)
42}
43
44fn main() {
45    // Load the ftir dataset
46    let (x, y) = data("./examples/datasets/bike/hour.csv");
47
48    let split = (x.len() as f32 * 0.8) as usize;
49    let x = x.split_at(split);
50    let y = y.split_at(split);
51
52    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
53    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
54    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
55    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
56
57    // Create the network
58    let mut network = network::Network::new(tensor::Shape::Single(12));
59
60    network.dense(24, activation::Activation::ReLU, false, None);
61    network.dense(24, activation::Activation::ReLU, false, None);
62    network.dense(24, activation::Activation::ReLU, false, None);
63
64    network.dense(1, activation::Activation::Linear, false, None);
65    network.set_objective(objective::Objective::RMSE, None);
66
67    network.loopback(2, 1, 2, Arc::new(|_loops| 1.0), false);
68
69    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
70
71    println!("{}", network);
72
73    // Train the network
74
75    let (train_loss, val_loss, val_acc) = network.learn(
76        &x_train,
77        &y_train,
78        Some((&x_test, &y_test, 25)),
79        64,
80        600,
81        Some(100),
82    );
83    plot::loss(
84        &train_loss,
85        &val_loss,
86        &val_acc,
87        &"LOOP : BIKE",
88        &"./output/bike/loop.png",
89    );
90
91    // Use the network
92    let prediction = network.predict(x_test.get(0).unwrap());
93    println!(
94        "Prediction. Target: {}. Output: {}.",
95        y_test[0].data, prediction.data
96    );
97}