1use neurons::{activation, network, objective, optimizer, random, tensor};
4
5use std::{
6 fs::File,
7 io::{BufRead, BufReader},
8};
9
10fn data(path: &str) -> (Vec<tensor::Tensor>, Vec<tensor::Tensor>) {
11 let reader = BufReader::new(File::open(&path).unwrap());
12
13 let mut x: Vec<Vec<f32>> = Vec::new();
14 let mut y: Vec<usize> = Vec::new();
15
16 for line in reader.lines().skip(1) {
17 let line = line.unwrap();
18 let record: Vec<&str> = line.split(',').collect();
19 x.push(vec![
20 record.get(1).unwrap().parse::<f32>().unwrap(),
21 record.get(2).unwrap().parse::<f32>().unwrap(),
22 record.get(3).unwrap().parse::<f32>().unwrap(),
23 record.get(4).unwrap().parse::<f32>().unwrap(),
24 ]);
25 y.push(match record.get(5).unwrap() {
26 &"Iris-setosa" => 0,
27 &"Iris-versicolor" => 1,
28 &"Iris-virginica" => 2,
29 _ => panic!("> Unknown class."),
30 });
31 }
32
33 let mut generator = random::Generator::create(12345);
34 let mut indices: Vec<usize> = (0..x.len()).collect();
35 generator.shuffle(&mut indices);
36
37 let x: Vec<tensor::Tensor> = indices
38 .iter()
39 .map(|&i| tensor::Tensor::single(x[i].clone()))
40 .collect();
41 let y: Vec<tensor::Tensor> = indices
42 .iter()
43 .map(|&i| tensor::Tensor::one_hot(y[i], 3))
44 .collect();
45
46 (x, y)
47}
48
49fn main() {
50 let (x, y) = data("./examples/datasets/iris.csv");
52
53 let split = (x.len() as f32 * 0.8) as usize;
54 let x = x.split_at(split);
55 let y = y.split_at(split);
56
57 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
58 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
59 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
60 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
61
62 println!(
63 "Train data {}x{}: {} => {}",
64 x_train.len(),
65 x_train[0].shape,
66 x_train[0].data,
67 y_train[0].data
68 );
69 println!(
70 "Test data {}x{}: {} => {}",
71 x_test.len(),
72 x_test[0].shape,
73 x_test[0].data,
74 y_test[0].data
75 );
76
77 let mut network = network::Network::new(tensor::Shape::Single(4));
79
80 network.dense(50, activation::Activation::ReLU, false, None);
81 network.dense(50, activation::Activation::ReLU, false, None);
82 network.dense(3, activation::Activation::Softmax, false, None);
83
84 network.set_optimizer(optimizer::RMSprop::create(
85 0.0001, 0.0, 1e-8, Some(0.01), Some(0.01), true, ));
92 network.set_objective(
93 objective::Objective::CrossEntropy, Some((-1f32, 1f32)), );
96
97 let (_train_loss, _val_loss, _val_acc) = network.learn(
99 &x_train,
100 &y_train,
101 Some((&x_test, &y_test, 5)),
102 1,
103 5,
104 Some(1),
105 );
106
107 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
109 println!(
110 "Final validation accuracy: {:.2} % and loss: {:.5}",
111 val_acc * 100.0,
112 val_loss
113 );
114
115 let prediction = network.predict(x_test.get(0).unwrap());
117 println!(
118 "Prediction on input: {}. Target: {}. Output: {}.",
119 x_test[0].data,
120 y_test[0].argmax(),
121 prediction.argmax()
122 );
123}