concision_core/params/
params.rs

1/*
2    Appellation: params <module>
3    Contrib: @FL03
4*/
5use ndarray::{
6    ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeArg,
7    ShapeBuilder,
8};
9
10/// The [`ParamsBase`] struct is a generic container for a set of weights and biases for a
11/// model where the bias tensor is always `n-1` dimensions smaller than the `weights` tensor.
12/// Consequently, this constrains the [`ParamsBase`] implementation to only support dimensions
13/// that can be reduced by one axis (i.e. $`\mbox{rank}(D)>0`$), which is typically the "zero-th" axis.
14pub struct ParamsBase<S, D = ndarray::Ix2>
15where
16    D: Dimension,
17    S: RawData,
18{
19    pub(crate) bias: ArrayBase<S, D::Smaller>,
20    pub(crate) weights: ArrayBase<S, D>,
21}
22
23impl<A, S, D> ParamsBase<S, D>
24where
25    D: Dimension,
26    S: RawData<Elem = A>,
27{
28    /// create a new instance of the [`ParamsBase`] with the given bias and weights
29    pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
30        Self { bias, weights }
31    }
32    /// returns a new instance of the [`ParamsBase`] using the initialization routine
33    pub fn init_from_fn<Sh, F>(shape: Sh, init: F) -> Self
34    where
35        A: Clone,
36        D: RemoveAxis,
37        S: DataOwned,
38        Sh: ShapeBuilder<Dim = D>,
39        F: Fn() -> A,
40    {
41        let shape = shape.into_shape_with_order();
42        let bshape = shape.raw_dim().remove_axis(Axis(0));
43        // initialize the bias and weights using the provided function for each element
44        let bias = ArrayBase::from_shape_fn(bshape, |_| init());
45        let weights = ArrayBase::from_shape_fn(shape, |_| init());
46        // create a new instance from the generated bias and weights
47        Self::new(bias, weights)
48    }
49    /// returns a new instance of the [`ParamsBase`] initialized use the given shape_function
50    pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
51    where
52        A: Clone,
53        D: RemoveAxis,
54        S: DataOwned,
55        Sh: ShapeBuilder<Dim = D>,
56        D::Smaller: Dimension + ShapeArg,
57        F: Fn(<D::Smaller as Dimension>::Pattern) -> A + Fn(<D as Dimension>::Pattern) -> A,
58    {
59        let shape = shape.into_shape_with_order();
60        let bdim = shape.raw_dim().remove_axis(Axis(0));
61        let bias = ArrayBase::from_shape_fn(bdim, |s| f(s));
62        let weights = ArrayBase::from_shape_fn(shape, |s| f(s));
63        Self::new(bias, weights)
64    }
65    /// create a new instance of the [`ParamsBase`] with the given bias used the default weights
66    pub fn from_bias<Sh>(shape: Sh, bias: ArrayBase<S, D::Smaller>) -> Self
67    where
68        A: Clone + Default,
69        D: RemoveAxis,
70        S: DataOwned,
71        Sh: ShapeBuilder<Dim = D>,
72    {
73        let weights = ArrayBase::from_elem(shape, A::default());
74        Self::new(bias, weights)
75    }
76    /// create a new instance of the [`ParamsBase`] with the given weights used the default
77    /// bias
78    pub fn from_weights<Sh>(shape: Sh, weights: ArrayBase<S, D>) -> Self
79    where
80        A: Clone + Default,
81        D: RemoveAxis,
82        S: DataOwned,
83        Sh: ShapeBuilder<Dim = D>,
84    {
85        let shape = shape.into_shape_with_order();
86        let dim_bias = shape.raw_dim().remove_axis(Axis(0));
87        let bias = ArrayBase::from_elem(dim_bias, A::default());
88        Self::new(bias, weights)
89    }
90    /// create a new instance of the [`ParamsBase`] from the given shape and element;
91    pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
92    where
93        A: Clone,
94        D: RemoveAxis,
95        S: DataOwned,
96        Sh: ShapeBuilder<Dim = D>,
97    {
98        let weights = ArrayBase::from_elem(shape, elem.clone());
99        let dim = weights.raw_dim();
100        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
101        Self::new(bias, weights)
102    }
103    #[allow(clippy::should_implement_trait)]
104    /// create an instance of the parameters with all values set to the default value
105    pub fn default<Sh>(shape: Sh) -> Self
106    where
107        A: Clone + Default,
108        D: RemoveAxis,
109        S: DataOwned,
110        Sh: ShapeBuilder<Dim = D>,
111    {
112        Self::from_elem(shape, A::default())
113    }
114    /// initialize the parameters with all values set to zero
115    pub fn ones<Sh>(shape: Sh) -> Self
116    where
117        A: Clone + num_traits::One,
118        D: RemoveAxis,
119        S: DataOwned,
120        Sh: ShapeBuilder<Dim = D>,
121    {
122        Self::from_elem(shape, A::one())
123    }
124    /// create an instance of the parameters with all values set to zero
125    pub fn zeros<Sh>(shape: Sh) -> Self
126    where
127        A: Clone + num_traits::Zero,
128        D: RemoveAxis,
129        S: DataOwned,
130        Sh: ShapeBuilder<Dim = D>,
131    {
132        Self::from_elem(shape, A::zero())
133    }
134    /// returns an immutable reference to the bias
135    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
136        &self.bias
137    }
138    /// returns a mutable reference to the bias
139    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
140        &mut self.bias
141    }
142    /// returns an immutable reference to the weights
143    pub const fn weights(&self) -> &ArrayBase<S, D> {
144        &self.weights
145    }
146    /// returns a mutable reference to the weights
147    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
148        &mut self.weights
149    }
150    /// assign the bias
151    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
152    where
153        A: Clone,
154        S: DataMut,
155    {
156        self.bias_mut().assign(bias);
157        self
158    }
159    /// assign the weights
160    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
161    where
162        A: Clone,
163        S: DataMut,
164    {
165        self.weights_mut().assign(weights);
166        self
167    }
168    /// replace the bias and return the previous state; uses [replace](core::mem::replace)
169    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
170        core::mem::replace(&mut self.bias, bias)
171    }
172    /// replace the weights and return the previous state; uses [replace](core::mem::replace)
173    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
174        core::mem::replace(&mut self.weights, weights)
175    }
176    /// set the bias
177    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
178        *self.bias_mut() = bias;
179        self
180    }
181    /// set the weights
182    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
183        *self.weights_mut() = weights;
184        self
185    }
186    /// perform a single backpropagation step
187    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
188    where
189        A: Clone,
190        S: Data,
191        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
192    {
193        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
194    }
195    /// forward propagation
196    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
197    where
198        A: Clone,
199        S: Data,
200        Self: crate::Forward<X, Output = Y>,
201    {
202        <Self as crate::Forward<X>>::forward(self, input)
203    }
204    /// returns the dimensions of the weights
205    pub fn dim(&self) -> D::Pattern {
206        self.weights().dim()
207    }
208    /// returns true if both the weights and bias are empty; uses [`is_empty`](ArrayBase::is_empty)
209    pub fn is_empty(&self) -> bool {
210        self.is_weights_empty() && self.is_bias_empty()
211    }
212    /// returns true if the weights are empty
213    pub fn is_weights_empty(&self) -> bool {
214        self.weights().is_empty()
215    }
216    /// returns true if the bias is empty
217    pub fn is_bias_empty(&self) -> bool {
218        self.bias().is_empty()
219    }
220    /// the total number of elements within the weight tensor
221    pub fn count_weight(&self) -> usize {
222        self.weights().len()
223    }
224    /// the total number of elements within the bias tensor
225    pub fn count_bias(&self) -> usize {
226        self.bias().len()
227    }
228    /// returns the raw dimensions of the weights;
229    pub fn raw_dim(&self) -> D {
230        self.weights().raw_dim()
231    }
232    /// returns the shape of the parameters; uses the shape of the weight tensor
233    pub fn shape(&self) -> &[usize] {
234        self.weights().shape()
235    }
236    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
237    /// weight tensor minus the "zero-th" axis
238    pub fn shape_bias(&self) -> &[usize] {
239        self.bias().shape()
240    }
241    /// returns the total number of parameters within the layer
242    pub fn size(&self) -> usize {
243        self.weights().len() + self.bias().len()
244    }
245    /// returns an owned instance of the parameters
246    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
247    where
248        A: Clone,
249        S: DataOwned,
250    {
251        ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
252    }
253    /// change the shape of the parameters; the shape of the bias parameters is determined by
254    /// removing the "zero-th" axis of the given shape
255    pub fn to_shape<Sh>(
256        &self,
257        shape: Sh,
258    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
259    where
260        A: Clone,
261        S: DataOwned,
262        Sh: ShapeBuilder,
263        Sh::Dim: Dimension + RemoveAxis,
264    {
265        let shape = shape.into_shape_with_order();
266        let dim = shape.raw_dim().clone();
267        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
268        let weights = self.weights().to_shape(dim)?;
269        Ok(ParamsBase::new(bias, weights))
270    }
271    /// returns a new [`ParamsBase`] instance with the same paramaters, but using a shared
272    /// representation of the data;
273    pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
274    where
275        A: Clone,
276        S: Data,
277    {
278        ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
279    }
280    /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information
281    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
282    where
283        S: Data,
284    {
285        ParamsBase::new(self.bias().view(), self.weights().view())
286    }
287    /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information
288    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
289    where
290        S: ndarray::DataMut,
291    {
292        ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
293    }
294}