1use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14 A: ScalarOperand + Float + FromPrimitive,
15 D: Dimension,
16 S: Data<Elem = A>,
17{
18 pub fn l1_norm(&self) -> A {
20 let bias = self.bias.l1_norm();
21 let weights = self.weights.l1_norm();
22 bias + weights
23 }
24 pub fn l2_norm(&self) -> A {
26 let bias = self.bias.l2_norm();
27 let weights = self.weights.l2_norm();
28 bias + weights
29 }
30 pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33 where
34 S: DataMut,
35 Self: ApplyGradient<Delta, Elem = A, Output = Z>,
36 {
37 <Self as ApplyGradient<Delta>>::apply_gradient(self, grad, lr)
38 }
39
40 pub fn apply_gradient_with_decay<Grad, Z>(
41 &mut self,
42 grad: &Grad,
43 lr: A,
44 decay: A,
45 ) -> crate::Result<Z>
46 where
47 S: DataMut,
48 Self: ApplyGradient<Grad, Elem = A, Output = Z>,
49 {
50 <Self as ApplyGradient<Grad>>::apply_gradient_with_decay(self, grad, lr, decay)
51 }
52
53 pub fn apply_gradient_with_momentum<Grad, V, Z>(
54 &mut self,
55 grad: &Grad,
56 lr: A,
57 momentum: A,
58 velocity: &mut V,
59 ) -> crate::Result<Z>
60 where
61 S: DataMut,
62 Self: ApplyGradientExt<Grad, Elem = A, Output = Z, Velocity = V>,
63 {
64 <Self as ApplyGradientExt<Grad>>::apply_gradient_with_momentum(
65 self, grad, lr, momentum, velocity,
66 )
67 }
68
69 pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70 &mut self,
71 grad: &Grad,
72 lr: A,
73 decay: A,
74 momentum: A,
75 velocity: &mut V,
76 ) -> crate::Result<Z>
77 where
78 S: DataMut,
79 Self: ApplyGradientExt<Grad, Elem = A, Output = Z, Velocity = V>,
80 {
81 <Self as ApplyGradientExt<Grad>>::apply_gradient_with_decay_and_momentum(
82 self, grad, lr, decay, momentum, velocity,
83 )
84 }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>> for ParamsBase<S, D>
88where
89 A: Float + FromPrimitive + ScalarOperand,
90 S: DataMut<Elem = A>,
91 T: Data<Elem = A>,
92 D: Dimension,
93{
94 type Elem = A;
95 type Output = ();
96
97 fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
98 self.bias.apply_gradient(grad.bias(), lr)?;
100 self.weights.apply_gradient(grad.weights(), lr)?;
102 Ok(())
103 }
104
105 fn apply_gradient_with_decay(
106 &mut self,
107 grad: &ParamsBase<T, D>,
108 lr: A,
109 decay: A,
110 ) -> crate::Result<Self::Output> {
111 self.bias
113 .apply_gradient_with_decay(grad.bias(), lr, decay)?;
114 self.weights
116 .apply_gradient_with_decay(grad.weights(), lr, decay)?;
117 Ok(())
118 }
119}
120
121impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>> for ParamsBase<S, D>
122where
123 A: Float + FromPrimitive + ScalarOperand,
124 S: DataMut<Elem = A>,
125 T: Data<Elem = A>,
126 D: Dimension,
127{
128 type Velocity = Params<A, D>;
129
130 fn apply_gradient_with_momentum(
131 &mut self,
132 grad: &ParamsBase<T, D>,
133 lr: A,
134 momentum: A,
135 velocity: &mut Self::Velocity,
136 ) -> crate::Result<()> {
137 self.bias
139 .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
140 self.weights.apply_gradient_with_momentum(
142 grad.weights(),
143 lr,
144 momentum,
145 velocity.weights_mut(),
146 )?;
147 Ok(())
148 }
149
150 fn apply_gradient_with_decay_and_momentum(
151 &mut self,
152 grad: &ParamsBase<T, D>,
153 lr: A,
154 decay: A,
155 momentum: A,
156 velocity: &mut Self::Velocity,
157 ) -> crate::Result<()> {
158 self.bias.apply_gradient_with_decay_and_momentum(
160 grad.bias(),
161 lr,
162 decay,
163 momentum,
164 velocity.bias_mut(),
165 )?;
166 self.weights.apply_gradient_with_decay_and_momentum(
168 grad.weights(),
169 lr,
170 decay,
171 momentum,
172 velocity.weights_mut(),
173 )?;
174 Ok(())
175 }
176}
177
178impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
179where
180 A: Float + FromPrimitive + ScalarOperand,
181 S: Data<Elem = A>,
182 T: Data<Elem = A>,
183{
184 type Elem = A;
185 type Output = A;
186
187 fn backward(
188 &mut self,
189 input: &ArrayBase<S, Ix2>,
190 delta: &ArrayBase<T, Ix1>,
191 gamma: Self::Elem,
192 ) -> crate::Result<Self::Output> {
193 let weight_delta = delta.t().dot(input);
195 self.weights.apply_gradient(&weight_delta, gamma)?;
197 self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
198 Ok(delta.pow2().sum())
200 }
201}
202
203impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
204where
205 A: Float + FromPrimitive + ScalarOperand,
206 S: Data<Elem = A>,
207 T: Data<Elem = A>,
208{
209 type Elem = A;
210 type Output = A;
211
212 fn backward(
213 &mut self,
214 input: &ArrayBase<S, Ix1>,
215 delta: &ArrayBase<T, Ix0>,
216 gamma: Self::Elem,
217 ) -> crate::Result<Self::Output> {
218 let weight_delta = input * delta;
220 self.weights.apply_gradient(&weight_delta, gamma)?;
222 self.bias.apply_gradient(&delta, gamma)?;
223 Ok(delta.pow2().sum())
225 }
226}
227
228impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
229where
230 A: Float + FromPrimitive + ScalarOperand,
231 S: Data<Elem = A>,
232 T: Data<Elem = A>,
233{
234 type Elem = A;
235 type Output = A;
236
237 fn backward(
238 &mut self,
239 input: &ArrayBase<S, Ix1>,
240 delta: &ArrayBase<T, Ix1>,
241 gamma: Self::Elem,
242 ) -> crate::Result<Self::Output> {
243 let dw = &self.weights * delta.t().dot(input);
245 self.weights.apply_gradient(&dw, gamma)?;
247 self.bias.apply_gradient(&delta, gamma)?;
248 Ok(delta.pow2().sum())
250 }
251}
252
253impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
254where
255 A: Float + FromPrimitive + ScalarOperand,
256 S: Data<Elem = A>,
257 T: Data<Elem = A>,
258{
259 type Elem = A;
260 type Output = A;
261
262 fn backward(
263 &mut self,
264 input: &ArrayBase<S, Ix2>,
265 delta: &ArrayBase<T, Ix2>,
266 gamma: Self::Elem,
267 ) -> crate::Result<Self::Output> {
268 let weight_delta = input.dot(&delta.t());
270 let bias_delta = delta.sum_axis(Axis(0));
272
273 self.weights.apply_gradient(&weight_delta, gamma)?;
274 self.bias.apply_gradient(&bias_delta, gamma)?;
275 Ok(delta.pow2().sum())
277 }
278}
279
280impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
281where
282 A: Clone,
283 D: Dimension,
284 S: Data<Elem = A>,
285 for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
286 Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
287{
288 type Output = Z;
289
290 fn forward(&self, input: &X) -> crate::Result<Self::Output> {
291 let output = input.dot(&self.weights) + &self.bias;
292 Ok(output)
293 }
294}
295
296#[cfg(feature = "rand")]
297impl<A, S, D> crate::init::Initialize<S, D> for ParamsBase<S, D>
298where
299 D: ndarray::RemoveAxis,
300 S: ndarray::RawData<Elem = A>,
301{
302 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
303 where
304 Ds: rand_distr::Distribution<A>,
305 Sh: ndarray::ShapeBuilder<Dim = D>,
306 S: ndarray::DataOwned,
307 {
308 use rand::SeedableRng;
309 Self::rand_with(
310 shape,
311 distr,
312 &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
313 )
314 }
315
316 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
317 where
318 R: rand::Rng + ?Sized,
319 Ds: rand_distr::Distribution<A>,
320 Sh: ShapeBuilder<Dim = D>,
321 S: ndarray::DataOwned,
322 {
323 let shape = shape.into_shape_with_order();
324 let bias_shape = shape.raw_dim().remove_axis(Axis(0));
325 let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
326 let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
327 Self { bias, weights }
328 }
329}