1pub mod matrix;
2pub mod neuralnetwork;
3pub mod xorshift;
4
5use crate::matrix::Matrix;
6use crate::neuralnetwork::NeuralNetwork;
7
8use std::fs;
9use std::time::{SystemTime, UNIX_EPOCH};
10
11pub fn current_millis() -> u128 {
12 SystemTime::now()
13 .duration_since(UNIX_EPOCH)
14 .unwrap()
15 .as_millis()
16}
17
18pub fn parse_csv(
19 filename: &str,
20 input_size: usize,
21 output_size: usize,
22) -> (Vec<Matrix>, Vec<Matrix>) {
23 let mut inputs = Vec::new();
24 let mut outputs = Vec::new();
25 let content = fs::read_to_string(filename).expect("Error: Can't open file!");
26 let lines: Vec<&str> = content.lines().collect();
27 for line_index in 0..lines.len() {
28 let line = lines[line_index];
29 let values: Vec<&str> = line.split(",").collect();
30 let mut input_vector = Matrix::new(input_size, 1);
31 let mut output_vector = Matrix::new(output_size, 1);
32 for value_index in 0..values.len() {
33 if value_index < input_size {
34 input_vector[value_index][0] = values[value_index].parse::<f32>().unwrap();
35 } else {
36 output_vector[value_index - input_size][0] =
37 values[value_index].parse::<f32>().unwrap();
38 }
39 }
40 inputs.push(input_vector);
41 outputs.push(output_vector);
42 }
43 (inputs, outputs)
44}
45
46pub fn get_accuracy(nn: &NeuralNetwork, filename: &str) -> f32 {
47 let (inputs, outputs) = parse_csv(filename, nn.input_nodes, nn.output_nodes);
48 let mut num_right: usize = 0;
49 for i in 0..inputs.len() {
50 if nn.predict(&inputs[i]).index_of_max() == outputs[i].index_of_max() {
51 num_right += 1;
52 }
53 }
54 num_right as f32 / inputs.len() as f32
55}
56
57pub fn train_on_dataset(nn: &mut NeuralNetwork, filename: &str, epochs: u32) {
58 let (inputs, outputs) = parse_csv(filename, nn.input_nodes, nn.output_nodes);
59 let start_time = current_millis();
60 for i in 0..epochs {
61 for j in 0..inputs.len() {
62 nn.train(&inputs[j], &outputs[j]);
63 }
64 print!("{} of {} epochs done\n", i + 1, epochs);
65 }
66 let end_time = (current_millis() - start_time) as f32 / 1000 as f32;
67 print!("Training took {}s\n", end_time);
68}