concision_core/params/impls/
impl_ops.rs

1/*
2    Appellation: impl_ops <module>
3    Contrib: @FL03
4*/
5use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14    A: ScalarOperand + Float + FromPrimitive,
15    D: Dimension,
16    S: Data<Elem = A>,
17{
18    /// Returns the L1 norm of the parameters (bias and weights).
19    pub fn l1_norm(&self) -> A {
20        let bias = self.bias.l1_norm();
21        let weights = self.weights.l1_norm();
22        bias + weights
23    }
24    /// Returns the L2 norm of the parameters (bias and weights).
25    pub fn l2_norm(&self) -> A {
26        let bias = self.bias.l2_norm();
27        let weights = self.weights.l2_norm();
28        bias + weights
29    }
30    /// a convenience method used to apply a gradient to the parameters using the given
31    /// learning rate.
32    pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33    where
34        S: DataMut,
35        Self: ApplyGradient<Delta, A, Output = Z>,
36    {
37        <Self as ApplyGradient<Delta, A>>::apply_gradient(self, grad, lr)
38    }
39
40    pub fn apply_gradient_with_decay<Grad, Z>(
41        &mut self,
42        grad: &Grad,
43        lr: A,
44        decay: A,
45    ) -> crate::Result<Z>
46    where
47        S: DataMut,
48        Self: ApplyGradient<Grad, A, Output = Z>,
49    {
50        <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
51    }
52
53    pub fn apply_gradient_with_momentum<Grad, V, Z>(
54        &mut self,
55        grad: &Grad,
56        lr: A,
57        momentum: A,
58        velocity: &mut V,
59    ) -> crate::Result<Z>
60    where
61        S: DataMut,
62        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
63    {
64        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
65            self, grad, lr, momentum, velocity,
66        )
67    }
68
69    pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70        &mut self,
71        grad: &Grad,
72        lr: A,
73        decay: A,
74        momentum: A,
75        velocity: &mut V,
76    ) -> crate::Result<Z>
77    where
78        S: DataMut,
79        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
80    {
81        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
82            self, grad, lr, decay, momentum, velocity,
83        )
84    }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
88where
89    A: Float + FromPrimitive + ScalarOperand,
90    S: DataMut<Elem = A>,
91    T: Data<Elem = A>,
92    D: Dimension,
93{
94    type Output = ();
95
96    fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
97        // apply the bias gradient
98        self.bias.apply_gradient(grad.bias(), lr)?;
99        // apply the weight gradient
100        self.weights.apply_gradient(grad.weights(), lr)?;
101        Ok(())
102    }
103
104    fn apply_gradient_with_decay(
105        &mut self,
106        grad: &ParamsBase<T, D>,
107        lr: A,
108        decay: A,
109    ) -> crate::Result<Self::Output> {
110        // apply the bias gradient
111        self.bias
112            .apply_gradient_with_decay(grad.bias(), lr, decay)?;
113        // apply the weight gradient
114        self.weights
115            .apply_gradient_with_decay(grad.weights(), lr, decay)?;
116        Ok(())
117    }
118}
119
120impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
121where
122    A: Float + FromPrimitive + ScalarOperand,
123    S: DataMut<Elem = A>,
124    T: Data<Elem = A>,
125    D: Dimension,
126{
127    type Velocity = Params<A, D>;
128
129    fn apply_gradient_with_momentum(
130        &mut self,
131        grad: &ParamsBase<T, D>,
132        lr: A,
133        momentum: A,
134        velocity: &mut Self::Velocity,
135    ) -> crate::Result<()> {
136        // apply the bias gradient
137        self.bias
138            .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
139        // apply the weight gradient
140        self.weights.apply_gradient_with_momentum(
141            grad.weights(),
142            lr,
143            momentum,
144            velocity.weights_mut(),
145        )?;
146        Ok(())
147    }
148
149    fn apply_gradient_with_decay_and_momentum(
150        &mut self,
151        grad: &ParamsBase<T, D>,
152        lr: A,
153        decay: A,
154        momentum: A,
155        velocity: &mut Self::Velocity,
156    ) -> crate::Result<()> {
157        // apply the bias gradient
158        self.bias.apply_gradient_with_decay_and_momentum(
159            grad.bias(),
160            lr,
161            decay,
162            momentum,
163            velocity.bias_mut(),
164        )?;
165        // apply the weight gradient
166        self.weights.apply_gradient_with_decay_and_momentum(
167            grad.weights(),
168            lr,
169            decay,
170            momentum,
171            velocity.weights_mut(),
172        )?;
173        Ok(())
174    }
175}
176
177impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
178where
179    A: Float + FromPrimitive + ScalarOperand,
180    S: Data<Elem = A>,
181    T: Data<Elem = A>,
182{
183    type Elem = A;
184    type Output = A;
185
186    fn backward(
187        &mut self,
188        input: &ArrayBase<S, Ix2>,
189        delta: &ArrayBase<T, Ix1>,
190        gamma: Self::Elem,
191    ) -> crate::Result<Self::Output> {
192        // compute the weight gradient
193        let weight_delta = delta.t().dot(input);
194        // update the weights and bias
195        self.weights.apply_gradient(&weight_delta, gamma)?;
196        self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
197        // return the sum of the squared delta
198        Ok(delta.pow2().sum())
199    }
200}
201
202impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
203where
204    A: Float + FromPrimitive + ScalarOperand,
205    S: Data<Elem = A>,
206    T: Data<Elem = A>,
207{
208    type Elem = A;
209    type Output = A;
210
211    fn backward(
212        &mut self,
213        input: &ArrayBase<S, Ix1>,
214        delta: &ArrayBase<T, Ix0>,
215        gamma: Self::Elem,
216    ) -> crate::Result<Self::Output> {
217        // compute the weight gradient
218        let weight_delta = input * delta;
219        // update the weights and bias
220        self.weights.apply_gradient(&weight_delta, gamma)?;
221        self.bias.apply_gradient(&delta, gamma)?;
222        // return the sum of the squared delta
223        Ok(delta.pow2().sum())
224    }
225}
226
227impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
228where
229    A: Float + FromPrimitive + ScalarOperand,
230    S: Data<Elem = A>,
231    T: Data<Elem = A>,
232{
233    type Elem = A;
234    type Output = A;
235
236    fn backward(
237        &mut self,
238        input: &ArrayBase<S, Ix1>,
239        delta: &ArrayBase<T, Ix1>,
240        gamma: Self::Elem,
241    ) -> crate::Result<Self::Output> {
242        // compute the weight gradient
243        let dw = &self.weights * delta.t().dot(input);
244        // update the weights and bias
245        self.weights.apply_gradient(&dw, gamma)?;
246        self.bias.apply_gradient(&delta, gamma)?;
247        // return the sum of the squared delta
248        Ok(delta.pow2().sum())
249    }
250}
251
252impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
253where
254    A: Float + FromPrimitive + ScalarOperand,
255    S: Data<Elem = A>,
256    T: Data<Elem = A>,
257{
258    type Elem = A;
259    type Output = A;
260
261    fn backward(
262        &mut self,
263        input: &ArrayBase<S, Ix2>,
264        delta: &ArrayBase<T, Ix2>,
265        gamma: Self::Elem,
266    ) -> crate::Result<Self::Output> {
267        // compute the weight gradient
268        let weight_delta = input.dot(&delta.t());
269        // compute the bias gradient
270        let bias_delta = delta.sum_axis(Axis(0));
271
272        self.weights.apply_gradient(&weight_delta, gamma)?;
273        self.bias.apply_gradient(&bias_delta, gamma)?;
274        // return the sum of the squared delta
275        Ok(delta.pow2().sum())
276    }
277}
278
279impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
280where
281    A: Clone,
282    D: Dimension,
283    S: Data<Elem = A>,
284    for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
285    Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
286{
287    type Output = Z;
288
289    fn forward(&self, input: &X) -> crate::Result<Self::Output> {
290        let output = input.dot(&self.weights) + &self.bias;
291        Ok(output)
292    }
293}
294
295#[cfg(feature = "rand")]
296impl<A, S, D> crate::init::Initialize<S, D> for ParamsBase<S, D>
297where
298    D: ndarray::RemoveAxis,
299    S: ndarray::RawData<Elem = A>,
300{
301    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
302    where
303        Ds: rand_distr::Distribution<A>,
304        Sh: ndarray::ShapeBuilder<Dim = D>,
305        S: ndarray::DataOwned,
306    {
307        use rand::SeedableRng;
308        Self::rand_with(
309            shape,
310            distr,
311            &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
312        )
313    }
314
315    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
316    where
317        R: rand::Rng + ?Sized,
318        Ds: rand_distr::Distribution<A>,
319        Sh: ShapeBuilder<Dim = D>,
320        S: ndarray::DataOwned,
321    {
322        let shape = shape.into_shape_with_order();
323        let bias_shape = shape.raw_dim().remove_axis(Axis(0));
324        let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
325        let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
326        Self { bias, weights }
327    }
328}