1use ndarray::{
6 ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeArg,
7 ShapeBuilder,
8};
9
10pub struct ParamsBase<S, D = ndarray::Ix2>
15where
16 D: Dimension,
17 S: RawData,
18{
19 pub(crate) bias: ArrayBase<S, D::Smaller>,
20 pub(crate) weights: ArrayBase<S, D>,
21}
22
23impl<A, S, D> ParamsBase<S, D>
24where
25 D: Dimension,
26 S: RawData<Elem = A>,
27{
28 pub const fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self {
30 Self { bias, weights }
31 }
32 pub fn init_from_fn<Sh, F>(shape: Sh, init: F) -> Self
34 where
35 A: Clone,
36 D: RemoveAxis,
37 S: DataOwned,
38 Sh: ShapeBuilder<Dim = D>,
39 F: Fn() -> A,
40 {
41 let shape = shape.into_shape_with_order();
42 let bshape = shape.raw_dim().remove_axis(Axis(0));
43 let bias = ArrayBase::from_shape_fn(bshape, |_| init());
45 let weights = ArrayBase::from_shape_fn(shape, |_| init());
46 Self::new(bias, weights)
48 }
49 pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
51 where
52 A: Clone,
53 D: RemoveAxis,
54 S: DataOwned,
55 Sh: ShapeBuilder<Dim = D>,
56 D::Smaller: Dimension + ShapeArg,
57 F: Fn(<D::Smaller as Dimension>::Pattern) -> A + Fn(<D as Dimension>::Pattern) -> A,
58 {
59 let shape = shape.into_shape_with_order();
60 let bdim = shape.raw_dim().remove_axis(Axis(0));
61 let bias = ArrayBase::from_shape_fn(bdim, |s| f(s));
62 let weights = ArrayBase::from_shape_fn(shape, |s| f(s));
63 Self::new(bias, weights)
64 }
65 pub fn from_bias<Sh>(shape: Sh, bias: ArrayBase<S, D::Smaller>) -> Self
67 where
68 A: Clone + Default,
69 D: RemoveAxis,
70 S: DataOwned,
71 Sh: ShapeBuilder<Dim = D>,
72 {
73 let weights = ArrayBase::from_elem(shape, A::default());
74 Self::new(bias, weights)
75 }
76 pub fn from_weights<Sh>(shape: Sh, weights: ArrayBase<S, D>) -> Self
79 where
80 A: Clone + Default,
81 D: RemoveAxis,
82 S: DataOwned,
83 Sh: ShapeBuilder<Dim = D>,
84 {
85 let shape = shape.into_shape_with_order();
86 let dim_bias = shape.raw_dim().remove_axis(Axis(0));
87 let bias = ArrayBase::from_elem(dim_bias, A::default());
88 Self::new(bias, weights)
89 }
90 pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
92 where
93 A: Clone,
94 D: RemoveAxis,
95 S: DataOwned,
96 Sh: ShapeBuilder<Dim = D>,
97 {
98 let weights = ArrayBase::from_elem(shape, elem.clone());
99 let dim = weights.raw_dim();
100 let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
101 Self::new(bias, weights)
102 }
103 #[allow(clippy::should_implement_trait)]
104 pub fn default<Sh>(shape: Sh) -> Self
106 where
107 A: Clone + Default,
108 D: RemoveAxis,
109 S: DataOwned,
110 Sh: ShapeBuilder<Dim = D>,
111 {
112 Self::from_elem(shape, A::default())
113 }
114 pub fn ones<Sh>(shape: Sh) -> Self
116 where
117 A: Clone + num_traits::One,
118 D: RemoveAxis,
119 S: DataOwned,
120 Sh: ShapeBuilder<Dim = D>,
121 {
122 Self::from_elem(shape, A::one())
123 }
124 pub fn zeros<Sh>(shape: Sh) -> Self
126 where
127 A: Clone + num_traits::Zero,
128 D: RemoveAxis,
129 S: DataOwned,
130 Sh: ShapeBuilder<Dim = D>,
131 {
132 Self::from_elem(shape, A::zero())
133 }
134 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
136 &self.bias
137 }
138 pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
140 &mut self.bias
141 }
142 pub const fn weights(&self) -> &ArrayBase<S, D> {
144 &self.weights
145 }
146 pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
148 &mut self.weights
149 }
150 pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
152 where
153 A: Clone,
154 S: DataMut,
155 {
156 self.bias_mut().assign(bias);
157 self
158 }
159 pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
161 where
162 A: Clone,
163 S: DataMut,
164 {
165 self.weights_mut().assign(weights);
166 self
167 }
168 pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
170 core::mem::replace(&mut self.bias, bias)
171 }
172 pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
174 core::mem::replace(&mut self.weights, weights)
175 }
176 pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
178 *self.bias_mut() = bias;
179 self
180 }
181 pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
183 *self.weights_mut() = weights;
184 self
185 }
186 pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
188 where
189 A: Clone,
190 S: Data,
191 Self: crate::Backward<X, Y, Elem = A, Output = Z>,
192 {
193 <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
194 }
195 pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
197 where
198 A: Clone,
199 S: Data,
200 Self: crate::Forward<X, Output = Y>,
201 {
202 <Self as crate::Forward<X>>::forward(self, input)
203 }
204 pub fn dim(&self) -> D::Pattern {
206 self.weights().dim()
207 }
208 pub fn is_empty(&self) -> bool {
210 self.is_weights_empty() && self.is_bias_empty()
211 }
212 pub fn is_weights_empty(&self) -> bool {
214 self.weights().is_empty()
215 }
216 pub fn is_bias_empty(&self) -> bool {
218 self.bias().is_empty()
219 }
220 pub fn count_weight(&self) -> usize {
222 self.weights().len()
223 }
224 pub fn count_bias(&self) -> usize {
226 self.bias().len()
227 }
228 pub fn raw_dim(&self) -> D {
230 self.weights().raw_dim()
231 }
232 pub fn shape(&self) -> &[usize] {
234 self.weights().shape()
235 }
236 pub fn shape_bias(&self) -> &[usize] {
239 self.bias().shape()
240 }
241 pub fn size(&self) -> usize {
243 self.weights().len() + self.bias().len()
244 }
245 pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
247 where
248 A: Clone,
249 S: DataOwned,
250 {
251 ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
252 }
253 pub fn to_shape<Sh>(
256 &self,
257 shape: Sh,
258 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
259 where
260 A: Clone,
261 S: DataOwned,
262 Sh: ShapeBuilder,
263 Sh::Dim: Dimension + RemoveAxis,
264 {
265 let shape = shape.into_shape_with_order();
266 let dim = shape.raw_dim().clone();
267 let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
268 let weights = self.weights().to_shape(dim)?;
269 Ok(ParamsBase::new(bias, weights))
270 }
271 pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
274 where
275 A: Clone,
276 S: Data,
277 {
278 ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
279 }
280 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
282 where
283 S: Data,
284 {
285 ParamsBase::new(self.bias().view(), self.weights().view())
286 }
287 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
289 where
290 S: ndarray::DataMut,
291 {
292 ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
293 }
294}