concision_core/params/
params.rs

1/*
2    Appellation: params <module>
3    Contrib: @FL03
4*/
5use ndarray::{
6    ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder,
7};
8
9/// The [`ParamsBase`] struct is a generic container for a set of weights and biases for a
10/// model. The implementation is designed around the [`ArrayBase`] type from the
11/// `ndarray` crate, which allows for flexible and efficient storage of multi-dimensional
12/// arrays.
13pub struct ParamsBase<S, D = ndarray::Ix2>
14where
15    D: Dimension,
16    S: RawData,
17{
18    pub(crate) bias: ArrayBase<S, D::Smaller>,
19    pub(crate) weights: ArrayBase<S, D>,
20}
21
22impl<A, S, D> ParamsBase<S, D>
23where
24    D: Dimension,
25    S: RawData<Elem = A>,
26{
27    /// create a new instance of the [`ParamsBase`] with the given bias and weights
28    pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
29        Self { bias, weights }
30    }
31    /// create a new instance of the [`ModelParams`] from the given shape and element;
32    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
33    where
34        A: Clone,
35        D: RemoveAxis,
36        S: DataOwned,
37        Sh: ShapeBuilder<Dim = D>,
38    {
39        let weights = ArrayBase::from_elem(shape, elem.clone());
40        let dim = weights.raw_dim();
41        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
42        Self::new(bias, weights)
43    }
44    #[allow(clippy::should_implement_trait)]
45    /// create an instance of the parameters with all values set to the default value
46    pub fn default<Sh>(shape: Sh) -> Self
47    where
48        A: Clone + Default,
49        D: RemoveAxis,
50        S: DataOwned,
51        Sh: ShapeBuilder<Dim = D>,
52    {
53        Self::from_elems(shape, A::default())
54    }
55    /// initialize the parameters with all values set to zero
56    pub fn ones<Sh>(shape: Sh) -> Self
57    where
58        A: Clone + num_traits::One,
59        D: RemoveAxis,
60        S: DataOwned,
61        Sh: ShapeBuilder<Dim = D>,
62    {
63        Self::from_elems(shape, A::one())
64    }
65    /// create an instance of the parameters with all values set to zero
66    pub fn zeros<Sh>(shape: Sh) -> Self
67    where
68        A: Clone + num_traits::Zero,
69        D: RemoveAxis,
70        S: DataOwned,
71        Sh: ShapeBuilder<Dim = D>,
72    {
73        Self::from_elems(shape, A::zero())
74    }
75    /// returns an immutable reference to the bias
76    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
77        &self.bias
78    }
79    /// returns a mutable reference to the bias
80    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
81        &mut self.bias
82    }
83    /// returns an immutable reference to the weights
84    pub const fn weights(&self) -> &ArrayBase<S, D> {
85        &self.weights
86    }
87    /// returns a mutable reference to the weights
88    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
89        &mut self.weights
90    }
91    /// assign the bias
92    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
93    where
94        A: Clone,
95        S: DataMut,
96    {
97        self.bias_mut().assign(bias);
98        self
99    }
100    /// assign the weights
101    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
102    where
103        A: Clone,
104        S: DataMut,
105    {
106        self.weights_mut().assign(weights);
107        self
108    }
109    /// replace the bias and return the previous state; uses [replace](core::mem::replace)
110    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
111        core::mem::replace(&mut self.bias, bias)
112    }
113    /// replace the weights and return the previous state; uses [replace](core::mem::replace)
114    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
115        core::mem::replace(&mut self.weights, weights)
116    }
117    /// set the bias
118    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
119        *self.bias_mut() = bias;
120        self
121    }
122    /// set the weights
123    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
124        *self.weights_mut() = weights;
125        self
126    }
127    /// perform a single backpropagation step
128    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
129    where
130        A: Clone,
131        S: Data,
132        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
133    {
134        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
135    }
136    /// forward propagation
137    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
138    where
139        A: Clone,
140        S: Data,
141        Self: crate::Forward<X, Output = Y>,
142    {
143        <Self as crate::Forward<X>>::forward(self, input)
144    }
145    /// returns the dimensions of the weights
146    pub fn dim(&self) -> D::Pattern {
147        self.weights().dim()
148    }
149    /// returns true if both the weights and bias are empty; uses [`is_empty`](ArrayBase::is_empty)
150    pub fn is_empty(&self) -> bool {
151        self.is_weights_empty() && self.is_bias_empty()
152    }
153    /// returns true if the weights are empty
154    pub fn is_weights_empty(&self) -> bool {
155        self.weights().is_empty()
156    }
157    /// returns true if the bias is empty
158    pub fn is_bias_empty(&self) -> bool {
159        self.bias().is_empty()
160    }
161    /// the total number of elements within the weight tensor
162    pub fn count_weight(&self) -> usize {
163        self.weights().len()
164    }
165    /// the total number of elements within the bias tensor
166    pub fn count_bias(&self) -> usize {
167        self.bias().len()
168    }
169    /// returns the raw dimensions of the weights;
170    pub fn raw_dim(&self) -> D {
171        self.weights().raw_dim()
172    }
173    /// returns the shape of the parameters; uses the shape of the weight tensor
174    pub fn shape(&self) -> &[usize] {
175        self.weights().shape()
176    }
177    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
178    /// weight tensor minus the "zero-th" axis
179    pub fn shape_bias(&self) -> &[usize] {
180        self.bias().shape()
181    }
182    /// returns the total number of parameters within the layer
183    pub fn size(&self) -> usize {
184        self.weights().len() + self.bias().len()
185    }
186    /// returns an owned instance of the parameters
187    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
188    where
189        A: Clone,
190        S: DataOwned,
191    {
192        ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
193    }
194    /// change the shape of the parameters; the shape of the bias parameters is determined by
195    /// removing the "zero-th" axis of the given shape
196    pub fn to_shape<Sh>(
197        &self,
198        shape: Sh,
199    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
200    where
201        A: Clone,
202        S: DataOwned,
203        Sh: ShapeBuilder,
204        Sh::Dim: Dimension + RemoveAxis,
205    {
206        let shape = shape.into_shape_with_order();
207        let dim = shape.raw_dim().clone();
208        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
209        let weights = self.weights().to_shape(dim)?;
210        Ok(ParamsBase::new(bias, weights))
211    }
212    /// returns a new [`ParamsBase`] instance with the same paramaters, but using a shared
213    /// representation of the data;
214    pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
215    where
216        A: Clone,
217        S: Data,
218    {
219        ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
220    }
221    /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information
222    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
223    where
224        S: Data,
225    {
226        ParamsBase::new(self.bias().view(), self.weights().view())
227    }
228    /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information
229    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
230    where
231        S: ndarray::DataMut,
232    {
233        ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
234    }
235}