1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/*
    Appellation: serde <impls>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
#![cfg(feature = "serde")]

use crate::params::{Entry, ParamMode, ParamsBase};
use core::marker::PhantomData;
use nd::*;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

impl<'a, A, S, D, K> Deserialize<'a> for ParamsBase<S, D, K>
where
    A: Deserialize<'a>,
    D: Deserialize<'a> + RemoveAxis,
    S: DataOwned<Elem = A>,
    <D as Dimension>::Smaller: Deserialize<'a> + Dimension,
{
    fn deserialize<Der>(deserializer: Der) -> Result<Self, Der::Error>
    where
        Der: Deserializer<'a>,
    {
        let (bias, weights) = Deserialize::deserialize(deserializer)?;
        Ok(Self {
            bias,
            weights,
            _mode: PhantomData,
        })
    }
}

impl<A, S, D, K> Serialize for ParamsBase<S, D, K>
where
    A: Serialize,
    D: RemoveAxis + Serialize,
    K: ParamMode,
    S: Data<Elem = A>,
    <D as Dimension>::Smaller: Dimension + Serialize,
{
    fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
    where
        Ser: Serializer,
    {
        (self.bias(), self.weights()).serialize(serializer)
    }
}

impl<A, S, D> Serialize for Entry<S, D>
where
    A: Serialize,
    S: Data<Elem = A>,
    D: RemoveAxis + Serialize,
    <D as Dimension>::Smaller: Dimension + Serialize,
{
    fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
    where
        Ser: Serializer,
    {
        match self {
            Self::Bias(bias) => bias.serialize(serializer),
            Self::Weight(weight) => weight.serialize(serializer),
        }
    }
}