concision_params/impls/
impl_params.rs

1/*
2    Appellation: impl_params <module>
3    Created At: 2026.01.13:18:36:16
4    Contrib: @FL03
5*/
6use crate::Params;
7use crate::params_base::ParamsBase;
8use crate::utils::extract_bias_dim;
9use ndarray::{
10    ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, LayoutRef, RawData, RawRef, RemoveAxis,
11    ShapeArg, ShapeBuilder,
12};
13
14impl<A, S, D> ParamsBase<S, D, A>
15where
16    D: Dimension,
17    S: RawData<Elem = A>,
18{
19    /// create a new instance of the [`ParamsBase`] with the given bias and weights
20    pub const fn new(bias: ArrayBase<S, D::Smaller, A>, weights: ArrayBase<S, D, A>) -> Self {
21        Self { bias, weights }
22    }
23    /// returns a new instance of the [`ParamsBase`] using the initialization routine
24    pub fn init_from_fn<Sh, F>(shape: Sh, init: F) -> Self
25    where
26        A: Clone,
27        D: RemoveAxis,
28        S: DataOwned,
29        Sh: ShapeBuilder<Dim = D>,
30        F: Fn() -> A,
31    {
32        let weights = ArrayBase::from_shape_fn(shape, |_| init());
33        // initialize the bias using a shape that is 1 rank lower then the weights
34        let bias = ArrayBase::from_shape_fn(extract_bias_dim(&weights), |_| init());
35        // create a new instance from the generated bias and weights
36        Self::new(bias, weights)
37    }
38    /// returns a new instance of the [`ParamsBase`] initialized use the given shape_function
39    pub fn from_shape_fn<Sh, F1, F2>(shape: Sh, w: F1, b: F2) -> Self
40    where
41        A: Clone,
42        D: RemoveAxis,
43        S: DataOwned,
44        Sh: ShapeBuilder<Dim = D>,
45        D::Smaller: Dimension + ShapeArg,
46        F1: Fn(<D as Dimension>::Pattern) -> A,
47        F2: Fn(<D::Smaller as Dimension>::Pattern) -> A,
48    {
49        // initialize the weights with some shape using the given function
50        let weights = ArrayBase::from_shape_fn(shape, w);
51        // initialize the bias tensor w.r.t. the weights
52        let bias = ArrayBase::from_shape_fn(extract_bias_dim(&weights), b);
53        // return a new instance
54        Self::new(bias, weights)
55    }
56    /// create a new instance of the [`ParamsBase`] with the given bias used the default weights
57    pub fn from_bias<Sh>(shape: Sh, bias: ArrayBase<S, D::Smaller, A>) -> Self
58    where
59        A: Clone + Default,
60        D: RemoveAxis,
61        S: DataOwned,
62        Sh: ShapeBuilder<Dim = D>,
63    {
64        let weights = ArrayBase::from_elem(shape, A::default());
65        let bdim = extract_bias_dim(&weights);
66        if bias.raw_dim() != bdim {
67            panic!("the given bias shape is invalid");
68        }
69        Self::new(bias, weights)
70    }
71    /// create a new instance of the [`ParamsBase`] with the given weights used the default
72    /// bias
73    pub fn from_weights(weights: ArrayBase<S, D, A>) -> Self
74    where
75        A: Clone + Default,
76        D: RemoveAxis,
77        S: DataOwned,
78    {
79        let bias = ArrayBase::from_elem(extract_bias_dim(&weights), A::default());
80        Self::new(bias, weights)
81    }
82    /// create a new instance of the [`ParamsBase`] from the given shape and element;
83    pub fn from_elem<Sh: ShapeBuilder<Dim = D>>(shape: Sh, elem: A) -> Self
84    where
85        A: Clone,
86        D: RemoveAxis,
87        S: DataOwned,
88    {
89        let weights = ArrayBase::from_elem(shape, elem.clone());
90        let bias = ArrayBase::from_elem(extract_bias_dim(&weights), elem);
91        Self::new(bias, weights)
92    }
93    #[allow(clippy::should_implement_trait)]
94    /// create an instance of the parameters with all values set to the default value
95    pub fn default<Sh>(shape: Sh) -> Self
96    where
97        A: Clone + Default,
98        D: RemoveAxis,
99        S: DataOwned,
100        Sh: ShapeBuilder<Dim = D>,
101    {
102        Self::from_elem(shape, A::default())
103    }
104    /// initialize the parameters with all values set to zero
105    pub fn ones<Sh>(shape: Sh) -> Self
106    where
107        A: Clone + num_traits::One,
108        D: RemoveAxis,
109        S: DataOwned,
110        Sh: ShapeBuilder<Dim = D>,
111    {
112        Self::from_elem(shape, A::one())
113    }
114    /// create an instance of the parameters with all values set to zero
115    pub fn zeros<Sh>(shape: Sh) -> Self
116    where
117        A: Clone + num_traits::Zero,
118        D: RemoveAxis,
119        S: DataOwned,
120        Sh: ShapeBuilder<Dim = D>,
121    {
122        Self::from_elem(shape, A::zero())
123    }
124    /// returns an immutable reference to the bias
125    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller, A> {
126        &self.bias
127    }
128    /// returns a mutable reference to the bias
129    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
130        &mut self.bias
131    }
132    /// returns an immutable reference to the weights
133    pub const fn weights(&self) -> &ArrayBase<S, D, A> {
134        &self.weights
135    }
136    /// returns a mutable reference to the weights
137    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
138        &mut self.weights
139    }
140    /// returns a raw reference to the bias tensor
141    pub fn bias_as_raw_ref(&self) -> &RawRef<A, D::Smaller>
142    where
143        S: Data,
144    {
145        self.bias().as_raw_ref()
146    }
147    /// returns a raw reference of the weights
148    pub fn weights_as_raw_ref(&self) -> &RawRef<A, D>
149    where
150        S: Data,
151    {
152        self.weights().as_raw_ref()
153    }
154    /// returns an immutable rererence to the bias as a layout reference
155    pub fn bias_layout_ref(&self) -> &LayoutRef<A, D::Smaller>
156    where
157        S: Data,
158    {
159        self.bias().as_layout_ref()
160    }
161    /// returns a mutable rererence to the weights as a layout reference
162    pub fn bias_layout_ref_mut(&mut self) -> &mut LayoutRef<A, D::Smaller>
163    where
164        S: DataMut,
165    {
166        self.bias_mut().as_layout_ref_mut()
167    }
168    /// returns an immutable rererence to the weights as a layout reference
169    pub fn weights_layout_ref(&self) -> &LayoutRef<A, D>
170    where
171        S: Data,
172    {
173        self.weights().as_layout_ref()
174    }
175    /// returns a mutable rererence to the weights as a layout reference
176    pub fn weights_layout_ref_mut(&mut self) -> &mut LayoutRef<A, D>
177    where
178        S: DataMut,
179    {
180        self.weights_mut().as_layout_ref_mut()
181    }
182    /// assign the bias
183    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller, A>) -> &mut Self
184    where
185        A: Clone,
186        S: DataMut,
187    {
188        self.bias_mut().assign(bias);
189        self
190    }
191    /// assign the weights
192    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D, A>) -> &mut Self
193    where
194        A: Clone,
195        S: DataMut,
196    {
197        self.weights_mut().assign(weights);
198        self
199    }
200    /// replace the bias and return the previous state; uses [replace](core::mem::replace)
201    pub fn replace_bias(
202        &mut self,
203        bias: ArrayBase<S, D::Smaller, A>,
204    ) -> ArrayBase<S, D::Smaller, A> {
205        core::mem::replace(&mut self.bias, bias)
206    }
207    /// replace the weights and return the previous state; uses [replace](core::mem::replace)
208    pub fn replace_weights(&mut self, weights: ArrayBase<S, D, A>) -> ArrayBase<S, D, A> {
209        core::mem::replace(&mut self.weights, weights)
210    }
211    /// set the bias
212    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> &mut Self {
213        *self.bias_mut() = bias;
214        self
215    }
216    /// set the weights
217    pub fn set_weights(&mut self, weights: ArrayBase<S, D, A>) -> &mut Self {
218        *self.weights_mut() = weights;
219        self
220    }
221    /// returns the dimensions of the weights
222    pub fn dim(&self) -> D::Pattern {
223        self.weights().dim()
224    }
225    /// returns true if both the weights and bias are empty; uses [`is_empty`](ArrayBase::is_empty)
226    pub fn is_empty(&self) -> bool {
227        self.is_weights_empty() && self.is_bias_empty()
228    }
229    /// returns true if the weights are empty
230    pub fn is_weights_empty(&self) -> bool {
231        self.weights().is_empty()
232    }
233    /// returns true if the bias is empty
234    pub fn is_bias_empty(&self) -> bool {
235        self.bias().is_empty()
236    }
237    /// the total number of elements within the weight tensor
238    pub fn count_weights(&self) -> usize {
239        self.weights().len()
240    }
241    /// the total number of elements within the bias tensor
242    pub fn count_bias(&self) -> usize {
243        self.bias().len()
244    }
245    /// returns the raw dimensions of the weights;
246    pub fn raw_dim(&self) -> D {
247        self.weights().raw_dim()
248    }
249    /// returns the shape of the parameters; uses the shape of the weight tensor
250    pub fn shape<'a>(&'a self) -> &'a [usize]
251    where
252        A: 'a,
253    {
254        self.weights.shape()
255    }
256    /// returns the shape of the bias tensor; the shape should be equivalent to that of the
257    /// weight tensor minus the "zero-th" axis
258    pub fn shape_bias(&self) -> &[usize]
259    where
260        A: 'static,
261    {
262        self.bias.shape()
263    }
264    /// returns the total number of parameters within the layer
265    pub fn size(&self) -> usize {
266        self.weights().len() + self.bias().len()
267    }
268    /// returns an owned instance of the parameters
269    pub fn to_owned(&self) -> Params<A, D>
270    where
271        A: Clone,
272        S: DataOwned,
273    {
274        ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
275    }
276    /// change the shape of the parameters; the shape of the bias parameters is determined by
277    /// removing the "zero-th" axis of the given shape
278    pub fn to_shape<Sh>(
279        &self,
280        shape: Sh,
281    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
282    where
283        A: Clone,
284        S: DataOwned,
285        Sh: ShapeBuilder<Dim = D>,
286        Sh::Dim: Dimension + RemoveAxis,
287    {
288        let shape = shape.into_shape_with_order();
289        let dim = shape.raw_dim().clone();
290        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
291        let weights = self.weights().to_shape(dim)?;
292        Ok(ParamsBase::new(bias, weights))
293    }
294    /// returns a new [`ParamsBase`] instance with the same paramaters, but using a shared
295    /// representation of the data;
296    pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
297    where
298        A: Clone,
299        S: Data,
300    {
301        ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
302    }
303    /// returns a "view" of the parameters; see [`view`](ndarray::ViewRepr) for more information
304    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
305    where
306        S: Data,
307    {
308        ParamsBase::new(self.bias().view(), self.weights().view())
309    }
310    /// returns mutable view of the parameters
311    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
312    where
313        S: DataMut,
314    {
315        ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
316    }
317    /// clamps all values within the parameters between the given min and max values
318    pub fn clamp(&mut self, min: A, max: A) -> Params<A, D>
319    where
320        A: 'static + Clone + PartialOrd,
321        S: Data,
322    {
323        ParamsBase {
324            bias: self.bias().clamp(min.clone(), max.clone()),
325            weights: self.weights().clamp(min, max),
326        }
327    }
328    /// applies the given function onto each element, capturing the results in a new instance
329    pub fn mapv<F, U>(&self, f: F) -> Params<U, D>
330    where
331        A: Clone,
332        S: Data,
333        F: Fn(A) -> U,
334    {
335        ParamsBase {
336            bias: self.bias().mapv(&f),
337            weights: self.weights().mapv(&f),
338        }
339    }
340}