1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/*
    Appellation: kinds <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
use ndarray::*;
use strum::{AsRefStr, EnumDiscriminants, EnumIs, VariantNames};

#[derive(AsRefStr, EnumDiscriminants, EnumIs, VariantNames)]
#[cfg_attr(
    feature = "serde",
    strum_discriminants(
        derive(serde::Deserialize, serde::Serialize),
        serde(rename_all = "lowercase", untagged),
    )
)]
#[non_exhaustive]
#[strum(serialize_all = "lowercase")]
#[strum_discriminants(
    name(Param),
    derive(
        AsRefStr,
        Hash,
        Ord,
        PartialOrd,
        VariantNames,
        strum::Display,
        strum::EnumCount,
        EnumIs,
        strum::EnumIter,
        strum::EnumString,
        strum::VariantArray
    ),
    strum(serialize_all = "lowercase")
)]
pub enum Entry<S, D>
where
    S: RawData,
    D: RemoveAxis,
{
    Bias(ArrayBase<S, D::Smaller>),
    Weight(ArrayBase<S, D>),
}

impl<A, S, D> Entry<S, D>
where
    D: RemoveAxis,
    S: RawData<Elem = A>,
{
    pub fn bias(data: ArrayBase<S, D::Smaller>) -> Self {
        Self::Bias(data)
    }

    pub fn weight(data: ArrayBase<S, D>) -> Self {
        Self::Weight(data)
    }
}

impl Param {
    pub fn bias() -> Self {
        Self::Bias
    }

    pub fn weight() -> Self {
        Self::Weight
    }
}