iris/
iris.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, random, tensor};
4
5use std::{
6    fs::File,
7    io::{BufRead, BufReader},
8};
9
10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11    let reader = BufReader::new(File::open(&path).unwrap());
12
13    let mut x: Vec<Vec<f32>> = Vec::new();
14    let mut y: Vec<usize> = Vec::new();
15
16    for line in reader.lines().skip(1) {
17        let line = line.unwrap();
18        let record: Vec<&str> = line.split(',').collect();
19        x.push(vec![
20            record.get(1).unwrap().parse::<f32>().unwrap(),
21            record.get(2).unwrap().parse::<f32>().unwrap(),
22            record.get(3).unwrap().parse::<f32>().unwrap(),
23            record.get(4).unwrap().parse::<f32>().unwrap(),
24        ]);
25        y.push(match record.get(5).unwrap() {
26            &"Iris-setosa" => 0,
27            &"Iris-versicolor" => 1,
28            &"Iris-virginica" => 2,
29            _ => panic!("> Unknown class."),
30        });
31    }
32
33    let mut generator = random::Generator::create(12345);
34    let mut indices: Vec<usize> = (0..x.len()).collect();
35    generator.shuffle(&mut indices);
36
37    let x: Vec<tensor::Tensor> = indices
38        .iter()
39        .map(|&i| tensor::Tensor::single(x[i].clone()))
40        .collect();
41    let y: Vec<tensor::Tensor> = indices
42        .iter()
43        .map(|&i| tensor::Tensor::one_hot(y[i], 3))
44        .collect();
45
46    (x, y)
47}
48
49fn main() {
50    // Load the iris dataset
51    let (x, y) = data("./examples/datasets/iris.csv");
52
53    let split = (x.len() as f32 * 0.8) as usize;
54    let x = x.split_at(split);
55    let y = y.split_at(split);
56
57    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
58    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
59    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
60    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
61
62    println!(
63        "Train data {}x{}: {} => {}",
64        x_train.len(),
65        x_train[0].shape,
66        x_train[0].data,
67        y_train[0].data
68    );
69    println!(
70        "Test data {}x{}: {} => {}",
71        x_test.len(),
72        x_test[0].shape,
73        x_test[0].data,
74        y_test[0].data
75    );
76
77    // Create the network
78    let mut network = network::Network::new(tensor::Shape::Single(4));
79
80    network.dense(50, activation::Activation::ReLU, false, None);
81    network.dense(50, activation::Activation::ReLU, false, None);
82    network.dense(3, activation::Activation::Softmax, false, None);
83
84    network.set_optimizer(optimizer::RMSprop::create(
85        0.0001,     // Learning rate
86        0.0,        // Alpha
87        1e-8,       // Epsilon
88        Some(0.01), // Decay
89        Some(0.01), // Momentum
90        true,       // Centered
91    ));
92    network.set_objective(
93        objective::Objective::CrossEntropy, // Objective function
94        Some((-1f32, 1f32)),                // Gradient clipping
95    );
96
97    // Train the network
98    let (_train_loss, _val_loss, _val_acc) = network.learn(
99        &x_train,
100        &y_train,
101        Some((&x_test, &y_test, 5)),
102        1,
103        5,
104        Some(1),
105    );
106
107    // Validate the network
108    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
109    println!(
110        "Final validation accuracy: {:.2} % and loss: {:.5}",
111        val_acc * 100.0,
112        val_loss
113    );
114
115    // Use the network
116    let prediction = network.predict(x_test.get(0).unwrap());
117    println!(
118        "Prediction on input: {}. Target: {}. Output: {}.",
119        x_test[0].data,
120        y_test[0].argmax(),
121        prediction.argmax()
122    );
123}