concision_core/params/
params.rs

1/*
2    Appellation: params <module>
3    Contrib: @FL03
4*/
5use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8/// this structure extends the `ArrayBase` type to include bias
9pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11    D: Dimension,
12    S: RawData,
13{
14    pub(crate) bias: ArrayBase<S, D::Smaller>,
15    pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20    D: Dimension,
21    S: RawData<Elem = A>,
22{
23    pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24    where
25        A: Clone,
26        S: DataOwned,
27    {
28        Self { bias, weights }
29    }
30
31    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32    where
33        A: Clone,
34        D: RemoveAxis,
35        S: DataOwned,
36        Sh: ShapeBuilder<Dim = D>,
37    {
38        let weights = ArrayBase::from_elem(shape, elem.clone());
39        let dim = weights.raw_dim();
40        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41        Self { bias, weights }
42    }
43    /// create an instance of the parameters with all values set to the default value
44    pub fn default<Sh>(shape: Sh) -> Self
45    where
46        A: Clone + Default,
47        D: RemoveAxis,
48        S: DataOwned,
49        Sh: ShapeBuilder<Dim = D>,
50    {
51        let weights = ArrayBase::default(shape);
52        let dim = weights.raw_dim();
53        let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54        Self { bias, weights }
55    }
56    /// initialize the parameters with all values set to zero
57    pub fn ones<Sh>(shape: Sh) -> Self
58    where
59        A: Clone + num_traits::One,
60        D: RemoveAxis,
61        S: DataOwned,
62        Sh: ShapeBuilder<Dim = D>,
63    {
64        let weights = ArrayBase::ones(shape);
65        let dim = weights.raw_dim();
66        let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67        Self { bias, weights }
68    }
69    /// create an instance of the parameters with all values set to zero
70    pub fn zeros<Sh>(shape: Sh) -> Self
71    where
72        A: Clone + num_traits::Zero,
73        D: RemoveAxis,
74        S: DataOwned,
75        Sh: ShapeBuilder<Dim = D>,
76    {
77        let weights = ArrayBase::zeros(shape);
78        let dim = weights.raw_dim();
79        let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80        Self { bias, weights }
81    }
82    /// returns an immutable reference to the bias
83    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84        &self.bias
85    }
86    /// returns a mutable reference to the bias
87    #[inline]
88    pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89        &mut self.bias
90    }
91    /// returns an immutable reference to the weights
92    pub const fn weights(&self) -> &ArrayBase<S, D> {
93        &self.weights
94    }
95    /// returns a mutable reference to the weights
96    #[inline]
97    pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98        &mut self.weights
99    }
100
101    /// assign the bias
102    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
103    where
104        A: Clone,
105        S: DataMut,
106    {
107        self.bias_mut().assign(&bias);
108        self
109    }
110    /// assign the weights
111    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
112    where
113        A: Clone,
114        S: DataMut,
115    {
116        self.weights_mut().assign(&weights);
117        self
118    }
119    /// replace the bias and return the previous state; uses [replace](core::mem::replace)
120    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
121        core::mem::replace(&mut self.bias, bias)
122    }
123    /// replace the weights and return the previous state; uses [replace](core::mem::replace)
124    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
125        core::mem::replace(&mut self.weights, weights)
126    }
127    /// set the bias
128    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
129        *self.bias_mut() = bias;
130        self
131    }
132    /// set the weights
133    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
134        *self.weights_mut() = weights;
135        self
136    }
137    /// perform a single backpropagation step
138    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
139    where
140        A: Clone,
141        S: Data,
142        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
143    {
144        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
145    }
146    /// forward propagation
147    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
148    where
149        A: Clone,
150        S: Data,
151        Self: crate::Forward<X, Output = Y>,
152    {
153        <Self as crate::Forward<X>>::forward(self, input)
154    }
155    /// returns the dimensions of the weights
156    pub fn dim(&self) -> D::Pattern {
157        self.weights().dim()
158    }
159    /// an iterator of the parameters; the created iterator zips together an axis iterator over
160    /// the columns of the weights and an iterator over the bias
161    pub fn iter(&self) -> super::iter::Iter<'_, A, D>
162    where
163        D: RemoveAxis,
164        S: Data,
165    {
166        super::iter::Iter {
167            bias: self.bias().iter(),
168            weights: self.weights().axis_iter(Axis(1)),
169        }
170    }
171    /// a mutable iterator of the parameters
172    pub fn iter_mut(
173        &mut self,
174    ) -> core::iter::Zip<
175        ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
176        ndarray::iter::IterMut<'_, A, D::Smaller>,
177    >
178    where
179        D: RemoveAxis,
180        S: DataMut,
181    {
182        self.weights
183            .axis_iter_mut(Axis(1))
184            .zip(self.bias.iter_mut())
185    }
186    /// returns an iterator over the bias
187    pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
188    where
189        S: Data,
190    {
191        self.bias().iter()
192    }
193    /// returns a mutable iterator over the bias
194    pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
195    where
196        S: DataMut,
197    {
198        self.bias_mut().iter_mut()
199    }
200    /// returns an iterator over the weights
201    pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
202    where
203        S: Data,
204    {
205        self.weights().iter()
206    }
207    /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
208    pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
209    where
210        S: DataMut,
211    {
212        self.weights_mut().iter_mut()
213    }
214    /// the total number of elements within the weight tensor
215    pub fn len(&self) -> usize {
216        self.weights().len()
217    }
218    /// returns the raw dimensions of the weights;
219    pub fn raw_dim(&self) -> D {
220        self.weights().raw_dim()
221    }
222    /// returns the shape of the parameters; uses the shape of the weight tensor
223    pub fn shape(&self) -> &[usize] {
224        self.weights().shape()
225    }
226    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
227    /// weight tensor minus the "zero-th" axis
228    pub fn shape_bias(&self) -> &[usize] {
229        self.bias().shape()
230    }
231    /// returns an owned instance of the parameters
232    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
233    where
234        A: Clone,
235        S: DataOwned,
236    {
237        ParamsBase {
238            bias: self.bias().to_owned(),
239            weights: self.weights().to_owned(),
240        }
241    }
242    /// change the shape of the parameters; the shape of the bias parameters is determined by
243    /// removing the "zero-th" axis of the given shape
244    pub fn to_shape<Sh>(
245        &self,
246        shape: Sh,
247    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
248    where
249        A: Clone,
250        S: DataOwned,
251        Sh: ShapeBuilder,
252        Sh::Dim: Dimension + RemoveAxis,
253    {
254        let shape = shape.into_shape_with_order();
255        let dim = shape.raw_dim().clone();
256        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
257        let weights = self.weights().to_shape(dim)?;
258        Ok(ParamsBase { bias, weights })
259    }
260    /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information
261    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
262    where
263        S: Data,
264    {
265        ParamsBase {
266            bias: self.bias().view(),
267            weights: self.weights().view(),
268        }
269    }
270    /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information
271    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
272    where
273        S: ndarray::DataMut,
274    {
275        ParamsBase {
276            bias: self.bias.view_mut(),
277            weights: self.weights.view_mut(),
278        }
279    }
280}
281
282impl<A, S> ParamsBase<S, Ix1>
283where
284    S: RawData<Elem = A>,
285{
286    pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
287    where
288        A: Clone,
289        S: DataOwned,
290    {
291        Self {
292            bias: ArrayBase::from_elem((), bias),
293            weights,
294        }
295    }
296
297    pub fn nrows(&self) -> usize {
298        self.weights.len()
299    }
300}
301
302impl<A, S> ParamsBase<S, Ix2>
303where
304    S: RawData<Elem = A>,
305{
306    pub fn ncols(&self) -> usize {
307        self.weights.ncols()
308    }
309
310    pub fn nrows(&self) -> usize {
311        self.weights.nrows()
312    }
313}
314
315impl<A, S, D> Clone for ParamsBase<S, D>
316where
317    D: Dimension,
318    S: ndarray::RawDataClone<Elem = A>,
319    A: Clone,
320{
321    fn clone(&self) -> Self {
322        Self {
323            bias: self.bias.clone(),
324            weights: self.weights.clone(),
325        }
326    }
327}
328
329impl<A, S, D> Copy for ParamsBase<S, D>
330where
331    D: Dimension + Copy,
332    <D as Dimension>::Smaller: Copy,
333    S: ndarray::RawDataClone<Elem = A> + Copy,
334    A: Copy,
335{
336}
337
338impl<A, S, D> PartialEq for ParamsBase<S, D>
339where
340    D: Dimension,
341    S: Data<Elem = A>,
342    A: PartialEq,
343{
344    fn eq(&self, other: &Self) -> bool {
345        self.bias == other.bias && self.weights == other.weights
346    }
347}
348
349impl<A, S, D> Eq for ParamsBase<S, D>
350where
351    D: Dimension,
352    S: Data<Elem = A>,
353    A: Eq,
354{
355}
356
357impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
358where
359    D: Dimension,
360    S: Data<Elem = A>,
361    A: core::fmt::Debug,
362{
363    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
364        f.debug_struct("ModelParams")
365            .field("bias", &self.bias)
366            .field("weights", &self.weights)
367            .finish()
368    }
369}
370
371// impl<A, S, D> PartialOrd for ModelParams<S, D>
372// where
373//     D: Dimension,
374//     S: Data<Elem = A>,
375//     A: PartialOrd,
376// {
377//     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
378//         match self.bias.iter().partial_cmp(&other.bias.iter()) {
379//             Some(core::cmp::Ordering::Equal) => self.weights.iter().partial_cmp(&other.weights.iter()),
380//             other => other,
381//         }
382//     }
383// }