1use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14 A: ScalarOperand + Float + FromPrimitive,
15 D: Dimension,
16 S: Data<Elem = A>,
17{
18 pub fn l1_norm(&self) -> A {
20 let bias = self.bias.l1_norm();
21 let weights = self.weights.l1_norm();
22 bias + weights
23 }
24 pub fn l2_norm(&self) -> A {
26 let bias = self.bias.l2_norm();
27 let weights = self.weights.l2_norm();
28 bias + weights
29 }
30
31 pub fn apply_gradient<Grad, Z>(&mut self, grad: &Grad, lr: A) -> crate::Result<Z>
32 where
33 S: DataMut,
34 Self: ApplyGradient<Grad, A, Output = Z>,
35 {
36 <Self as ApplyGradient<Grad, A>>::apply_gradient(self, grad, lr)
37 }
38
39 pub fn apply_gradient_with_decay<Grad, Z>(
40 &mut self,
41 grad: &Grad,
42 lr: A,
43 decay: A,
44 ) -> crate::Result<Z>
45 where
46 S: DataMut,
47 Self: ApplyGradient<Grad, A, Output = Z>,
48 {
49 <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
50 }
51
52 pub fn apply_gradient_with_momentum<Grad, V, Z>(
53 &mut self,
54 grad: &Grad,
55 lr: A,
56 momentum: A,
57 velocity: &mut V,
58 ) -> crate::Result<Z>
59 where
60 S: DataMut,
61 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
62 {
63 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
64 self, grad, lr, momentum, velocity,
65 )
66 }
67
68 pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
69 &mut self,
70 grad: &Grad,
71 lr: A,
72 decay: A,
73 momentum: A,
74 velocity: &mut V,
75 ) -> crate::Result<Z>
76 where
77 S: DataMut,
78 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
79 {
80 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
81 self, grad, lr, decay, momentum, velocity,
82 )
83 }
84}
85
86impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
87where
88 A: Float + FromPrimitive + ScalarOperand,
89 S: DataMut<Elem = A>,
90 T: Data<Elem = A>,
91 D: Dimension,
92{
93 type Output = ();
94
95 fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
96 self.bias.apply_gradient(grad.bias(), lr)?;
98 self.weights.apply_gradient(grad.weights(), lr)?;
100 Ok(())
101 }
102
103 fn apply_gradient_with_decay(
104 &mut self,
105 grad: &ParamsBase<T, D>,
106 lr: A,
107 decay: A,
108 ) -> crate::Result<Self::Output> {
109 self.bias
111 .apply_gradient_with_decay(grad.bias(), lr, decay)?;
112 self.weights
114 .apply_gradient_with_decay(grad.weights(), lr, decay)?;
115 Ok(())
116 }
117}
118
119impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
120where
121 A: Float + FromPrimitive + ScalarOperand,
122 S: DataMut<Elem = A>,
123 T: Data<Elem = A>,
124 D: Dimension,
125{
126 type Velocity = Params<A, D>;
127
128 fn apply_gradient_with_momentum(
129 &mut self,
130 grad: &ParamsBase<T, D>,
131 lr: A,
132 momentum: A,
133 velocity: &mut Self::Velocity,
134 ) -> crate::Result<()> {
135 self.bias
137 .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
138 self.weights.apply_gradient_with_momentum(
140 grad.weights(),
141 lr,
142 momentum,
143 velocity.weights_mut(),
144 )?;
145 Ok(())
146 }
147
148 fn apply_gradient_with_decay_and_momentum(
149 &mut self,
150 grad: &ParamsBase<T, D>,
151 lr: A,
152 decay: A,
153 momentum: A,
154 velocity: &mut Self::Velocity,
155 ) -> crate::Result<()> {
156 self.bias.apply_gradient_with_decay_and_momentum(
158 grad.bias(),
159 lr,
160 decay,
161 momentum,
162 velocity.bias_mut(),
163 )?;
164 self.weights.apply_gradient_with_decay_and_momentum(
166 grad.weights(),
167 lr,
168 decay,
169 momentum,
170 velocity.weights_mut(),
171 )?;
172 Ok(())
173 }
174}
175
176impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
177where
178 A: Float + FromPrimitive + ScalarOperand,
179 S: Data<Elem = A>,
180 T: Data<Elem = A>,
181{
182 type HParam = A;
183 type Output = A;
184
185 fn backward(
186 &mut self,
187 input: &ArrayBase<S, Ix2>,
188 delta: &ArrayBase<T, Ix1>,
189 gamma: Self::HParam,
190 ) -> crate::Result<Self::Output> {
191 let weight_delta = delta.t().dot(input);
193 self.weights.apply_gradient(&weight_delta, gamma)?;
195 self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
196 Ok(delta.pow2().sum())
198 }
199}
200
201impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
202where
203 A: Float + FromPrimitive + ScalarOperand,
204 S: Data<Elem = A>,
205 T: Data<Elem = A>,
206{
207 type HParam = A;
208 type Output = A;
209
210 fn backward(
211 &mut self,
212 input: &ArrayBase<S, Ix1>,
213 delta: &ArrayBase<T, Ix0>,
214 gamma: Self::HParam,
215 ) -> crate::Result<Self::Output> {
216 let weight_delta = input * delta;
218 self.weights.apply_gradient(&weight_delta, gamma)?;
220 self.bias.apply_gradient(&delta, gamma)?;
221 Ok(delta.pow2().sum())
223 }
224}
225
226impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
227where
228 A: Float + FromPrimitive + ScalarOperand,
229 S: Data<Elem = A>,
230 T: Data<Elem = A>,
231{
232 type HParam = A;
233 type Output = A;
234
235 fn backward(
236 &mut self,
237 input: &ArrayBase<S, Ix1>,
238 delta: &ArrayBase<T, Ix1>,
239 gamma: Self::HParam,
240 ) -> crate::Result<Self::Output> {
241 let dw = &self.weights * delta.t().dot(input);
243 self.weights.apply_gradient(&dw, gamma)?;
245 self.bias.apply_gradient(&delta, gamma)?;
246 Ok(delta.pow2().sum())
248 }
249}
250
251impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
252where
253 A: Float + FromPrimitive + ScalarOperand,
254 S: Data<Elem = A>,
255 T: Data<Elem = A>,
256{
257 type HParam = A;
258 type Output = A;
259
260 fn backward(
261 &mut self,
262 input: &ArrayBase<S, Ix2>,
263 delta: &ArrayBase<T, Ix2>,
264 gamma: Self::HParam,
265 ) -> crate::Result<Self::Output> {
266 let weight_delta = input.dot(&delta.t());
268 let bias_delta = delta.sum_axis(Axis(0));
270
271 self.weights.apply_gradient(&weight_delta, gamma)?;
272 self.bias.apply_gradient(&bias_delta, gamma)?;
273 Ok(delta.pow2().sum())
275 }
276}
277
278impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
279where
280 A: Clone,
281 D: Dimension,
282 S: Data<Elem = A>,
283 for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
284 Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
285{
286 type Output = Z;
287
288 fn forward(&self, input: &X) -> crate::Result<Self::Output> {
289 let output = input.dot(&self.weights) + &self.bias;
290 Ok(output)
291 }
292}
293
294#[cfg(feature = "rand")]
295impl<A, S, D> crate::init::Initialize<A, D> for ParamsBase<S, D>
296where
297 D: ndarray::RemoveAxis,
298 S: ndarray::RawData<Elem = A>,
299{
300 type Data = S;
301
302 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
303 where
304 Ds: rand_distr::Distribution<A>,
305 Sh: ndarray::ShapeBuilder<Dim = D>,
306 S: ndarray::DataOwned,
307 {
308 use rand::SeedableRng;
309 Self::rand_with(
310 shape,
311 distr,
312 &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
313 )
314 }
315
316 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
317 where
318 R: rand::Rng + ?Sized,
319 Ds: rand_distr::Distribution<A>,
320 Sh: ShapeBuilder<Dim = D>,
321 S: ndarray::DataOwned,
322 {
323 let shape = shape.into_shape_with_order();
324 let bias_shape = shape.raw_dim().remove_axis(Axis(0));
325 let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
326 let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
327 Self { bias, weights }
328 }
329}