neuralnetwork/
lib.rs

1pub mod matrix;
2pub mod neuralnetwork;
3pub mod xorshift;
4
5use crate::matrix::Matrix;
6use crate::neuralnetwork::NeuralNetwork;
7
8use std::fs;
9use std::time::{SystemTime, UNIX_EPOCH};
10
11pub fn current_millis() -> u128 {
12    SystemTime::now()
13        .duration_since(UNIX_EPOCH)
14        .unwrap()
15        .as_millis()
16}
17
18pub fn parse_csv(
19    filename: &str,
20    input_size: usize,
21    output_size: usize,
22) -> (Vec<Matrix>, Vec<Matrix>) {
23    let mut inputs = Vec::new();
24    let mut outputs = Vec::new();
25    let content = fs::read_to_string(filename).expect("Error: Can't open file!");
26    let lines: Vec<&str> = content.lines().collect();
27    for line_index in 0..lines.len() {
28        let line = lines[line_index];
29        let values: Vec<&str> = line.split(",").collect();
30        let mut input_vector = Matrix::new(input_size, 1);
31        let mut output_vector = Matrix::new(output_size, 1);
32        for value_index in 0..values.len() {
33            if value_index < input_size {
34                input_vector[value_index][0] = values[value_index].parse::<f32>().unwrap();
35            } else {
36                output_vector[value_index - input_size][0] =
37                    values[value_index].parse::<f32>().unwrap();
38            }
39        }
40        inputs.push(input_vector);
41        outputs.push(output_vector);
42    }
43    (inputs, outputs)
44}
45
46pub fn get_accuracy(nn: &NeuralNetwork, filename: &str) -> f32 {
47    let (inputs, outputs) = parse_csv(filename, nn.input_nodes, nn.output_nodes);
48    let mut num_right: usize = 0;
49    for i in 0..inputs.len() {
50        if nn.predict(&inputs[i]).index_of_max() == outputs[i].index_of_max() {
51            num_right += 1;
52        }
53    }
54    num_right as f32 / inputs.len() as f32
55}
56
57pub fn train_on_dataset(nn: &mut NeuralNetwork, filename: &str, epochs: u32) {
58    let (inputs, outputs) = parse_csv(filename, nn.input_nodes, nn.output_nodes);
59    let start_time = current_millis();
60    for i in 0..epochs {
61        for j in 0..inputs.len() {
62            nn.train(&inputs[j], &outputs[j]);
63        }
64        print!("{} of {} epochs done\n", i + 1, epochs);
65    }
66    let end_time = (current_millis() - start_time) as f32 / 1000 as f32;
67    print!("Training took {}s\n", end_time);
68}