concision_linear/params/
item.rs1use crate::params::ParamsBase;
6use core::marker::PhantomData;
7use ndarray::*;
8use strum::{AsRefStr, EnumDiscriminants, EnumIs, VariantNames};
9
10#[derive(AsRefStr, EnumDiscriminants, EnumIs, VariantNames)]
11#[cfg_attr(
12 feature = "serde",
13 strum_discriminants(
14 derive(serde::Deserialize, serde::Serialize),
15 serde(rename_all = "lowercase", untagged),
16 )
17)]
18#[non_exhaustive]
19#[strum(serialize_all = "lowercase")]
20#[strum_discriminants(
21 name(Param),
22 derive(
23 AsRefStr,
24 Hash,
25 Ord,
26 PartialOrd,
27 VariantNames,
28 strum::Display,
29 strum::EnumCount,
30 EnumIs,
31 strum::EnumIter,
32 strum::EnumString,
33 strum::VariantArray
34 ),
35 strum(serialize_all = "lowercase")
36)]
37pub enum Parameter<S, D>
38where
39 S: RawData,
40 D: RemoveAxis,
41{
42 Bias(ArrayBase<S, D::Smaller>),
43 Weight(ArrayBase<S, D>),
44}
45
46impl<A, S, D> Parameter<S, D>
47where
48 D: RemoveAxis,
49 S: RawData<Elem = A>,
50{
51 pub fn from_bias(data: ArrayBase<S, D::Smaller>) -> Self {
52 Self::Bias(data)
53 }
54
55 pub fn from_weight(data: ArrayBase<S, D>) -> Self {
56 Self::Weight(data)
57 }
58}
59
60impl Param {
61 pub fn bias() -> Self {
62 Self::Bias
63 }
64
65 pub fn weight() -> Self {
66 Self::Weight
67 }
68}
69
70pub struct Item<S, D, E>
71where
72 D: Dimension<Smaller = E>,
73 E: RemoveAxis,
74 S: RawData,
75{
76 pub data: ParamsBase<S, E>,
77 _parent: PhantomData<D>,
78}