concision_core/params/
params.rs

1/*
2    Appellation: params <module>
3    Contrib: @FL03
4*/
5use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8/// this structure extends the `ArrayBase` type to include bias
9pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11    D: Dimension,
12    S: RawData,
13{
14    pub(crate) bias: ArrayBase<S, D::Smaller>,
15    pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20    D: Dimension,
21    S: RawData<Elem = A>,
22{
23    pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24    where
25        A: Clone,
26        S: DataOwned,
27    {
28        Self { bias, weights }
29    }
30
31    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32    where
33        A: Clone,
34        D: RemoveAxis,
35        S: DataOwned,
36        Sh: ShapeBuilder<Dim = D>,
37    {
38        let weights = ArrayBase::from_elem(shape, elem.clone());
39        let dim = weights.raw_dim();
40        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41        Self { bias, weights }
42    }
43    /// create an instance of the parameters with all values set to the default value
44    pub fn default<Sh>(shape: Sh) -> Self
45    where
46        A: Clone + Default,
47        D: RemoveAxis,
48        S: DataOwned,
49        Sh: ShapeBuilder<Dim = D>,
50    {
51        let weights = ArrayBase::default(shape);
52        let dim = weights.raw_dim();
53        let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54        Self { bias, weights }
55    }
56    /// initialize the parameters with all values set to zero
57    pub fn ones<Sh>(shape: Sh) -> Self
58    where
59        A: Clone + num_traits::One,
60        D: RemoveAxis,
61        S: DataOwned,
62        Sh: ShapeBuilder<Dim = D>,
63    {
64        let weights = ArrayBase::ones(shape);
65        let dim = weights.raw_dim();
66        let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67        Self { bias, weights }
68    }
69    /// create an instance of the parameters with all values set to zero
70    pub fn zeros<Sh>(shape: Sh) -> Self
71    where
72        A: Clone + num_traits::Zero,
73        D: RemoveAxis,
74        S: DataOwned,
75        Sh: ShapeBuilder<Dim = D>,
76    {
77        let weights = ArrayBase::zeros(shape);
78        let dim = weights.raw_dim();
79        let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80        Self { bias, weights }
81    }
82    /// returns an immutable reference to the bias
83    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84        &self.bias
85    }
86    /// returns a mutable reference to the bias
87    #[inline]
88    pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89        &mut self.bias
90    }
91    /// returns an immutable reference to the weights
92    pub const fn weights(&self) -> &ArrayBase<S, D> {
93        &self.weights
94    }
95    /// returns a mutable reference to the weights
96    #[inline]
97    pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98        &mut self.weights
99    }
100
101    /// assign the bias
102    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
103    where
104        A: Clone,
105        S: DataMut,
106    {
107        self.bias_mut().assign(bias);
108        self
109    }
110    /// assign the weights
111    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
112    where
113        A: Clone,
114        S: DataMut,
115    {
116        self.weights_mut().assign(weights);
117        self
118    }
119    /// replace the bias and return the previous state; uses [replace](core::mem::replace)
120    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
121        core::mem::replace(&mut self.bias, bias)
122    }
123    /// replace the weights and return the previous state; uses [replace](core::mem::replace)
124    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
125        core::mem::replace(&mut self.weights, weights)
126    }
127    /// set the bias
128    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
129        *self.bias_mut() = bias;
130        self
131    }
132    /// set the weights
133    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
134        *self.weights_mut() = weights;
135        self
136    }
137    /// perform a single backpropagation step
138    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
139    where
140        A: Clone,
141        S: Data,
142        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
143    {
144        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
145    }
146    /// forward propagation
147    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
148    where
149        A: Clone,
150        S: Data,
151        Self: crate::Forward<X, Output = Y>,
152    {
153        <Self as crate::Forward<X>>::forward(self, input)
154    }
155    /// returns the dimensions of the weights
156    pub fn dim(&self) -> D::Pattern {
157        self.weights().dim()
158    }
159    /// an iterator of the parameters; the created iterator zips together an axis iterator over
160    /// the columns of the weights and an iterator over the bias
161    pub fn iter(&self) -> super::iter::Iter<'_, A, D>
162    where
163        D: RemoveAxis,
164        S: Data,
165    {
166        super::iter::Iter {
167            bias: self.bias().iter(),
168            weights: self.weights().axis_iter(Axis(1)),
169        }
170    }
171    /// a mutable iterator of the parameters
172    pub fn iter_mut(
173        &mut self,
174    ) -> core::iter::Zip<
175        ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
176        ndarray::iter::IterMut<'_, A, D::Smaller>,
177    >
178    where
179        D: RemoveAxis,
180        S: DataMut,
181    {
182        self.weights
183            .axis_iter_mut(Axis(1))
184            .zip(self.bias.iter_mut())
185    }
186    /// returns an iterator over the bias
187    pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
188    where
189        S: Data,
190    {
191        self.bias().iter()
192    }
193    /// returns a mutable iterator over the bias
194    pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
195    where
196        S: DataMut,
197    {
198        self.bias_mut().iter_mut()
199    }
200    /// returns an iterator over the weights
201    pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
202    where
203        S: Data,
204    {
205        self.weights().iter()
206    }
207    /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
208    pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
209    where
210        S: DataMut,
211    {
212        self.weights_mut().iter_mut()
213    }
214    /// returns true if both the weights and bias are empty; uses [`is_empty`](ArrayBase::is_empty)
215    pub fn is_empty(&self) -> bool {
216        self.is_weights_empty() && self.is_bias_empty()
217    }
218    /// returns true if the weights are empty
219    pub fn is_weights_empty(&self) -> bool {
220        self.weights().is_empty()
221    }
222    /// returns true if the bias is empty
223    pub fn is_bias_empty(&self) -> bool {
224        self.bias().is_empty()
225    }
226    /// the total number of elements within the weight tensor
227    pub fn len(&self) -> usize {
228        self.weights().len()
229    }
230    /// returns the raw dimensions of the weights;
231    pub fn raw_dim(&self) -> D {
232        self.weights().raw_dim()
233    }
234    /// returns the shape of the parameters; uses the shape of the weight tensor
235    pub fn shape(&self) -> &[usize] {
236        self.weights().shape()
237    }
238    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
239    /// weight tensor minus the "zero-th" axis
240    pub fn shape_bias(&self) -> &[usize] {
241        self.bias().shape()
242    }
243    /// returns an owned instance of the parameters
244    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
245    where
246        A: Clone,
247        S: DataOwned,
248    {
249        ParamsBase {
250            bias: self.bias().to_owned(),
251            weights: self.weights().to_owned(),
252        }
253    }
254    /// change the shape of the parameters; the shape of the bias parameters is determined by
255    /// removing the "zero-th" axis of the given shape
256    pub fn to_shape<Sh>(
257        &self,
258        shape: Sh,
259    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
260    where
261        A: Clone,
262        S: DataOwned,
263        Sh: ShapeBuilder,
264        Sh::Dim: Dimension + RemoveAxis,
265    {
266        let shape = shape.into_shape_with_order();
267        let dim = shape.raw_dim().clone();
268        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
269        let weights = self.weights().to_shape(dim)?;
270        Ok(ParamsBase { bias, weights })
271    }
272    /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information
273    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
274    where
275        S: Data,
276    {
277        ParamsBase {
278            bias: self.bias().view(),
279            weights: self.weights().view(),
280        }
281    }
282    /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information
283    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
284    where
285        S: ndarray::DataMut,
286    {
287        ParamsBase {
288            bias: self.bias.view_mut(),
289            weights: self.weights.view_mut(),
290        }
291    }
292}
293
294impl<A, S> ParamsBase<S, Ix1>
295where
296    S: RawData<Elem = A>,
297{
298    pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
299    where
300        A: Clone,
301        S: DataOwned,
302    {
303        Self {
304            bias: ArrayBase::from_elem((), bias),
305            weights,
306        }
307    }
308
309    pub fn nrows(&self) -> usize {
310        self.weights.len()
311    }
312}
313
314impl<A, S> ParamsBase<S, Ix2>
315where
316    S: RawData<Elem = A>,
317{
318    pub fn ncols(&self) -> usize {
319        self.weights.ncols()
320    }
321
322    pub fn nrows(&self) -> usize {
323        self.weights.nrows()
324    }
325}
326
327impl<A, S, D> Clone for ParamsBase<S, D>
328where
329    D: Dimension,
330    S: ndarray::RawDataClone<Elem = A>,
331    A: Clone,
332{
333    fn clone(&self) -> Self {
334        Self {
335            bias: self.bias.clone(),
336            weights: self.weights.clone(),
337        }
338    }
339}
340
341impl<A, S, D> Copy for ParamsBase<S, D>
342where
343    D: Dimension + Copy,
344    <D as Dimension>::Smaller: Copy,
345    S: ndarray::RawDataClone<Elem = A> + Copy,
346    A: Copy,
347{
348}
349
350impl<A, S, D> PartialEq for ParamsBase<S, D>
351where
352    D: Dimension,
353    S: Data<Elem = A>,
354    A: PartialEq,
355{
356    fn eq(&self, other: &Self) -> bool {
357        self.bias == other.bias && self.weights == other.weights
358    }
359}
360
361impl<A, S, D> Eq for ParamsBase<S, D>
362where
363    D: Dimension,
364    S: Data<Elem = A>,
365    A: Eq,
366{
367}
368
369impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
370where
371    D: Dimension,
372    S: Data<Elem = A>,
373    A: core::fmt::Debug,
374{
375    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
376        f.debug_struct("ModelParams")
377            .field("bias", &self.bias)
378            .field("weights", &self.weights)
379            .finish()
380    }
381}
382
383// impl<A, S, D> PartialOrd for ModelParams<S, D>
384// where
385//     D: Dimension,
386//     S: Data<Elem = A>,
387//     A: PartialOrd,
388// {
389//     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
390//         match self.bias.iter().partial_cmp(&other.bias.iter()) {
391//             Some(core::cmp::Ordering::Equal) => self.weights.iter().partial_cmp(&other.weights.iter()),
392//             other => other,
393//         }
394//     }
395// }