1use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8pub struct ParamsBase<S, D = ndarray::Ix2>
13where
14    D: Dimension,
15    S: RawData,
16{
17    pub(crate) bias: ArrayBase<S, D::Smaller>,
18    pub(crate) weights: ArrayBase<S, D>,
19}
20
21impl<A, S, D> ParamsBase<S, D>
22where
23    D: Dimension,
24    S: RawData<Elem = A>,
25{
26    pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
28        Self { bias, weights }
29    }
30    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32    where
33        A: Clone,
34        D: RemoveAxis,
35        S: DataOwned,
36        Sh: ShapeBuilder<Dim = D>,
37    {
38        let weights = ArrayBase::from_elem(shape, elem.clone());
39        let dim = weights.raw_dim();
40        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41        Self::new(bias, weights)
42    }
43    pub fn default<Sh>(shape: Sh) -> Self
45    where
46        A: Clone + Default,
47        D: RemoveAxis,
48        S: DataOwned,
49        Sh: ShapeBuilder<Dim = D>,
50    {
51        Self::from_elems(shape, A::default())
52    }
53    pub fn ones<Sh>(shape: Sh) -> Self
55    where
56        A: Clone + num_traits::One,
57        D: RemoveAxis,
58        S: DataOwned,
59        Sh: ShapeBuilder<Dim = D>,
60    {
61        Self::from_elems(shape, A::one())
62    }
63    pub fn zeros<Sh>(shape: Sh) -> Self
65    where
66        A: Clone + num_traits::Zero,
67        D: RemoveAxis,
68        S: DataOwned,
69        Sh: ShapeBuilder<Dim = D>,
70    {
71        Self::from_elems(shape, A::zero())
72    }
73    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
75        &self.bias
76    }
77    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
79        &mut self.bias
80    }
81    pub const fn weights(&self) -> &ArrayBase<S, D> {
83        &self.weights
84    }
85    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
87        &mut self.weights
88    }
89    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
91    where
92        A: Clone,
93        S: DataMut,
94    {
95        self.bias_mut().assign(bias);
96        self
97    }
98    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
100    where
101        A: Clone,
102        S: DataMut,
103    {
104        self.weights_mut().assign(weights);
105        self
106    }
107    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
109        core::mem::replace(&mut self.bias, bias)
110    }
111    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
113        core::mem::replace(&mut self.weights, weights)
114    }
115    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
117        *self.bias_mut() = bias;
118        self
119    }
120    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
122        *self.weights_mut() = weights;
123        self
124    }
125    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
127    where
128        A: Clone,
129        S: Data,
130        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
131    {
132        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
133    }
134    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
136    where
137        A: Clone,
138        S: Data,
139        Self: crate::Forward<X, Output = Y>,
140    {
141        <Self as crate::Forward<X>>::forward(self, input)
142    }
143    pub fn dim(&self) -> D::Pattern {
145        self.weights().dim()
146    }
147    pub fn iter(&self) -> super::iter::Iter<'_, A, D>
150    where
151        D: RemoveAxis,
152        S: Data,
153    {
154        super::iter::Iter {
155            bias: self.bias().iter(),
156            weights: self.weights().axis_iter(Axis(1)),
157        }
158    }
159    pub fn iter_mut(
161        &mut self,
162    ) -> core::iter::Zip<
163        ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
164        ndarray::iter::IterMut<'_, A, D::Smaller>,
165    >
166    where
167        D: RemoveAxis,
168        S: DataMut,
169    {
170        self.weights
171            .axis_iter_mut(Axis(1))
172            .zip(self.bias.iter_mut())
173    }
174    pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
176    where
177        S: Data,
178    {
179        self.bias().iter()
180    }
181    pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
183    where
184        S: DataMut,
185    {
186        self.bias_mut().iter_mut()
187    }
188    pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
190    where
191        S: Data,
192    {
193        self.weights().iter()
194    }
195    pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
197    where
198        S: DataMut,
199    {
200        self.weights_mut().iter_mut()
201    }
202    pub fn is_empty(&self) -> bool {
204        self.is_weights_empty() && self.is_bias_empty()
205    }
206    pub fn is_weights_empty(&self) -> bool {
208        self.weights().is_empty()
209    }
210    pub fn is_bias_empty(&self) -> bool {
212        self.bias().is_empty()
213    }
214    pub fn count_weight(&self) -> usize {
216        self.weights().len()
217    }
218    pub fn count_bias(&self) -> usize {
220        self.bias().len()
221    }
222    pub fn raw_dim(&self) -> D {
224        self.weights().raw_dim()
225    }
226    pub fn shape(&self) -> &[usize] {
228        self.weights().shape()
229    }
230    pub fn shape_bias(&self) -> &[usize] {
233        self.bias().shape()
234    }
235    pub fn size(&self) -> usize {
237        self.weights().len() + self.bias().len()
238    }
239    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
241    where
242        A: Clone,
243        S: DataOwned,
244    {
245        ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
246    }
247    pub fn to_shape<Sh>(
250        &self,
251        shape: Sh,
252    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
253    where
254        A: Clone,
255        S: DataOwned,
256        Sh: ShapeBuilder,
257        Sh::Dim: Dimension + RemoveAxis,
258    {
259        let shape = shape.into_shape_with_order();
260        let dim = shape.raw_dim().clone();
261        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
262        let weights = self.weights().to_shape(dim)?;
263        Ok(ParamsBase::new(bias, weights))
264    }
265    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
267    where
268        S: Data,
269    {
270        ParamsBase::new(self.bias().view(), self.weights().view())
271    }
272    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
274    where
275        S: ndarray::DataMut,
276    {
277        ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
278    }
279}
280
281impl<A, S> ParamsBase<S, Ix1>
282where
283    S: RawData<Elem = A>,
284{
285    pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
286    where
287        A: Clone,
288        S: DataOwned,
289    {
290        Self {
291            bias: ArrayBase::from_elem((), bias),
292            weights,
293        }
294    }
295
296    pub fn nrows(&self) -> usize {
297        self.weights.len()
298    }
299}
300
301impl<A, S> ParamsBase<S, Ix2>
302where
303    S: RawData<Elem = A>,
304{
305    pub fn ncols(&self) -> usize {
306        self.weights.ncols()
307    }
308
309    pub fn nrows(&self) -> usize {
310        self.weights.nrows()
311    }
312}
313
314impl<A, S, D> Clone for ParamsBase<S, D>
315where
316    D: Dimension,
317    S: ndarray::RawDataClone<Elem = A>,
318    A: Clone,
319{
320    fn clone(&self) -> Self {
321        Self::new(self.bias().clone(), self.weights().clone())
322    }
323}
324
325impl<A, S, D> Copy for ParamsBase<S, D>
326where
327    D: Dimension + Copy,
328    <D as Dimension>::Smaller: Copy,
329    S: ndarray::RawDataClone<Elem = A> + Copy,
330    A: Copy,
331{
332}
333
334impl<A, S, D> PartialEq for ParamsBase<S, D>
335where
336    D: Dimension,
337    S: Data<Elem = A>,
338    A: PartialEq,
339{
340    fn eq(&self, other: &Self) -> bool {
341        self.bias() == other.bias() && self.weights() == other.weights()
342    }
343}
344
345impl<A, S, D> Eq for ParamsBase<S, D>
346where
347    D: Dimension,
348    S: Data<Elem = A>,
349    A: Eq,
350{
351}
352
353impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
354where
355    D: Dimension,
356    S: Data<Elem = A>,
357    A: core::fmt::Debug,
358{
359    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
360        f.debug_struct("ModelParams")
361            .field("bias", self.bias())
362            .field("weights", self.weights())
363            .finish()
364    }
365}
366
367