1use ndarray::{
6 ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder,
7};
8
9pub struct ParamsBase<S, D = ndarray::Ix2>
14where
15 D: Dimension,
16 S: RawData,
17{
18 pub(crate) bias: ArrayBase<S, D::Smaller>,
19 pub(crate) weights: ArrayBase<S, D>,
20}
21
22impl<A, S, D> ParamsBase<S, D>
23where
24 D: Dimension,
25 S: RawData<Elem = A>,
26{
27 pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
29 Self { bias, weights }
30 }
31 pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
33 where
34 A: Clone,
35 D: RemoveAxis,
36 S: DataOwned,
37 Sh: ShapeBuilder<Dim = D>,
38 {
39 let weights = ArrayBase::from_elem(shape, elem.clone());
40 let dim = weights.raw_dim();
41 let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
42 Self::new(bias, weights)
43 }
44 #[allow(clippy::should_implement_trait)]
45 pub fn default<Sh>(shape: Sh) -> Self
47 where
48 A: Clone + Default,
49 D: RemoveAxis,
50 S: DataOwned,
51 Sh: ShapeBuilder<Dim = D>,
52 {
53 Self::from_elems(shape, A::default())
54 }
55 pub fn ones<Sh>(shape: Sh) -> Self
57 where
58 A: Clone + num_traits::One,
59 D: RemoveAxis,
60 S: DataOwned,
61 Sh: ShapeBuilder<Dim = D>,
62 {
63 Self::from_elems(shape, A::one())
64 }
65 pub fn zeros<Sh>(shape: Sh) -> Self
67 where
68 A: Clone + num_traits::Zero,
69 D: RemoveAxis,
70 S: DataOwned,
71 Sh: ShapeBuilder<Dim = D>,
72 {
73 Self::from_elems(shape, A::zero())
74 }
75 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
77 &self.bias
78 }
79 pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
81 &mut self.bias
82 }
83 pub const fn weights(&self) -> &ArrayBase<S, D> {
85 &self.weights
86 }
87 pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
89 &mut self.weights
90 }
91 pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
93 where
94 A: Clone,
95 S: DataMut,
96 {
97 self.bias_mut().assign(bias);
98 self
99 }
100 pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
102 where
103 A: Clone,
104 S: DataMut,
105 {
106 self.weights_mut().assign(weights);
107 self
108 }
109 pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
111 core::mem::replace(&mut self.bias, bias)
112 }
113 pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
115 core::mem::replace(&mut self.weights, weights)
116 }
117 pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
119 *self.bias_mut() = bias;
120 self
121 }
122 pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
124 *self.weights_mut() = weights;
125 self
126 }
127 pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
129 where
130 A: Clone,
131 S: Data,
132 Self: crate::Backward<X, Y, Elem = A, Output = Z>,
133 {
134 <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
135 }
136 pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
138 where
139 A: Clone,
140 S: Data,
141 Self: crate::Forward<X, Output = Y>,
142 {
143 <Self as crate::Forward<X>>::forward(self, input)
144 }
145 pub fn dim(&self) -> D::Pattern {
147 self.weights().dim()
148 }
149 pub fn is_empty(&self) -> bool {
151 self.is_weights_empty() && self.is_bias_empty()
152 }
153 pub fn is_weights_empty(&self) -> bool {
155 self.weights().is_empty()
156 }
157 pub fn is_bias_empty(&self) -> bool {
159 self.bias().is_empty()
160 }
161 pub fn count_weight(&self) -> usize {
163 self.weights().len()
164 }
165 pub fn count_bias(&self) -> usize {
167 self.bias().len()
168 }
169 pub fn raw_dim(&self) -> D {
171 self.weights().raw_dim()
172 }
173 pub fn shape(&self) -> &[usize] {
175 self.weights().shape()
176 }
177 pub fn shape_bias(&self) -> &[usize] {
180 self.bias().shape()
181 }
182 pub fn size(&self) -> usize {
184 self.weights().len() + self.bias().len()
185 }
186 pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
188 where
189 A: Clone,
190 S: DataOwned,
191 {
192 ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
193 }
194 pub fn to_shape<Sh>(
197 &self,
198 shape: Sh,
199 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
200 where
201 A: Clone,
202 S: DataOwned,
203 Sh: ShapeBuilder,
204 Sh::Dim: Dimension + RemoveAxis,
205 {
206 let shape = shape.into_shape_with_order();
207 let dim = shape.raw_dim().clone();
208 let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
209 let weights = self.weights().to_shape(dim)?;
210 Ok(ParamsBase::new(bias, weights))
211 }
212 pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
215 where
216 A: Clone,
217 S: Data,
218 {
219 ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
220 }
221 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
223 where
224 S: Data,
225 {
226 ParamsBase::new(self.bias().view(), self.weights().view())
227 }
228 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
230 where
231 S: ndarray::DataMut,
232 {
233 ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
234 }
235}