concision_linear/params/
item.rs

1/*
2    Appellation: entry <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::params::ParamsBase;
6use core::marker::PhantomData;
7use ndarray::*;
8use strum::{AsRefStr, EnumDiscriminants, EnumIs, VariantNames};
9
10#[derive(AsRefStr, EnumDiscriminants, EnumIs, VariantNames)]
11#[cfg_attr(
12    feature = "serde",
13    strum_discriminants(
14        derive(serde::Deserialize, serde::Serialize),
15        serde(rename_all = "lowercase", untagged),
16    )
17)]
18#[non_exhaustive]
19#[strum(serialize_all = "lowercase")]
20#[strum_discriminants(
21    name(Param),
22    derive(
23        AsRefStr,
24        Hash,
25        Ord,
26        PartialOrd,
27        VariantNames,
28        strum::Display,
29        strum::EnumCount,
30        EnumIs,
31        strum::EnumIter,
32        strum::EnumString,
33        strum::VariantArray
34    ),
35    strum(serialize_all = "lowercase")
36)]
37pub enum Parameter<S, D>
38where
39    S: RawData,
40    D: RemoveAxis,
41{
42    Bias(ArrayBase<S, D::Smaller>),
43    Weight(ArrayBase<S, D>),
44}
45
46impl<A, S, D> Parameter<S, D>
47where
48    D: RemoveAxis,
49    S: RawData<Elem = A>,
50{
51    pub fn from_bias(data: ArrayBase<S, D::Smaller>) -> Self {
52        Self::Bias(data)
53    }
54
55    pub fn from_weight(data: ArrayBase<S, D>) -> Self {
56        Self::Weight(data)
57    }
58}
59
60impl Param {
61    pub fn bias() -> Self {
62        Self::Bias
63    }
64
65    pub fn weight() -> Self {
66        Self::Weight
67    }
68}
69
70pub struct Item<S, D, E>
71where
72    D: Dimension<Smaller = E>,
73    E: RemoveAxis,
74    S: RawData,
75{
76    pub data: ParamsBase<S, E>,
77    _parent: PhantomData<D>,
78}