xor/
xor.rs

1use echo_state_network::*;
2use rand::prelude::*;
3
4const TRAIN_STEP: usize = 5000;
5const TEST_STEP: usize = 100;
6const N_X: u64 = 100;
7const BETA: f64 = 0.0;
8
9const RANDOM_SEED: u64 = 41;
10const TEST_RANDOM_SEED: u64 = 91;
11
12fn main() {
13    let (train_input, train_expected_output) = xor_data_gen(TRAIN_STEP, RANDOM_SEED);
14    let (test_input, test_expected_output) = xor_data_gen(TEST_STEP, TEST_RANDOM_SEED);
15
16    let path = format!("{}/examples/graph", env!("CARGO_MANIFEST_DIR"));
17
18    let n_u = train_input.first().unwrap().len() as u64;
19    let n_y = train_expected_output.first().unwrap().len() as u64;
20
21    let mut model = EchoStateNetwork::new(
22        n_u,
23        n_y,
24        N_X,
25        0.1,
26        1.0,
27        0.9,
28        |x| x.tanh(),
29        None,
30        None,
31        1.0,
32        |x| x.clone_owned(),
33        |x| x.clone_owned(),
34        false,
35        BETA,
36    );
37
38    model.offline_train(&train_input, &train_expected_output);
39
40    let mut estimated_output = vec![];
41    for input in test_input.iter() {
42        estimated_output.push(model.estimate(input));
43    }
44
45    let (bits_l2_error, bits_l1_error) =
46        get_bits_error_rate(estimated_output.clone(), test_expected_output.clone(), 2);
47    let (l2_error, l1_error) =
48        get_error_rate(estimated_output.clone(), test_expected_output.clone(), 2);
49    println!("Bits Mean Squared Error: {}", bits_l2_error);
50    println!("Bits Mean Absolute Error: {}", bits_l1_error);
51    println!("Mean Squared Error: {}", l2_error);
52    println!("Mean Absolute Error: {}", l1_error);
53
54    let y_estimated = estimated_output.iter().map(|x| x[0]).collect::<Vec<f64>>();
55    let y_expected = test_expected_output
56        .clone()
57        .into_iter()
58        .flatten()
59        .collect::<Vec<f64>>();
60
61    plotter::plot(
62        "XOR",
63        (0..TEST_STEP).map(|v| v as f64).collect::<Vec<f64>>(),
64        vec![y_expected, y_estimated],
65        vec!["Expected".to_string(), "Output".to_string()],
66        Some(&path),
67    )
68    .unwrap();
69
70    write_as_serde(
71        model,
72        &train_input,
73        &train_expected_output,
74        &test_input,
75        &test_expected_output,
76        estimated_output,
77        None,
78    );
79}
80
81fn xor_data_gen(step: usize, seed: u64) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
82    let tau = 2;
83
84    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
85
86    let input_vec = (0..step)
87        .map(|_| vec![rng.gen_range(0..2) as f64])
88        .collect::<Vec<Vec<f64>>>();
89
90    let mut output_vec = vec![vec![0.0]; step];
91    for n in tau..step {
92        output_vec[n][0] = ((input_vec[n - 1][0] as u32) ^ (input_vec[n - 2][0] as u32)) as f64;
93    }
94
95    (input_vec, output_vec)
96}
97
98fn get_bits_error_rate(
99    estimated_output: Vec<Vec<f64>>,
100    expected_output: Vec<Vec<f64>>,
101    ignore_bits: usize,
102) -> (f64, f64) {
103    let mut y_tested_binary = vec![0.0; estimated_output.len()];
104
105    for (n, estimated) in estimated_output.iter().enumerate() {
106        if estimated[0] > 0.5 {
107            y_tested_binary[n] = 1.0;
108        } else {
109            y_tested_binary[n] = 0.0;
110        }
111    }
112
113    let expected_output = expected_output.into_iter().flatten().collect::<Vec<f64>>();
114
115    let mse = mean_squared_error(
116        &expected_output[ignore_bits..],
117        &y_tested_binary[ignore_bits..],
118    );
119    let mae = mean_absolute_error(
120        &expected_output[ignore_bits..],
121        &y_tested_binary[ignore_bits..],
122    );
123
124    (mse, mae)
125}
126
127fn get_error_rate(
128    estimated_output: Vec<Vec<f64>>,
129    expected_output: Vec<Vec<f64>>,
130    ignore_bits: usize,
131) -> (f64, f64) {
132    let estimated_output = estimated_output.iter().map(|x| x[0]).collect::<Vec<f64>>();
133    let expected_output = expected_output.into_iter().flatten().collect::<Vec<f64>>();
134
135    let mse = mean_squared_error(
136        &expected_output[ignore_bits..],
137        &estimated_output[ignore_bits..],
138    );
139    let mae = mean_absolute_error(
140        &expected_output[ignore_bits..],
141        &estimated_output[ignore_bits..],
142    );
143
144    (mse, mae)
145}