1use ndarray::prelude::*;
6use ndarray::{Data, DataMut, DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder};
7
8pub struct ParamsBase<S, D = ndarray::Ix2>
10where
11 D: Dimension,
12 S: RawData,
13{
14 pub(crate) bias: ArrayBase<S, D::Smaller>,
15 pub(crate) weights: ArrayBase<S, D>,
16}
17
18impl<A, S, D> ParamsBase<S, D>
19where
20 D: Dimension,
21 S: RawData<Elem = A>,
22{
23 pub fn new(bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>) -> Self
24 where
25 A: Clone,
26 S: DataOwned,
27 {
28 Self { bias, weights }
29 }
30
31 pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
32 where
33 A: Clone,
34 D: RemoveAxis,
35 S: DataOwned,
36 Sh: ShapeBuilder<Dim = D>,
37 {
38 let weights = ArrayBase::from_elem(shape, elem.clone());
39 let dim = weights.raw_dim();
40 let bias = ArrayBase::from_elem(dim.remove_axis(Axis(0)), elem);
41 Self { bias, weights }
42 }
43 pub fn default<Sh>(shape: Sh) -> Self
45 where
46 A: Clone + Default,
47 D: RemoveAxis,
48 S: DataOwned,
49 Sh: ShapeBuilder<Dim = D>,
50 {
51 let weights = ArrayBase::default(shape);
52 let dim = weights.raw_dim();
53 let bias = ArrayBase::default(dim.remove_axis(Axis(0)));
54 Self { bias, weights }
55 }
56 pub fn ones<Sh>(shape: Sh) -> Self
58 where
59 A: Clone + num_traits::One,
60 D: RemoveAxis,
61 S: DataOwned,
62 Sh: ShapeBuilder<Dim = D>,
63 {
64 let weights = ArrayBase::ones(shape);
65 let dim = weights.raw_dim();
66 let bias = ArrayBase::ones(dim.remove_axis(Axis(0)));
67 Self { bias, weights }
68 }
69 pub fn zeros<Sh>(shape: Sh) -> Self
71 where
72 A: Clone + num_traits::Zero,
73 D: RemoveAxis,
74 S: DataOwned,
75 Sh: ShapeBuilder<Dim = D>,
76 {
77 let weights = ArrayBase::zeros(shape);
78 let dim = weights.raw_dim();
79 let bias = ArrayBase::zeros(dim.remove_axis(Axis(0)));
80 Self { bias, weights }
81 }
82 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller> {
84 &self.bias
85 }
86 #[inline]
88 pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
89 &mut self.bias
90 }
91 pub const fn weights(&self) -> &ArrayBase<S, D> {
93 &self.weights
94 }
95 #[inline]
97 pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
98 &mut self.weights
99 }
100 pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> crate::Result<Z>
102 where
103 A: Clone,
104 S: Data,
105 Self: crate::Backward<X, Y, Elem = A, Output = Z>,
106 {
107 <Self as crate::Backward<X, Y>>::backward(self, input, grad, lr)
108 }
109 pub fn forward<X, Y>(&self, input: &X) -> crate::Result<Y>
111 where
112 A: Clone,
113 S: Data,
114 Self: crate::Forward<X, Output = Y>,
115 {
116 <Self as crate::Forward<X>>::forward(self, input)
117 }
118 pub fn dim(&self) -> D::Pattern {
120 self.weights.dim()
121 }
122 pub fn iter(&self) -> super::iter::Iter<'_, A, D>
125 where
126 D: RemoveAxis,
127 S: Data,
128 {
129 super::iter::Iter {
130 bias: self.bias.iter(),
131 weights: self.weights.axis_iter(Axis(1)),
132 }
133 }
134 pub fn iter_mut(
136 &mut self,
137 ) -> core::iter::Zip<
138 ndarray::iter::AxisIterMut<'_, A, D::Smaller>,
139 ndarray::iter::IterMut<'_, A, D::Smaller>,
140 >
141 where
142 D: RemoveAxis,
143 S: DataMut,
144 {
145 self.weights
146 .axis_iter_mut(Axis(1))
147 .zip(self.bias.iter_mut())
148 }
149 pub fn iter_bias(&self) -> ndarray::iter::Iter<'_, A, D::Smaller>
151 where
152 S: Data,
153 {
154 self.bias.iter()
155 }
156 pub fn iter_bias_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D::Smaller>
158 where
159 S: DataMut,
160 {
161 self.bias.iter_mut()
162 }
163 pub fn iter_weights(&self) -> ndarray::iter::Iter<'_, A, D>
165 where
166 S: Data,
167 {
168 self.weights.iter()
169 }
170 pub fn iter_weights_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
172 where
173 S: DataMut,
174 {
175 self.weights.iter_mut()
176 }
177 pub fn len(&self) -> usize {
179 self.weights.len()
180 }
181 pub fn raw_dim(&self) -> D {
183 self.weights.raw_dim()
184 }
185 pub fn shape(&self) -> &[usize] {
187 self.weights.shape()
188 }
189 pub fn shape_bias(&self) -> &[usize] {
192 self.bias.shape()
193 }
194 pub fn to_owned(&self) -> ParamsBase<ndarray::OwnedRepr<A>, D>
196 where
197 A: Clone,
198 S: DataOwned,
199 {
200 ParamsBase {
201 bias: self.bias.to_owned(),
202 weights: self.weights.to_owned(),
203 }
204 }
205 pub fn to_shape<Sh>(
208 &self,
209 shape: Sh,
210 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
211 where
212 A: Clone,
213 S: DataOwned,
214 Sh: ShapeBuilder,
215 Sh::Dim: Dimension + RemoveAxis,
216 {
217 let shape = shape.into_shape_with_order();
218 let dim = shape.raw_dim().clone();
219 let bias = self.bias.to_shape(dim.remove_axis(Axis(0)))?;
220 let weights = self.weights.to_shape(dim)?;
221 Ok(ParamsBase { bias, weights })
222 }
223 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
225 where
226 S: Data,
227 {
228 ParamsBase {
229 bias: self.bias.view(),
230 weights: self.weights.view(),
231 }
232 }
233 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
235 where
236 S: ndarray::DataMut,
237 {
238 ParamsBase {
239 bias: self.bias.view_mut(),
240 weights: self.weights.view_mut(),
241 }
242 }
243}
244
245impl<A, S> ParamsBase<S, Ix1>
246where
247 S: RawData<Elem = A>,
248{
249 pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
250 where
251 A: Clone,
252 S: DataOwned,
253 {
254 Self {
255 bias: ArrayBase::from_elem((), bias),
256 weights,
257 }
258 }
259
260 pub fn nrows(&self) -> usize {
261 self.weights.len()
262 }
263}
264
265impl<A, S> ParamsBase<S, Ix2>
266where
267 S: RawData<Elem = A>,
268{
269 pub fn ncols(&self) -> usize {
270 self.weights.ncols()
271 }
272
273 pub fn nrows(&self) -> usize {
274 self.weights.nrows()
275 }
276}
277
278impl<A, S, D> Clone for ParamsBase<S, D>
279where
280 D: Dimension,
281 S: ndarray::RawDataClone<Elem = A>,
282 A: Clone,
283{
284 fn clone(&self) -> Self {
285 Self {
286 bias: self.bias.clone(),
287 weights: self.weights.clone(),
288 }
289 }
290}
291
292impl<A, S, D> Copy for ParamsBase<S, D>
293where
294 D: Dimension + Copy,
295 <D as Dimension>::Smaller: Copy,
296 S: ndarray::RawDataClone<Elem = A> + Copy,
297 A: Copy,
298{
299}
300
301impl<A, S, D> PartialEq for ParamsBase<S, D>
302where
303 D: Dimension,
304 S: Data<Elem = A>,
305 A: PartialEq,
306{
307 fn eq(&self, other: &Self) -> bool {
308 self.bias == other.bias && self.weights == other.weights
309 }
310}
311
312impl<A, S, D> Eq for ParamsBase<S, D>
313where
314 D: Dimension,
315 S: Data<Elem = A>,
316 A: Eq,
317{
318}
319
320impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
321where
322 D: Dimension,
323 S: Data<Elem = A>,
324 A: core::fmt::Debug,
325{
326 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
327 f.debug_struct("ModelParams")
328 .field("bias", &self.bias)
329 .field("weights", &self.weights)
330 .finish()
331 }
332}
333
334