1use ndarray::{
6    ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder,
7};
8
9pub struct ParamsBase<S, D = ndarray::Ix2>
14where
15    D: Dimension,
16    S: RawData,
17{
18    pub(crate) bias: ArrayBase<S, D::Smaller>,
19    pub(crate) weights: ArrayBase<S, D>,
20}
21
22impl<A, S, D> ParamsBase<S, D>
23where
24    D: Dimension,
25    S: RawData<Elem = A>,
26{
27    pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
29        Self { bias, weights }
30    }
31    pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
33    where
34        A: Clone,
35        D: RemoveAxis,
36        S: DataOwned,
37        Sh: ShapeBuilder<Dim = D>,
38    {
39        let weights = ArrayBase::from_elem(shape, elem.clone());
40        let dim = weights.raw_dim();
41        let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
42        Self::new(bias, weights)
43    }
44    #[allow(clippy::should_implement_trait)]
45    pub fn default<Sh>(shape: Sh) -> Self
47    where
48        A: Clone + Default,
49        D: RemoveAxis,
50        S: DataOwned,
51        Sh: ShapeBuilder<Dim = D>,
52    {
53        Self::from_elems(shape, A::default())
54    }
55    pub fn ones<Sh>(shape: Sh) -> Self
57    where
58        A: Clone + num_traits::One,
59        D: RemoveAxis,
60        S: DataOwned,
61        Sh: ShapeBuilder<Dim = D>,
62    {
63        Self::from_elems(shape, A::one())
64    }
65    pub fn zeros<Sh>(shape: Sh) -> Self
67    where
68        A: Clone + num_traits::Zero,
69        D: RemoveAxis,
70        S: DataOwned,
71        Sh: ShapeBuilder<Dim = D>,
72    {
73        Self::from_elems(shape, A::zero())
74    }
75    pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
77        &self.bias
78    }
79    pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
81        &mut self.bias
82    }
83    pub const fn weights(&self) -> &ArrayBase<S, D> {
85        &self.weights
86    }
87    pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
89        &mut self.weights
90    }
91    pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
93    where
94        A: Clone,
95        S: DataMut,
96    {
97        self.bias_mut().assign(bias);
98        self
99    }
100    pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
102    where
103        A: Clone,
104        S: DataMut,
105    {
106        self.weights_mut().assign(weights);
107        self
108    }
109    pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
111        core::mem::replace(&mut self.bias, bias)
112    }
113    pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
115        core::mem::replace(&mut self.weights, weights)
116    }
117    pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
119        *self.bias_mut() = bias;
120        self
121    }
122    pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
124        *self.weights_mut() = weights;
125        self
126    }
127    pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
129    where
130        A: Clone,
131        S: Data,
132        Self: crate::Backward<X, Y, Elem = A, Output = Z>,
133    {
134        <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
135    }
136    pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
138    where
139        A: Clone,
140        S: Data,
141        Self: crate::Forward<X, Output = Y>,
142    {
143        <Self as crate::Forward<X>>::forward(self, input)
144    }
145    pub fn dim(&self) -> D::Pattern {
147        self.weights().dim()
148    }
149    pub fn is_empty(&self) -> bool {
151        self.is_weights_empty() && self.is_bias_empty()
152    }
153    pub fn is_weights_empty(&self) -> bool {
155        self.weights().is_empty()
156    }
157    pub fn is_bias_empty(&self) -> bool {
159        self.bias().is_empty()
160    }
161    pub fn count_weight(&self) -> usize {
163        self.weights().len()
164    }
165    pub fn count_bias(&self) -> usize {
167        self.bias().len()
168    }
169    pub fn raw_dim(&self) -> D {
171        self.weights().raw_dim()
172    }
173    pub fn shape(&self) -> &[usize] {
175        self.weights().shape()
176    }
177    pub fn shape_bias(&self) -> &[usize] {
180        self.bias().shape()
181    }
182    pub fn size(&self) -> usize {
184        self.weights().len() + self.bias().len()
185    }
186    pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
188    where
189        A: Clone,
190        S: DataOwned,
191    {
192        ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
193    }
194    pub fn to_shape<Sh>(
197        &self,
198        shape: Sh,
199    ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
200    where
201        A: Clone,
202        S: DataOwned,
203        Sh: ShapeBuilder,
204        Sh::Dim: Dimension + RemoveAxis,
205    {
206        let shape = shape.into_shape_with_order();
207        let dim = shape.raw_dim().clone();
208        let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
209        let weights = self.weights().to_shape(dim)?;
210        Ok(ParamsBase::new(bias, weights))
211    }
212    pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
215    where
216        A: Clone,
217        S: Data,
218    {
219        ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
220    }
221    pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
223    where
224        S: Data,
225    {
226        ParamsBase::new(self.bias().view(), self.weights().view())
227    }
228    pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
230    where
231        S: ndarray::DataMut,
232    {
233        ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
234    }
235}