1use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14 A: ScalarOperand + Float + FromPrimitive,
15 D: Dimension,
16 S: Data<Elem = A>,
17{
18 pub fn l1_norm(&self) -> A {
20 let bias = self.bias.l1_norm();
21 let weights = self.weights.l1_norm();
22 bias + weights
23 }
24 pub fn l2_norm(&self) -> A {
26 let bias = self.bias.l2_norm();
27 let weights = self.weights.l2_norm();
28 bias + weights
29 }
30 pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33 where
34 S: DataMut,
35 Self: ApplyGradient<Delta, A, Output = Z>,
36 {
37 <Self as ApplyGradient<Delta, A>>::apply_gradient(self, grad, lr)
38 }
39
40 pub fn apply_gradient_with_decay<Grad, Z>(
41 &mut self,
42 grad: &Grad,
43 lr: A,
44 decay: A,
45 ) -> crate::Result<Z>
46 where
47 S: DataMut,
48 Self: ApplyGradient<Grad, A, Output = Z>,
49 {
50 <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
51 }
52
53 pub fn apply_gradient_with_momentum<Grad, V, Z>(
54 &mut self,
55 grad: &Grad,
56 lr: A,
57 momentum: A,
58 velocity: &mut V,
59 ) -> crate::Result<Z>
60 where
61 S: DataMut,
62 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
63 {
64 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
65 self, grad, lr, momentum, velocity,
66 )
67 }
68
69 pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70 &mut self,
71 grad: &Grad,
72 lr: A,
73 decay: A,
74 momentum: A,
75 velocity: &mut V,
76 ) -> crate::Result<Z>
77 where
78 S: DataMut,
79 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
80 {
81 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
82 self, grad, lr, decay, momentum, velocity,
83 )
84 }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
88where
89 A: Float + FromPrimitive + ScalarOperand,
90 S: DataMut<Elem = A>,
91 T: Data<Elem = A>,
92 D: Dimension,
93{
94 type Output = ();
95
96 fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
97 self.bias.apply_gradient(grad.bias(), lr)?;
99 self.weights.apply_gradient(grad.weights(), lr)?;
101 Ok(())
102 }
103
104 fn apply_gradient_with_decay(
105 &mut self,
106 grad: &ParamsBase<T, D>,
107 lr: A,
108 decay: A,
109 ) -> crate::Result<Self::Output> {
110 self.bias
112 .apply_gradient_with_decay(grad.bias(), lr, decay)?;
113 self.weights
115 .apply_gradient_with_decay(grad.weights(), lr, decay)?;
116 Ok(())
117 }
118}
119
120impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
121where
122 A: Float + FromPrimitive + ScalarOperand,
123 S: DataMut<Elem = A>,
124 T: Data<Elem = A>,
125 D: Dimension,
126{
127 type Velocity = Params<A, D>;
128
129 fn apply_gradient_with_momentum(
130 &mut self,
131 grad: &ParamsBase<T, D>,
132 lr: A,
133 momentum: A,
134 velocity: &mut Self::Velocity,
135 ) -> crate::Result<()> {
136 self.bias
138 .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
139 self.weights.apply_gradient_with_momentum(
141 grad.weights(),
142 lr,
143 momentum,
144 velocity.weights_mut(),
145 )?;
146 Ok(())
147 }
148
149 fn apply_gradient_with_decay_and_momentum(
150 &mut self,
151 grad: &ParamsBase<T, D>,
152 lr: A,
153 decay: A,
154 momentum: A,
155 velocity: &mut Self::Velocity,
156 ) -> crate::Result<()> {
157 self.bias.apply_gradient_with_decay_and_momentum(
159 grad.bias(),
160 lr,
161 decay,
162 momentum,
163 velocity.bias_mut(),
164 )?;
165 self.weights.apply_gradient_with_decay_and_momentum(
167 grad.weights(),
168 lr,
169 decay,
170 momentum,
171 velocity.weights_mut(),
172 )?;
173 Ok(())
174 }
175}
176
177impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
178where
179 A: Float + FromPrimitive + ScalarOperand,
180 S: Data<Elem = A>,
181 T: Data<Elem = A>,
182{
183 type Elem = A;
184 type Output = A;
185
186 fn backward(
187 &mut self,
188 input: &ArrayBase<S, Ix2>,
189 delta: &ArrayBase<T, Ix1>,
190 gamma: Self::Elem,
191 ) -> crate::Result<Self::Output> {
192 let weight_delta = delta.t().dot(input);
194 self.weights.apply_gradient(&weight_delta, gamma)?;
196 self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
197 Ok(delta.pow2().sum())
199 }
200}
201
202impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
203where
204 A: Float + FromPrimitive + ScalarOperand,
205 S: Data<Elem = A>,
206 T: Data<Elem = A>,
207{
208 type Elem = A;
209 type Output = A;
210
211 fn backward(
212 &mut self,
213 input: &ArrayBase<S, Ix1>,
214 delta: &ArrayBase<T, Ix0>,
215 gamma: Self::Elem,
216 ) -> crate::Result<Self::Output> {
217 let weight_delta = input * delta;
219 self.weights.apply_gradient(&weight_delta, gamma)?;
221 self.bias.apply_gradient(&delta, gamma)?;
222 Ok(delta.pow2().sum())
224 }
225}
226
227impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
228where
229 A: Float + FromPrimitive + ScalarOperand,
230 S: Data<Elem = A>,
231 T: Data<Elem = A>,
232{
233 type Elem = A;
234 type Output = A;
235
236 fn backward(
237 &mut self,
238 input: &ArrayBase<S, Ix1>,
239 delta: &ArrayBase<T, Ix1>,
240 gamma: Self::Elem,
241 ) -> crate::Result<Self::Output> {
242 let dw = &self.weights * delta.t().dot(input);
244 self.weights.apply_gradient(&dw, gamma)?;
246 self.bias.apply_gradient(&delta, gamma)?;
247 Ok(delta.pow2().sum())
249 }
250}
251
252impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
253where
254 A: Float + FromPrimitive + ScalarOperand,
255 S: Data<Elem = A>,
256 T: Data<Elem = A>,
257{
258 type Elem = A;
259 type Output = A;
260
261 fn backward(
262 &mut self,
263 input: &ArrayBase<S, Ix2>,
264 delta: &ArrayBase<T, Ix2>,
265 gamma: Self::Elem,
266 ) -> crate::Result<Self::Output> {
267 let weight_delta = input.dot(&delta.t());
269 let bias_delta = delta.sum_axis(Axis(0));
271
272 self.weights.apply_gradient(&weight_delta, gamma)?;
273 self.bias.apply_gradient(&bias_delta, gamma)?;
274 Ok(delta.pow2().sum())
276 }
277}
278
279impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
280where
281 A: Clone,
282 D: Dimension,
283 S: Data<Elem = A>,
284 for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
285 Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
286{
287 type Output = Z;
288
289 fn forward(&self, input: &X) -> crate::Result<Self::Output> {
290 let output = input.dot(&self.weights) + &self.bias;
291 Ok(output)
292 }
293}
294
295#[cfg(feature = "rand")]
296impl<A, S, D> crate::init::Initialize<S, D> for ParamsBase<S, D>
297where
298 D: ndarray::RemoveAxis,
299 S: ndarray::RawData<Elem = A>,
300{
301 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
302 where
303 Ds: rand_distr::Distribution<A>,
304 Sh: ndarray::ShapeBuilder<Dim = D>,
305 S: ndarray::DataOwned,
306 {
307 use rand::SeedableRng;
308 Self::rand_with(
309 shape,
310 distr,
311 &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
312 )
313 }
314
315 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
316 where
317 R: rand::Rng + ?Sized,
318 Ds: rand_distr::Distribution<A>,
319 Sh: ShapeBuilder<Dim = D>,
320 S: ndarray::DataOwned,
321 {
322 let shape = shape.into_shape_with_order();
323 let bias_shape = shape.raw_dim().remove_axis(Axis(0));
324 let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
325 let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
326 Self { bias, weights }
327 }
328}