1use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11 D: Dimension,
12 S: RawData,
13{
14 pub(crate) bias: ArrayBase<S, D::Smaller>,
15 pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20 D: Dimension,
21 S: RawData<Elem = A>,
22{
23 pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24 where
25 A: Clone,
26 S: DataOwned,
27 {
28 Self { bias, weights }
29 }
30
31 pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32 where
33 A: Clone,
34 D: RemoveAxis,
35 S: DataOwned,
36 Sh: ShapeBuilder<Dim = D>,
37 {
38 let weights = ArrayBase::from_elem(shape, elem.clone());
39 let dim = weights.raw_dim();
40 let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41 Self { bias, weights }
42 }
43 pub fn default<Sh>(shape: Sh) -> Self
45 where
46 A: Clone + Default,
47 D: RemoveAxis,
48 S: DataOwned,
49 Sh: ShapeBuilder<Dim = D>,
50 {
51 let weights = ArrayBase::default(shape);
52 let dim = weights.raw_dim();
53 let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54 Self { bias, weights }
55 }
56 pub fn ones<Sh>(shape: Sh) -> Self
58 where
59 A: Clone + num_traits::One,
60 D: RemoveAxis,
61 S: DataOwned,
62 Sh: ShapeBuilder<Dim = D>,
63 {
64 let weights = ArrayBase::ones(shape);
65 let dim = weights.raw_dim();
66 let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67 Self { bias, weights }
68 }
69 pub fn zeros<Sh>(shape: Sh) -> Self
71 where
72 A: Clone + num_traits::Zero,
73 D: RemoveAxis,
74 S: DataOwned,
75 Sh: ShapeBuilder<Dim = D>,
76 {
77 let weights = ArrayBase::zeros(shape);
78 let dim = weights.raw_dim();
79 let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80 Self { bias, weights }
81 }
82 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84 &self.bias
85 }
86 #[inline]
88 pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89 &mut self.bias
90 }
91 pub const fn weights(&self) -> &ArrayBase<S, D> {
93 &self.weights
94 }
95 #[inline]
97 pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98 &mut self.weights
99 }
100
101 pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
103 where
104 A: Clone,
105 S: DataMut,
106 {
107 self.bias_mut().assign(&bias);
108 self
109 }
110 pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
112 where
113 A: Clone,
114 S: DataMut,
115 {
116 self.weights_mut().assign(&weights);
117 self
118 }
119 pub fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
121 core::mem::replace(&mut self.bias, bias)
122 }
123 pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
125 core::mem::replace(&mut self.weights, weights)
126 }
127 pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
129 *self.bias_mut() = bias;
130 self
131 }
132 pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
134 *self.weights_mut() = weights;
135 self
136 }
137 pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
139 where
140 A: Clone,
141 S: Data,
142 Self: crate::Backward<X, Y, Elem = A, Output = Z>,
143 {
144 <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
145 }
146 pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
148 where
149 A: Clone,
150 S: Data,
151 Self: crate::Forward<X, Output = Y>,
152 {
153 <Self as crate::Forward<X>>::forward(self, input)
154 }
155 pub fn dim(&self) -> D::Pattern {
157 self.weights().dim()
158 }
159 pub fn iter(&self) -> super::iter::Iter<'_, A, D>
162 where
163 D: RemoveAxis,
164 S: Data,
165 {
166 super::iter::Iter {
167 bias: self.bias().iter(),
168 weights: self.weights().axis_iter(Axis(1)),
169 }
170 }
171 pub fn iter_mut(
173 &mut self,
174 ) -> core::iter::Zip<
175 ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
176 ndarray::iter::IterMut<'_, A, D::Smaller>,
177 >
178 where
179 D: RemoveAxis,
180 S: DataMut,
181 {
182 self.weights
183 .axis_iter_mut(Axis(1))
184 .zip(self.bias.iter_mut())
185 }
186 pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
188 where
189 S: Data,
190 {
191 self.bias().iter()
192 }
193 pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
195 where
196 S: DataMut,
197 {
198 self.bias_mut().iter_mut()
199 }
200 pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
202 where
203 S: Data,
204 {
205 self.weights().iter()
206 }
207 pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
209 where
210 S: DataMut,
211 {
212 self.weights_mut().iter_mut()
213 }
214 pub fn len(&self) -> usize {
216 self.weights().len()
217 }
218 pub fn raw_dim(&self) -> D {
220 self.weights().raw_dim()
221 }
222 pub fn shape(&self) -> &[usize] {
224 self.weights().shape()
225 }
226 pub fn shape_bias(&self) -> &[usize] {
229 self.bias().shape()
230 }
231 pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
233 where
234 A: Clone,
235 S: DataOwned,
236 {
237 ParamsBase {
238 bias: self.bias().to_owned(),
239 weights: self.weights().to_owned(),
240 }
241 }
242 pub fn to_shape<Sh>(
245 &self,
246 shape: Sh,
247 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
248 where
249 A: Clone,
250 S: DataOwned,
251 Sh: ShapeBuilder,
252 Sh::Dim: Dimension + RemoveAxis,
253 {
254 let shape = shape.into_shape_with_order();
255 let dim = shape.raw_dim().clone();
256 let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
257 let weights = self.weights().to_shape(dim)?;
258 Ok(ParamsBase { bias, weights })
259 }
260 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
262 where
263 S: Data,
264 {
265 ParamsBase {
266 bias: self.bias().view(),
267 weights: self.weights().view(),
268 }
269 }
270 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
272 where
273 S: ndarray::DataMut,
274 {
275 ParamsBase {
276 bias: self.bias.view_mut(),
277 weights: self.weights.view_mut(),
278 }
279 }
280}
281
282impl<A, S> ParamsBase<S, Ix1>
283where
284 S: RawData<Elem = A>,
285{
286 pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
287 where
288 A: Clone,
289 S: DataOwned,
290 {
291 Self {
292 bias: ArrayBase::from_elem((), bias),
293 weights,
294 }
295 }
296
297 pub fn nrows(&self) -> usize {
298 self.weights.len()
299 }
300}
301
302impl<A, S> ParamsBase<S, Ix2>
303where
304 S: RawData<Elem = A>,
305{
306 pub fn ncols(&self) -> usize {
307 self.weights.ncols()
308 }
309
310 pub fn nrows(&self) -> usize {
311 self.weights.nrows()
312 }
313}
314
315impl<A, S, D> Clone for ParamsBase<S, D>
316where
317 D: Dimension,
318 S: ndarray::RawDataClone<Elem = A>,
319 A: Clone,
320{
321 fn clone(&self) -> Self {
322 Self {
323 bias: self.bias.clone(),
324 weights: self.weights.clone(),
325 }
326 }
327}
328
329impl<A, S, D> Copy for ParamsBase<S, D>
330where
331 D: Dimension + Copy,
332 <D as Dimension>::Smaller: Copy,
333 S: ndarray::RawDataClone<Elem = A> + Copy,
334 A: Copy,
335{
336}
337
338impl<A, S, D> PartialEq for ParamsBase<S, D>
339where
340 D: Dimension,
341 S: Data<Elem = A>,
342 A: PartialEq,
343{
344 fn eq(&self, other: &Self) -> bool {
345 self.bias == other.bias && self.weights == other.weights
346 }
347}
348
349impl<A, S, D> Eq for ParamsBase<S, D>
350where
351 D: Dimension,
352 S: Data<Elem = A>,
353 A: Eq,
354{
355}
356
357impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
358where
359 D: Dimension,
360 S: Data<Elem = A>,
361 A: core::fmt::Debug,
362{
363 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
364 f.debug_struct("ModelParams")
365 .field("bias", &self.bias)
366 .field("weights", &self.weights)
367 .finish()
368 }
369}
370
371