concision_core/params/
params.rs

1/*
2    Appellation: params <module>
3    Contrib: @FL03
4*/
5use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8/// this structure extends the `ArrayBase` type to include bias
9pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11    D: Dimension,
12    S: RawData,
13{
14    pub(crate) bias: ArrayBase<S, D::Smaller>,
15    pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20    D: Dimension,
21    S: RawData<Elem = A>,
22{
23    pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24    where
25        A: Clone,
26        S: DataOwned,
27    {
28        Self { bias, weights }
29    }
30
31    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32    where
33        A: Clone,
34        D: RemoveAxis,
35        S: DataOwned,
36        Sh: ShapeBuilder<Dim = D>,
37    {
38        let weights = ArrayBase::from_elem(shape, elem.clone());
39        let dim = weights.raw_dim();
40        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41        Self { bias, weights }
42    }
43    /// create an instance of the parameters with all values set to the default value
44    pub fn default<Sh>(shape: Sh) -> Self
45    where
46        A: Clone + Default,
47        D: RemoveAxis,
48        S: DataOwned,
49        Sh: ShapeBuilder<Dim = D>,
50    {
51        let weights = ArrayBase::default(shape);
52        let dim = weights.raw_dim();
53        let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54        Self { bias, weights }
55    }
56    /// initialize the parameters with all values set to zero
57    pub fn ones<Sh>(shape: Sh) -> Self
58    where
59        A: Clone + num_traits::One,
60        D: RemoveAxis,
61        S: DataOwned,
62        Sh: ShapeBuilder<Dim = D>,
63    {
64        let weights = ArrayBase::ones(shape);
65        let dim = weights.raw_dim();
66        let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67        Self { bias, weights }
68    }
69    /// create an instance of the parameters with all values set to zero
70    pub fn zeros<Sh>(shape: Sh) -> Self
71    where
72        A: Clone + num_traits::Zero,
73        D: RemoveAxis,
74        S: DataOwned,
75        Sh: ShapeBuilder<Dim = D>,
76    {
77        let weights = ArrayBase::zeros(shape);
78        let dim = weights.raw_dim();
79        let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80        Self { bias, weights }
81    }
82    /// returns an immutable reference to the bias
83    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84        &self.bias
85    }
86    /// returns a mutable reference to the bias
87    #[inline]
88    pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89        &mut self.bias
90    }
91    /// returns an immutable reference to the weights
92    pub const fn weights(&self) -> &ArrayBase<S, D> {
93        &self.weights
94    }
95    /// returns a mutable reference to the weights
96    #[inline]
97    pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98        &mut self.weights
99    }
100    /// perform a single backpropagation step
101    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
102    where
103        A: Clone,
104        S: Data,
105        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
106    {
107        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
108    }
109    /// forward propagation
110    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
111    where
112        A: Clone,
113        S: Data,
114        Self: crate::Forward<X, Output = Y>,
115    {
116        <Self as crate::Forward<X>>::forward(self, input)
117    }
118    /// returns the dimensions of the weights
119    pub fn dim(&self) -> D::Pattern {
120        self.weights.dim()
121    }
122    /// an iterator of the parameters; the created iterator zips together an axis iterator over
123    /// the columns of the weights and an iterator over the bias
124    pub fn iter(&self) -> super::iter::Iter<'_, A, D>
125    where
126        D: RemoveAxis,
127        S: Data,
128    {
129        super::iter::Iter {
130            bias: self.bias.iter(),
131            weights: self.weights.axis_iter(Axis(1)),
132        }
133    }
134    /// a mutable iterator of the parameters
135    pub fn iter_mut(
136        &mut self,
137    ) -> core::iter::Zip<
138        ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
139        ndarray::iter::IterMut<'_, A, D::Smaller>,
140    >
141    where
142        D: RemoveAxis,
143        S: DataMut,
144    {
145        self.weights
146            .axis_iter_mut(Axis(1))
147            .zip(self.bias.iter_mut())
148    }
149    /// returns an iterator over the bias
150    pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
151    where
152        S: Data,
153    {
154        self.bias.iter()
155    }
156    /// returns a mutable iterator over the bias
157    pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
158    where
159        S: DataMut,
160    {
161        self.bias.iter_mut()
162    }
163    /// returns an iterator over the weights
164    pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
165    where
166        S: Data,
167    {
168        self.weights.iter()
169    }
170    /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
171    pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
172    where
173        S: DataMut,
174    {
175        self.weights.iter_mut()
176    }
177    /// the total number of elements within the weight tensor
178    pub fn len(&self) -> usize {
179        self.weights.len()
180    }
181    /// returns the raw dimensions of the weights;
182    pub fn raw_dim(&self) -> D {
183        self.weights.raw_dim()
184    }
185    /// returns the shape of the parameters; uses the shape of the weight tensor
186    pub fn shape(&self) -> &[usize] {
187        self.weights.shape()
188    }
189    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
190    /// weight tensor minus the "zero-th" axis
191    pub fn shape_bias(&self) -> &[usize] {
192        self.bias.shape()
193    }
194    /// returns an owned instance of the parameters
195    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
196    where
197        A: Clone,
198        S: DataOwned,
199    {
200        ParamsBase {
201            bias: self.bias.to_owned(),
202            weights: self.weights.to_owned(),
203        }
204    }
205    /// change the shape of the parameters; the shape of the bias parameters is determined by
206    /// removing the "zero-th" axis of the given shape
207    pub fn to_shape<Sh>(
208        &self,
209        shape: Sh,
210    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
211    where
212        A: Clone,
213        S: DataOwned,
214        Sh: ShapeBuilder,
215        Sh::Dim: Dimension + RemoveAxis,
216    {
217        let shape = shape.into_shape_with_order();
218        let dim = shape.raw_dim().clone();
219        let bias = self.bias.to_shape(dim.remove_axis(Axis(0)))?;
220        let weights = self.weights.to_shape(dim)?;
221        Ok(ParamsBase { bias, weights })
222    }
223    /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information
224    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
225    where
226        S: Data,
227    {
228        ParamsBase {
229            bias: self.bias.view(),
230            weights: self.weights.view(),
231        }
232    }
233    /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information
234    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
235    where
236        S: ndarray::DataMut,
237    {
238        ParamsBase {
239            bias: self.bias.view_mut(),
240            weights: self.weights.view_mut(),
241        }
242    }
243}
244
245impl<A, S> ParamsBase<S, Ix1>
246where
247    S: RawData<Elem = A>,
248{
249    pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
250    where
251        A: Clone,
252        S: DataOwned,
253    {
254        Self {
255            bias: ArrayBase::from_elem((), bias),
256            weights,
257        }
258    }
259
260    pub fn nrows(&self) -> usize {
261        self.weights.len()
262    }
263}
264
265impl<A, S> ParamsBase<S, Ix2>
266where
267    S: RawData<Elem = A>,
268{
269    pub fn ncols(&self) -> usize {
270        self.weights.ncols()
271    }
272
273    pub fn nrows(&self) -> usize {
274        self.weights.nrows()
275    }
276}
277
278impl<A, S, D> Clone for ParamsBase<S, D>
279where
280    D: Dimension,
281    S: ndarray::RawDataClone<Elem = A>,
282    A: Clone,
283{
284    fn clone(&self) -> Self {
285        Self {
286            bias: self.bias.clone(),
287            weights: self.weights.clone(),
288        }
289    }
290}
291
292impl<A, S, D> Copy for ParamsBase<S, D>
293where
294    D: Dimension + Copy,
295    <D as Dimension>::Smaller: Copy,
296    S: ndarray::RawDataClone<Elem = A> + Copy,
297    A: Copy,
298{
299}
300
301impl<A, S, D> PartialEq for ParamsBase<S, D>
302where
303    D: Dimension,
304    S: Data<Elem = A>,
305    A: PartialEq,
306{
307    fn eq(&self, other: &Self) -> bool {
308        self.bias == other.bias && self.weights == other.weights
309    }
310}
311
312impl<A, S, D> Eq for ParamsBase<S, D>
313where
314    D: Dimension,
315    S: Data<Elem = A>,
316    A: Eq,
317{
318}
319
320impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
321where
322    D: Dimension,
323    S: Data<Elem = A>,
324    A: core::fmt::Debug,
325{
326    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
327        f.debug_struct("ModelParams")
328            .field("bias", &self.bias)
329            .field("weights", &self.weights)
330            .finish()
331    }
332}
333
334// impl<A, S, D> PartialOrd for ModelParams<S, D>
335// where
336//     D: Dimension,
337//     S: Data<Elem = A>,
338//     A: PartialOrd,
339// {
340//     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
341//         match self.bias.iter().partial_cmp(&other.bias.iter()) {
342//             Some(core::cmp::Ordering::Equal) => self.weights.iter().partial_cmp(&other.weights.iter()),
343//             other => other,
344//         }
345//     }
346// }