1use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11 D: Dimension,
12 S: RawData,
13{
14 pub(crate) bias: ArrayBase<S, D::Smaller>,
15 pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20 D: Dimension,
21 S: RawData<Elem = A>,
22{
23 pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24 where
25 A: Clone,
26 S: DataOwned,
27 {
28 Self { bias, weights }
29 }
30
31 pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32 where
33 A: Clone,
34 D: RemoveAxis,
35 S: DataOwned,
36 Sh: ShapeBuilder<Dim = D>,
37 {
38 let weights = ArrayBase::from_elem(shape, elem.clone());
39 let dim = weights.raw_dim();
40 let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41 Self { bias, weights }
42 }
43 pub fn default<Sh>(shape: Sh) -> Self
45 where
46 A: Clone + Default,
47 D: RemoveAxis,
48 S: DataOwned,
49 Sh: ShapeBuilder<Dim = D>,
50 {
51 let weights = ArrayBase::default(shape);
52 let dim = weights.raw_dim();
53 let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54 Self { bias, weights }
55 }
56 pub fn ones<Sh>(shape: Sh) -> Self
58 where
59 A: Clone + num_traits::One,
60 D: RemoveAxis,
61 S: DataOwned,
62 Sh: ShapeBuilder<Dim = D>,
63 {
64 let weights = ArrayBase::ones(shape);
65 let dim = weights.raw_dim();
66 let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67 Self { bias, weights }
68 }
69 pub fn zeros<Sh>(shape: Sh) -> Self
71 where
72 A: Clone + num_traits::Zero,
73 D: RemoveAxis,
74 S: DataOwned,
75 Sh: ShapeBuilder<Dim = D>,
76 {
77 let weights = ArrayBase::zeros(shape);
78 let dim = weights.raw_dim();
79 let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80 Self { bias, weights }
81 }
82 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84 &self.bias
85 }
86 #[inline]
88 pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89 &mut self.bias
90 }
91 pub const fn weights(&self) -> &ArrayBase<S, D> {
93 &self.weights
94 }
95 #[inline]
97 pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98 &mut self.weights
99 }
100
101 pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
103 where
104 A: Clone,
105 S: DataMut,
106 {
107 self.bias_mut().assign(bias);
108 self
109 }
110 pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
112 where
113 A: Clone,
114 S: DataMut,
115 {
116 self.weights_mut().assign(weights);
117 self
118 }
119 pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
121 core::mem::replace(&mut self.bias, bias)
122 }
123 pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
125 core::mem::replace(&mut self.weights, weights)
126 }
127 pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
129 *self.bias_mut() = bias;
130 self
131 }
132 pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
134 *self.weights_mut() = weights;
135 self
136 }
137 pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
139 where
140 A: Clone,
141 S: Data,
142 Self: crate::Backward<X, Y, Elem = A, Output = Z>,
143 {
144 <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
145 }
146 pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
148 where
149 A: Clone,
150 S: Data,
151 Self: crate::Forward<X, Output = Y>,
152 {
153 <Self as crate::Forward<X>>::forward(self, input)
154 }
155 pub fn dim(&self) -> D::Pattern {
157 self.weights().dim()
158 }
159 pub fn iter(&self) -> super::iter::Iter<'_, A, D>
162 where
163 D: RemoveAxis,
164 S: Data,
165 {
166 super::iter::Iter {
167 bias: self.bias().iter(),
168 weights: self.weights().axis_iter(Axis(1)),
169 }
170 }
171 pub fn iter_mut(
173 &mut self,
174 ) -> core::iter::Zip<
175 ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
176 ndarray::iter::IterMut<'_, A, D::Smaller>,
177 >
178 where
179 D: RemoveAxis,
180 S: DataMut,
181 {
182 self.weights
183 .axis_iter_mut(Axis(1))
184 .zip(self.bias.iter_mut())
185 }
186 pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
188 where
189 S: Data,
190 {
191 self.bias().iter()
192 }
193 pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
195 where
196 S: DataMut,
197 {
198 self.bias_mut().iter_mut()
199 }
200 pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
202 where
203 S: Data,
204 {
205 self.weights().iter()
206 }
207 pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
209 where
210 S: DataMut,
211 {
212 self.weights_mut().iter_mut()
213 }
214 pub fn is_empty(&self) -> bool {
216 self.is_weights_empty() && self.is_bias_empty()
217 }
218 pub fn is_weights_empty(&self) -> bool {
220 self.weights().is_empty()
221 }
222 pub fn is_bias_empty(&self) -> bool {
224 self.bias().is_empty()
225 }
226 pub fn len(&self) -> usize {
228 self.weights().len()
229 }
230 pub fn raw_dim(&self) -> D {
232 self.weights().raw_dim()
233 }
234 pub fn shape(&self) -> &[usize] {
236 self.weights().shape()
237 }
238 pub fn shape_bias(&self) -> &[usize] {
241 self.bias().shape()
242 }
243 pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
245 where
246 A: Clone,
247 S: DataOwned,
248 {
249 ParamsBase {
250 bias: self.bias().to_owned(),
251 weights: self.weights().to_owned(),
252 }
253 }
254 pub fn to_shape<Sh>(
257 &self,
258 shape: Sh,
259 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
260 where
261 A: Clone,
262 S: DataOwned,
263 Sh: ShapeBuilder,
264 Sh::Dim: Dimension + RemoveAxis,
265 {
266 let shape = shape.into_shape_with_order();
267 let dim = shape.raw_dim().clone();
268 let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
269 let weights = self.weights().to_shape(dim)?;
270 Ok(ParamsBase { bias, weights })
271 }
272 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
274 where
275 S: Data,
276 {
277 ParamsBase {
278 bias: self.bias().view(),
279 weights: self.weights().view(),
280 }
281 }
282 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
284 where
285 S: ndarray::DataMut,
286 {
287 ParamsBase {
288 bias: self.bias.view_mut(),
289 weights: self.weights.view_mut(),
290 }
291 }
292}
293
294impl<A, S> ParamsBase<S, Ix1>
295where
296 S: RawData<Elem = A>,
297{
298 pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
299 where
300 A: Clone,
301 S: DataOwned,
302 {
303 Self {
304 bias: ArrayBase::from_elem((), bias),
305 weights,
306 }
307 }
308
309 pub fn nrows(&self) -> usize {
310 self.weights.len()
311 }
312}
313
314impl<A, S> ParamsBase<S, Ix2>
315where
316 S: RawData<Elem = A>,
317{
318 pub fn ncols(&self) -> usize {
319 self.weights.ncols()
320 }
321
322 pub fn nrows(&self) -> usize {
323 self.weights.nrows()
324 }
325}
326
327impl<A, S, D> Clone for ParamsBase<S, D>
328where
329 D: Dimension,
330 S: ndarray::RawDataClone<Elem = A>,
331 A: Clone,
332{
333 fn clone(&self) -> Self {
334 Self {
335 bias: self.bias.clone(),
336 weights: self.weights.clone(),
337 }
338 }
339}
340
341impl<A, S, D> Copy for ParamsBase<S, D>
342where
343 D: Dimension + Copy,
344 <D as Dimension>::Smaller: Copy,
345 S: ndarray::RawDataClone<Elem = A> + Copy,
346 A: Copy,
347{
348}
349
350impl<A, S, D> PartialEq for ParamsBase<S, D>
351where
352 D: Dimension,
353 S: Data<Elem = A>,
354 A: PartialEq,
355{
356 fn eq(&self, other: &Self) -> bool {
357 self.bias == other.bias && self.weights == other.weights
358 }
359}
360
361impl<A, S, D> Eq for ParamsBase<S, D>
362where
363 D: Dimension,
364 S: Data<Elem = A>,
365 A: Eq,
366{
367}
368
369impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
370where
371 D: Dimension,
372 S: Data<Elem = A>,
373 A: core::fmt::Debug,
374{
375 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
376 f.debug_struct("ModelParams")
377 .field("bias", &self.bias)
378 .field("weights", &self.weights)
379 .finish()
380 }
381}
382
383