1use crate::params::{Params, ParamsBase};
6use crate::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14 A: ScalarOperand + Float + FromPrimitive,
15 D: Dimension,
16 S: Data<Elem = A>,
17{
18 pub fn l1_norm(&self) -> A {
20 let bias = self.bias().l1_norm();
21 let weights = self.weights().l1_norm();
22 bias + weights
23 }
24 pub fn l2_norm(&self) -> A {
26 let bias = self.bias().l2_norm();
27 let weights = self.weights().l2_norm();
28 bias + weights
29 }
30 pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33 where
34 S: DataMut,
35 Self: ApplyGradient<Delta, A, Output = Z>,
36 {
37 <Self as ApplyGradient<Delta, A>>::apply_gradient(self, grad, lr)
38 }
39
40 pub fn apply_gradient_with_decay<Grad, Z>(
41 &mut self,
42 grad: &Grad,
43 lr: A,
44 decay: A,
45 ) -> crate::Result<Z>
46 where
47 S: DataMut,
48 Self: ApplyGradient<Grad, A, Output = Z>,
49 {
50 <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
51 }
52
53 pub fn apply_gradient_with_momentum<Grad, V, Z>(
54 &mut self,
55 grad: &Grad,
56 lr: A,
57 momentum: A,
58 velocity: &mut V,
59 ) -> crate::Result<Z>
60 where
61 S: DataMut,
62 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
63 {
64 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
65 self, grad, lr, momentum, velocity,
66 )
67 }
68
69 pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70 &mut self,
71 grad: &Grad,
72 lr: A,
73 decay: A,
74 momentum: A,
75 velocity: &mut V,
76 ) -> crate::Result<Z>
77 where
78 S: DataMut,
79 Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
80 {
81 <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
82 self, grad, lr, decay, momentum, velocity,
83 )
84 }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
88where
89 A: Float + FromPrimitive + ScalarOperand,
90 S: DataMut<Elem = A>,
91 T: Data<Elem = A>,
92 D: Dimension,
93{
94 type Output = ();
95
96 fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
97 self.bias_mut().apply_gradient(grad.bias(), lr)?;
99 self.weights_mut().apply_gradient(grad.weights(), lr)?;
101 Ok(())
102 }
103
104 fn apply_gradient_with_decay(
105 &mut self,
106 grad: &ParamsBase<T, D>,
107 lr: A,
108 decay: A,
109 ) -> crate::Result<Self::Output> {
110 self.bias_mut()
112 .apply_gradient_with_decay(grad.bias(), lr, decay)?;
113 self.weights_mut()
115 .apply_gradient_with_decay(grad.weights(), lr, decay)?;
116 Ok(())
117 }
118}
119
120impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
121where
122 A: Float + FromPrimitive + ScalarOperand,
123 S: DataMut<Elem = A>,
124 T: Data<Elem = A>,
125 D: Dimension,
126{
127 type Velocity = Params<A, D>;
128
129 fn apply_gradient_with_momentum(
130 &mut self,
131 grad: &ParamsBase<T, D>,
132 lr: A,
133 momentum: A,
134 velocity: &mut Self::Velocity,
135 ) -> crate::Result<()> {
136 self.bias_mut().apply_gradient_with_momentum(
138 grad.bias(),
139 lr,
140 momentum,
141 velocity.bias_mut(),
142 )?;
143 self.weights_mut().apply_gradient_with_momentum(
145 grad.weights(),
146 lr,
147 momentum,
148 velocity.weights_mut(),
149 )?;
150 Ok(())
151 }
152
153 fn apply_gradient_with_decay_and_momentum(
154 &mut self,
155 grad: &ParamsBase<T, D>,
156 lr: A,
157 decay: A,
158 momentum: A,
159 velocity: &mut Self::Velocity,
160 ) -> crate::Result<()> {
161 self.bias_mut().apply_gradient_with_decay_and_momentum(
163 grad.bias(),
164 lr,
165 decay,
166 momentum,
167 velocity.bias_mut(),
168 )?;
169 self.weights_mut().apply_gradient_with_decay_and_momentum(
171 grad.weights(),
172 lr,
173 decay,
174 momentum,
175 velocity.weights_mut(),
176 )?;
177 Ok(())
178 }
179}
180
181impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
182where
183 A: Float + FromPrimitive + ScalarOperand,
184 S: Data<Elem = A>,
185 T: Data<Elem = A>,
186{
187 type Elem = A;
188 type Output = A;
189
190 fn backward(
191 &mut self,
192 input: &ArrayBase<S, Ix2>,
193 delta: &ArrayBase<T, Ix1>,
194 gamma: Self::Elem,
195 ) -> crate::Result<Self::Output> {
196 let weight_delta = delta.t().dot(input);
198 self.weights_mut().apply_gradient(&weight_delta, gamma)?;
200 self.bias_mut()
201 .apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
202 Ok(delta.pow2().sum())
204 }
205}
206
207impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
208where
209 A: Float + FromPrimitive + ScalarOperand,
210 S: Data<Elem = A>,
211 T: Data<Elem = A>,
212{
213 type Elem = A;
214 type Output = A;
215
216 fn backward(
217 &mut self,
218 input: &ArrayBase<S, Ix1>,
219 delta: &ArrayBase<T, Ix0>,
220 gamma: Self::Elem,
221 ) -> crate::Result<Self::Output> {
222 let weight_delta = input * delta;
224 self.weights_mut().apply_gradient(&weight_delta, gamma)?;
226 self.bias_mut().apply_gradient(delta, gamma)?;
227 Ok(delta.pow2().sum())
229 }
230}
231
232impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
233where
234 A: Float + FromPrimitive + ScalarOperand,
235 S: Data<Elem = A>,
236 T: Data<Elem = A>,
237{
238 type Elem = A;
239 type Output = A;
240
241 fn backward(
242 &mut self,
243 input: &ArrayBase<S, Ix1>,
244 delta: &ArrayBase<T, Ix1>,
245 gamma: Self::Elem,
246 ) -> crate::Result<Self::Output> {
247 let dw = &self.weights * delta.t().dot(input);
249 self.weights_mut().apply_gradient(&dw, gamma)?;
251 self.bias_mut().apply_gradient(delta, gamma)?;
252 Ok(delta.pow2().sum())
254 }
255}
256
257impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
258where
259 A: Float + FromPrimitive + ScalarOperand,
260 S: Data<Elem = A>,
261 T: Data<Elem = A>,
262{
263 type Elem = A;
264 type Output = A;
265
266 fn backward(
267 &mut self,
268 input: &ArrayBase<S, Ix2>,
269 delta: &ArrayBase<T, Ix2>,
270 gamma: Self::Elem,
271 ) -> crate::Result<Self::Output> {
272 let weight_delta = input.dot(&delta.t());
274 let bias_delta = delta.sum_axis(Axis(0));
276
277 self.weights_mut().apply_gradient(&weight_delta, gamma)?;
278 self.bias_mut().apply_gradient(&bias_delta, gamma)?;
279 Ok(delta.pow2().sum())
281 }
282}
283
284impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
285where
286 A: Clone,
287 D: Dimension,
288 S: Data<Elem = A>,
289 for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
290 Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
291{
292 type Output = Z;
293
294 fn forward(&self, input: &X) -> crate::Result<Self::Output> {
295 let output = input.dot(&self.weights) + &self.bias;
296 Ok(output)
297 }
298}