selector/
selector.rs

1use ritenn::{Activation, HaltCondition, NN};
2
3const ACTIONS: u32 = 10;
4
5fn main() {
6    // create examples of the xor function
7    let mut examples = Vec::new();
8    for i in 0..ACTIONS {
9        let mut result = Vec::new();
10        for j in 0..ACTIONS {
11            if j == i {
12                result.push(1.0);
13            } else {
14                result.push(0.0);
15            }
16        }
17        let example = (vec![i as f64], result);
18        examples.push(example);
19    }
20
21    // create a new neural network
22    let mut nn = NN::new(&[1, 10, ACTIONS], Activation::PELU, Activation::Sigmoid);
23
24    // train the network
25    nn.train(&examples)
26        .log_interval(Some(1000))
27        .halt_condition(HaltCondition::MSE(0.01))
28        .rate(0.025)
29        .momentum(0.5)
30        .lambda(0.00005)
31        .go();
32
33    // print results of the trained network
34    for &(ref input, _) in examples.iter() {
35        let result = nn.run(input);
36        let print: Vec<String> = result
37            .iter()
38            .map(|x: &f64| format!("{:4.2}", (*x * 100.0).round() / 100.0))
39            .collect();
40        println!("{:1.0} -> {:?}", input[0], print);
41    }
42}