concision_linear/params/
store.rs

1/*
2    Appellation: params <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::{build_bias, Biased, Features, Node, ParamMode, Unbiased};
6use concision::dimensional;
7use core::marker::PhantomData;
8use nd::*;
9use num::{One, Zero};
10
11/// The [ParamsBase] struct is a generic store for linear parameters. The store mimics
12/// the underlying [ArrayBase](ndarray::ArrayBase), enabling developers to specify
13/// the data repr and dimension. Additionally, the store is parameterized to
14/// accept a `K` type, used to designate the store as either [Biased](crate::Biased) or [Unbiased](crate::Unbiased).
15pub struct ParamsBase<S = OwnedRepr<f64>, D = Ix2, K = Biased>
16where
17    D: Dimension,
18    S: RawData,
19{
20    pub(crate) bias: Option<ArrayBase<S, D::Smaller>>,
21    pub(crate) weight: ArrayBase<S, D>,
22    pub(crate) _mode: PhantomData<K>,
23}
24
25impl<A, S, D, K> ParamsBase<S, D, K>
26where
27    D: RemoveAxis,
28    S: RawData<Elem = A>,
29{
30    pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
31    where
32        A: Clone,
33        K: ParamMode,
34        S: DataOwned,
35        Sh: ShapeBuilder<Dim = D>,
36    {
37        let dim = shape.into_shape().raw_dim().clone();
38        let bias = if K::BIASED {
39            Some(ArrayBase::from_elem(
40                crate::bias_dim(dim.clone()),
41                elem.clone(),
42            ))
43        } else {
44            None
45        };
46        Self {
47            bias,
48            weight: ArrayBase::from_elem(dim, elem),
49            _mode: PhantomData::<K>,
50        }
51    }
52
53    pub fn into_biased(self) -> ParamsBase<S, D, Biased>
54    where
55        A: Default,
56        K: 'static,
57        S: DataOwned,
58    {
59        if self.is_biased() {
60            return ParamsBase {
61                bias: self.bias,
62                weight: self.weight,
63                _mode: PhantomData::<Biased>,
64            };
65        }
66        let sm = crate::bias_dim(self.raw_dim());
67        ParamsBase {
68            bias: Some(ArrayBase::default(sm)),
69            weight: self.weight,
70            _mode: PhantomData::<Biased>,
71        }
72    }
73
74    pub fn into_unbiased(self) -> ParamsBase<S, D, Unbiased> {
75        ParamsBase {
76            bias: None,
77            weight: self.weight,
78            _mode: PhantomData::<Unbiased>,
79        }
80    }
81
82    pub const fn weights(&self) -> &ArrayBase<S, D> {
83        &self.weight
84    }
85
86    pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
87        &mut self.weight
88    }
89
90    pub fn features(&self) -> Features {
91        Features::from_shape(self.shape())
92    }
93
94    pub fn in_features(&self) -> usize {
95        self.features().dmodel()
96    }
97
98    pub fn out_features(&self) -> usize {
99        if self.ndim() == 1 {
100            return 1;
101        }
102        self.shape()[1]
103    }
104    /// Returns true if the parameter store is biased;
105    /// Compares the [TypeId](core::any::TypeId) of the store with the [Biased](crate::Biased) type.
106    pub fn is_biased(&self) -> bool
107    where
108        K: 'static,
109    {
110        crate::is_biased::<K>()
111    }
112
113    pbuilder!(new.default where A: Default, S: DataOwned);
114
115    pbuilder!(ones where A: Clone + One, S: DataOwned);
116
117    pbuilder!(zeros where A: Clone + Zero, S: DataOwned);
118
119    dimensional!(weights());
120
121    wnbview!(into_owned::<OwnedRepr>(self) where A: Clone, S: Data);
122
123    wnbview!(into_shared::<OwnedArcRepr>(self) where A: Clone, S: DataOwned);
124
125    wnbview!(to_owned::<OwnedRepr>(&self) where A: Clone, S: Data);
126
127    wnbview!(to_shared::<OwnedArcRepr>(&self) where A: Clone, S: DataOwned);
128
129    wnbview!(view::<'a, ViewRepr>(&self) where A: Clone, S: Data);
130
131    wnbview!(view_mut::<'a, ViewRepr>(&mut self) where A: Clone, S: DataMut);
132}
133
134impl<A, S, D> ParamsBase<S, D, Biased>
135where
136    D: RemoveAxis,
137    S: RawData<Elem = A>,
138{
139    /// Create a new biased parameter store from the given shape.
140    pub fn biased<Sh>(shape: Sh) -> Self
141    where
142        A: Default,
143        S: DataOwned,
144        Sh: ShapeBuilder<Dim = D>,
145    {
146        let dim = shape.into_shape().raw_dim().clone();
147        Self {
148            bias: build_bias(true, dim.clone(), ArrayBase::default),
149            weight: ArrayBase::default(dim),
150            _mode: PhantomData::<Biased>,
151        }
152    }
153    /// Return an unwraped, immutable reference to the bias array.
154    pub fn bias(&self) -> &ArrayBase<S, D::Smaller> {
155        self.bias.as_ref().unwrap()
156    }
157    /// Return an unwraped, mutable reference to the bias array.
158    pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
159        self.bias.as_mut().unwrap()
160    }
161}
162
163impl<A, S, D> ParamsBase<S, D, Unbiased>
164where
165    D: Dimension,
166    S: RawData<Elem = A>,
167{
168    /// Create a new unbiased parameter store from the given shape.
169    pub fn unbiased<Sh>(shape: Sh) -> Self
170    where
171        A: Default,
172        S: DataOwned,
173        Sh: ShapeBuilder<Dim = D>,
174    {
175        Self {
176            bias: None,
177            weight: ArrayBase::default(shape),
178            _mode: PhantomData::<Unbiased>,
179        }
180    }
181}
182impl<A, S, K> ParamsBase<S, Ix2, K>
183where
184    K: 'static,
185    S: RawData<Elem = A>,
186{
187    pub fn set_node(&mut self, idx: usize, node: Node<A>)
188    where
189        A: Clone + Default,
190        S: DataMut + DataOwned,
191    {
192        let (weight, bias) = node;
193        if let Some(bias) = bias {
194            if !self.is_biased() {
195                let mut tmp = ArrayBase::default(self.out_features());
196                tmp.index_axis_mut(Axis(0), idx).assign(&bias);
197                self.bias = Some(tmp);
198            }
199            self.bias
200                .as_mut()
201                .unwrap()
202                .index_axis_mut(Axis(0), idx)
203                .assign(&bias);
204        }
205
206        self.weights_mut()
207            .index_axis_mut(Axis(0), idx)
208            .assign(&weight);
209    }
210}
211
212impl<A, S, D> Default for ParamsBase<S, D, Biased>
213where
214    A: Default,
215    D: Dimension,
216    S: DataOwned<Elem = A>,
217{
218    fn default() -> Self {
219        Self {
220            bias: Some(Default::default()),
221            weight: Default::default(),
222            _mode: PhantomData::<Biased>,
223        }
224    }
225}
226
227impl<A, S, D> Default for ParamsBase<S, D, Unbiased>
228where
229    A: Default,
230    D: Dimension,
231    S: DataOwned<Elem = A>,
232{
233    fn default() -> Self {
234        Self {
235            bias: None,
236            weight: Default::default(),
237            _mode: PhantomData::<Unbiased>,
238        }
239    }
240}