1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
use super::*;
use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;

/// A serializable wrapper for [`NeuronTopology`]. See [`NNTSerde::from`] for conversion.
#[derive(Serialize, Deserialize)]
pub struct NNTSerde<const I: usize, const O: usize> {
    #[serde(with = "BigArray")]
    pub(crate) input_layer: [NeuronTopology; I],

    pub(crate) hidden_layers: Vec<NeuronTopology>,

    #[serde(with = "BigArray")]
    pub(crate) output_layer: [NeuronTopology; O],

    pub(crate) mutation_rate: f32,
    pub(crate) mutation_passes: usize,
}

impl<const I: usize, const O: usize> From<&NeuralNetworkTopology<I, O>> for NNTSerde<I, O> {
    fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
        let input_layer = value
            .input_layer
            .iter()
            .map(|n| n.read().unwrap().clone())
            .collect::<Vec<_>>()
            .try_into()
            .unwrap();

        let hidden_layers = value
            .hidden_layers
            .iter()
            .map(|n| n.read().unwrap().clone())
            .collect();

        let output_layer = value
            .output_layer
            .iter()
            .map(|n| n.read().unwrap().clone())
            .collect::<Vec<_>>()
            .try_into()
            .unwrap();

        Self {
            input_layer,
            hidden_layers,
            output_layer,
            mutation_rate: value.mutation_rate,
            mutation_passes: value.mutation_passes,
        }
    }
}

#[cfg(test)]
#[test]
fn serde() {
    let mut rng = rand::thread_rng();
    let nnt = NeuralNetworkTopology::<10, 10>::new(0.1, 3, &mut rng);
    let nnts = NNTSerde::from(&nnt);

    let encoded = bincode::serialize(&nnts).unwrap();

    if let Some(_) = option_env!("TEST_CREATEFILE") {
        std::fs::write("serde-test.nn", &encoded).unwrap();
    }

    let decoded: NNTSerde<10, 10> = bincode::deserialize(&encoded).unwrap();
    let nnt2: NeuralNetworkTopology<10, 10> = decoded.into();

    dbg!(nnt, nnt2);
}