xor/
xor.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use neurons::{activation, network, objective, optimizer, tensor};
4
5fn main() {
6    // Create the training data for the binary AND operation
7    let x: Vec<tensor::Tensor> = vec![
8        tensor::Tensor::single(vec![0.0, 0.0]),
9        tensor::Tensor::single(vec![0.0, 1.0]),
10        tensor::Tensor::single(vec![1.0, 0.0]),
11        tensor::Tensor::single(vec![1.0, 1.0]),
12    ];
13    let y: Vec<tensor::Tensor> = vec![
14        tensor::Tensor::single(vec![0.0]),
15        tensor::Tensor::single(vec![0.0]),
16        tensor::Tensor::single(vec![0.0]),
17        tensor::Tensor::single(vec![1.0]),
18    ];
19
20    let inputs: Vec<&tensor::Tensor> = x.iter().collect();
21    let targets: Vec<&tensor::Tensor> = y.iter().collect();
22
23    // Create the network
24    let mut network = network::Network::new(tensor::Shape::Single(2));
25
26    network.dense(10, activation::Activation::ReLU, true, None);
27    network.dense(1, activation::Activation::Sigmoid, false, None);
28
29    network.set_optimizer(optimizer::SGD::create(0.1, Some(0.01)));
30    network.set_objective(objective::Objective::BinaryCrossEntropy, None);
31
32    // Train the network
33    let (_epoch_loss, _val_loss, _val_acc) =
34        network.learn(&inputs, &targets, None, 4, 500, Some(50));
35
36    // Validate the network
37    let (val_loss, val_acc) = network.validate(&inputs, &targets, 1e-1);
38    println!(
39        "Final validation accuracy: {:.2} % and loss: {:.5}",
40        val_acc * 100.0,
41        val_loss
42    );
43
44    // Use the network
45    let prediction = network.predict(inputs.get(0).unwrap());
46    println!(
47        "Prediction on input: {} Target: {} Output: {}",
48        inputs[0].data, targets[0].data, prediction.data
49    );
50}