concision_core/params/impls/
impl_params_ops.rs

1/*
2    Appellation: impl_ops <module>
3    Contrib: @FL03
4*/
5use crate::params::{Params, ParamsBase};
6use crate::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14    A: ScalarOperand + Float + FromPrimitive,
15    D: Dimension,
16    S: Data<Elem = A>,
17{
18    /// Returns the L1 norm of the parameters (bias and weights).
19    pub fn l1_norm(&self) -> A {
20        let bias = self.bias().l1_norm();
21        let weights = self.weights().l1_norm();
22        bias + weights
23    }
24    /// Returns the L2 norm of the parameters (bias and weights).
25    pub fn l2_norm(&self) -> A {
26        let bias = self.bias().l2_norm();
27        let weights = self.weights().l2_norm();
28        bias + weights
29    }
30    /// a convenience method used to apply a gradient to the parameters using the given
31    /// learning rate.
32    pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33    where
34        S: DataMut,
35        Self: ApplyGradient<Delta, A, Output = Z>,
36    {
37        <Self as ApplyGradient<Delta, A>>::apply_gradient(self, grad, lr)
38    }
39
40    pub fn apply_gradient_with_decay<Grad, Z>(
41        &mut self,
42        grad: &Grad,
43        lr: A,
44        decay: A,
45    ) -> crate::Result<Z>
46    where
47        S: DataMut,
48        Self: ApplyGradient<Grad, A, Output = Z>,
49    {
50        <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
51    }
52
53    pub fn apply_gradient_with_momentum<Grad, V, Z>(
54        &mut self,
55        grad: &Grad,
56        lr: A,
57        momentum: A,
58        velocity: &mut V,
59    ) -> crate::Result<Z>
60    where
61        S: DataMut,
62        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
63    {
64        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
65            self, grad, lr, momentum, velocity,
66        )
67    }
68
69    pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70        &mut self,
71        grad: &Grad,
72        lr: A,
73        decay: A,
74        momentum: A,
75        velocity: &mut V,
76    ) -> crate::Result<Z>
77    where
78        S: DataMut,
79        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
80    {
81        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
82            self, grad, lr, decay, momentum, velocity,
83        )
84    }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
88where
89    A: Float + FromPrimitive + ScalarOperand,
90    S: DataMut<Elem = A>,
91    T: Data<Elem = A>,
92    D: Dimension,
93{
94    type Output = ();
95
96    fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
97        // apply the bias gradient
98        self.bias_mut().apply_gradient(grad.bias(), lr)?;
99        // apply the weight gradient
100        self.weights_mut().apply_gradient(grad.weights(), lr)?;
101        Ok(())
102    }
103
104    fn apply_gradient_with_decay(
105        &mut self,
106        grad: &ParamsBase<T, D>,
107        lr: A,
108        decay: A,
109    ) -> crate::Result<Self::Output> {
110        // apply the bias gradient
111        self.bias_mut()
112            .apply_gradient_with_decay(grad.bias(), lr, decay)?;
113        // apply the weight gradient
114        self.weights_mut()
115            .apply_gradient_with_decay(grad.weights(), lr, decay)?;
116        Ok(())
117    }
118}
119
120impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
121where
122    A: Float + FromPrimitive + ScalarOperand,
123    S: DataMut<Elem = A>,
124    T: Data<Elem = A>,
125    D: Dimension,
126{
127    type Velocity = Params<A, D>;
128
129    fn apply_gradient_with_momentum(
130        &mut self,
131        grad: &ParamsBase<T, D>,
132        lr: A,
133        momentum: A,
134        velocity: &mut Self::Velocity,
135    ) -> crate::Result<()> {
136        // apply the bias gradient
137        self.bias_mut().apply_gradient_with_momentum(
138            grad.bias(),
139            lr,
140            momentum,
141            velocity.bias_mut(),
142        )?;
143        // apply the weight gradient
144        self.weights_mut().apply_gradient_with_momentum(
145            grad.weights(),
146            lr,
147            momentum,
148            velocity.weights_mut(),
149        )?;
150        Ok(())
151    }
152
153    fn apply_gradient_with_decay_and_momentum(
154        &mut self,
155        grad: &ParamsBase<T, D>,
156        lr: A,
157        decay: A,
158        momentum: A,
159        velocity: &mut Self::Velocity,
160    ) -> crate::Result<()> {
161        // apply the bias gradient
162        self.bias_mut().apply_gradient_with_decay_and_momentum(
163            grad.bias(),
164            lr,
165            decay,
166            momentum,
167            velocity.bias_mut(),
168        )?;
169        // apply the weight gradient
170        self.weights_mut().apply_gradient_with_decay_and_momentum(
171            grad.weights(),
172            lr,
173            decay,
174            momentum,
175            velocity.weights_mut(),
176        )?;
177        Ok(())
178    }
179}
180
181impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
182where
183    A: Float + FromPrimitive + ScalarOperand,
184    S: Data<Elem = A>,
185    T: Data<Elem = A>,
186{
187    type Elem = A;
188    type Output = A;
189
190    fn backward(
191        &mut self,
192        input: &ArrayBase<S, Ix2>,
193        delta: &ArrayBase<T, Ix1>,
194        gamma: Self::Elem,
195    ) -> crate::Result<Self::Output> {
196        // compute the weight gradient
197        let weight_delta = delta.t().dot(input);
198        // update the weights and bias
199        self.weights_mut().apply_gradient(&weight_delta, gamma)?;
200        self.bias_mut()
201            .apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
202        // return the sum of the squared delta
203        Ok(delta.pow2().sum())
204    }
205}
206
207impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
208where
209    A: Float + FromPrimitive + ScalarOperand,
210    S: Data<Elem = A>,
211    T: Data<Elem = A>,
212{
213    type Elem = A;
214    type Output = A;
215
216    fn backward(
217        &mut self,
218        input: &ArrayBase<S, Ix1>,
219        delta: &ArrayBase<T, Ix0>,
220        gamma: Self::Elem,
221    ) -> crate::Result<Self::Output> {
222        // compute the weight gradient
223        let weight_delta = input * delta;
224        // update the weights and bias
225        self.weights_mut().apply_gradient(&weight_delta, gamma)?;
226        self.bias_mut().apply_gradient(delta, gamma)?;
227        // return the sum of the squared delta
228        Ok(delta.pow2().sum())
229    }
230}
231
232impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
233where
234    A: Float + FromPrimitive + ScalarOperand,
235    S: Data<Elem = A>,
236    T: Data<Elem = A>,
237{
238    type Elem = A;
239    type Output = A;
240
241    fn backward(
242        &mut self,
243        input: &ArrayBase<S, Ix1>,
244        delta: &ArrayBase<T, Ix1>,
245        gamma: Self::Elem,
246    ) -> crate::Result<Self::Output> {
247        // compute the weight gradient
248        let dw = &self.weights * delta.t().dot(input);
249        // update the weights and bias
250        self.weights_mut().apply_gradient(&dw, gamma)?;
251        self.bias_mut().apply_gradient(delta, gamma)?;
252        // return the sum of the squared delta
253        Ok(delta.pow2().sum())
254    }
255}
256
257impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
258where
259    A: Float + FromPrimitive + ScalarOperand,
260    S: Data<Elem = A>,
261    T: Data<Elem = A>,
262{
263    type Elem = A;
264    type Output = A;
265
266    fn backward(
267        &mut self,
268        input: &ArrayBase<S, Ix2>,
269        delta: &ArrayBase<T, Ix2>,
270        gamma: Self::Elem,
271    ) -> crate::Result<Self::Output> {
272        // compute the weight gradient
273        let weight_delta = input.dot(&delta.t());
274        // compute the bias gradient
275        let bias_delta = delta.sum_axis(Axis(0));
276
277        self.weights_mut().apply_gradient(&weight_delta, gamma)?;
278        self.bias_mut().apply_gradient(&bias_delta, gamma)?;
279        // return the sum of the squared delta
280        Ok(delta.pow2().sum())
281    }
282}
283
284impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
285where
286    A: Clone,
287    D: Dimension,
288    S: Data<Elem = A>,
289    for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
290    Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
291{
292    type Output = Z;
293
294    fn forward(&self, input: &X) -> crate::Result<Self::Output> {
295        let output = input.dot(&self.weights) + &self.bias;
296        Ok(output)
297    }
298}