concision_core/params/impls/
impl_ops.rs

1/*
2    Appellation: impl_ops <module>
3    Contrib: @FL03
4*/
5use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14    A: ScalarOperand + Float + FromPrimitive,
15    D: Dimension,
16    S: Data<Elem = A>,
17{
18    /// Returns the L1 norm of the parameters (bias and weights).
19    pub fn l1_norm(&self) -> A {
20        let bias = self.bias.l1_norm();
21        let weights = self.weights.l1_norm();
22        bias + weights
23    }
24    /// Returns the L2 norm of the parameters (bias and weights).
25    pub fn l2_norm(&self) -> A {
26        let bias = self.bias.l2_norm();
27        let weights = self.weights.l2_norm();
28        bias + weights
29    }
30    /// a convenience method used to apply a gradient to the parameters using the given
31    /// learning rate.
32    pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> crate::Result<Z>
33    where
34        S: DataMut,
35        Self: ApplyGradient<Delta, Elem = A, Output = Z>,
36    {
37        <Self as ApplyGradient<Delta>>::apply_gradient(self, grad, lr)
38    }
39
40    pub fn apply_gradient_with_decay<Grad, Z>(
41        &mut self,
42        grad: &Grad,
43        lr: A,
44        decay: A,
45    ) -> crate::Result<Z>
46    where
47        S: DataMut,
48        Self: ApplyGradient<Grad, Elem = A, Output = Z>,
49    {
50        <Self as ApplyGradient<Grad>>::apply_gradient_with_decay(self, grad, lr, decay)
51    }
52
53    pub fn apply_gradient_with_momentum<Grad, V, Z>(
54        &mut self,
55        grad: &Grad,
56        lr: A,
57        momentum: A,
58        velocity: &mut V,
59    ) -> crate::Result<Z>
60    where
61        S: DataMut,
62        Self: ApplyGradientExt<Grad, Elem = A, Output = Z, Velocity = V>,
63    {
64        <Self as ApplyGradientExt<Grad>>::apply_gradient_with_momentum(
65            self, grad, lr, momentum, velocity,
66        )
67    }
68
69    pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
70        &mut self,
71        grad: &Grad,
72        lr: A,
73        decay: A,
74        momentum: A,
75        velocity: &mut V,
76    ) -> crate::Result<Z>
77    where
78        S: DataMut,
79        Self: ApplyGradientExt<Grad, Elem = A, Output = Z, Velocity = V>,
80    {
81        <Self as ApplyGradientExt<Grad>>::apply_gradient_with_decay_and_momentum(
82            self, grad, lr, decay, momentum, velocity,
83        )
84    }
85}
86
87impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>> for ParamsBase<S, D>
88where
89    A: Float + FromPrimitive + ScalarOperand,
90    S: DataMut<Elem = A>,
91    T: Data<Elem = A>,
92    D: Dimension,
93{
94    type Elem = A;
95    type Output = ();
96
97    fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
98        // apply the bias gradient
99        self.bias.apply_gradient(grad.bias(), lr)?;
100        // apply the weight gradient
101        self.weights.apply_gradient(grad.weights(), lr)?;
102        Ok(())
103    }
104
105    fn apply_gradient_with_decay(
106        &mut self,
107        grad: &ParamsBase<T, D>,
108        lr: A,
109        decay: A,
110    ) -> crate::Result<Self::Output> {
111        // apply the bias gradient
112        self.bias
113            .apply_gradient_with_decay(grad.bias(), lr, decay)?;
114        // apply the weight gradient
115        self.weights
116            .apply_gradient_with_decay(grad.weights(), lr, decay)?;
117        Ok(())
118    }
119}
120
121impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>> for ParamsBase<S, D>
122where
123    A: Float + FromPrimitive + ScalarOperand,
124    S: DataMut<Elem = A>,
125    T: Data<Elem = A>,
126    D: Dimension,
127{
128    type Velocity = Params<A, D>;
129
130    fn apply_gradient_with_momentum(
131        &mut self,
132        grad: &ParamsBase<T, D>,
133        lr: A,
134        momentum: A,
135        velocity: &mut Self::Velocity,
136    ) -> crate::Result<()> {
137        // apply the bias gradient
138        self.bias
139            .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
140        // apply the weight gradient
141        self.weights.apply_gradient_with_momentum(
142            grad.weights(),
143            lr,
144            momentum,
145            velocity.weights_mut(),
146        )?;
147        Ok(())
148    }
149
150    fn apply_gradient_with_decay_and_momentum(
151        &mut self,
152        grad: &ParamsBase<T, D>,
153        lr: A,
154        decay: A,
155        momentum: A,
156        velocity: &mut Self::Velocity,
157    ) -> crate::Result<()> {
158        // apply the bias gradient
159        self.bias.apply_gradient_with_decay_and_momentum(
160            grad.bias(),
161            lr,
162            decay,
163            momentum,
164            velocity.bias_mut(),
165        )?;
166        // apply the weight gradient
167        self.weights.apply_gradient_with_decay_and_momentum(
168            grad.weights(),
169            lr,
170            decay,
171            momentum,
172            velocity.weights_mut(),
173        )?;
174        Ok(())
175    }
176}
177
178impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
179where
180    A: Float + FromPrimitive + ScalarOperand,
181    S: Data<Elem = A>,
182    T: Data<Elem = A>,
183{
184    type Elem = A;
185    type Output = A;
186
187    fn backward(
188        &mut self,
189        input: &ArrayBase<S, Ix2>,
190        delta: &ArrayBase<T, Ix1>,
191        gamma: Self::Elem,
192    ) -> crate::Result<Self::Output> {
193        // compute the weight gradient
194        let weight_delta = delta.t().dot(input);
195        // update the weights and bias
196        self.weights.apply_gradient(&weight_delta, gamma)?;
197        self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
198        // return the sum of the squared delta
199        Ok(delta.pow2().sum())
200    }
201}
202
203impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
204where
205    A: Float + FromPrimitive + ScalarOperand,
206    S: Data<Elem = A>,
207    T: Data<Elem = A>,
208{
209    type Elem = A;
210    type Output = A;
211
212    fn backward(
213        &mut self,
214        input: &ArrayBase<S, Ix1>,
215        delta: &ArrayBase<T, Ix0>,
216        gamma: Self::Elem,
217    ) -> crate::Result<Self::Output> {
218        // compute the weight gradient
219        let weight_delta = input * delta;
220        // update the weights and bias
221        self.weights.apply_gradient(&weight_delta, gamma)?;
222        self.bias.apply_gradient(&delta, gamma)?;
223        // return the sum of the squared delta
224        Ok(delta.pow2().sum())
225    }
226}
227
228impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
229where
230    A: Float + FromPrimitive + ScalarOperand,
231    S: Data<Elem = A>,
232    T: Data<Elem = A>,
233{
234    type Elem = A;
235    type Output = A;
236
237    fn backward(
238        &mut self,
239        input: &ArrayBase<S, Ix1>,
240        delta: &ArrayBase<T, Ix1>,
241        gamma: Self::Elem,
242    ) -> crate::Result<Self::Output> {
243        // compute the weight gradient
244        let dw = &self.weights * delta.t().dot(input);
245        // update the weights and bias
246        self.weights.apply_gradient(&dw, gamma)?;
247        self.bias.apply_gradient(&delta, gamma)?;
248        // return the sum of the squared delta
249        Ok(delta.pow2().sum())
250    }
251}
252
253impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
254where
255    A: Float + FromPrimitive + ScalarOperand,
256    S: Data<Elem = A>,
257    T: Data<Elem = A>,
258{
259    type Elem = A;
260    type Output = A;
261
262    fn backward(
263        &mut self,
264        input: &ArrayBase<S, Ix2>,
265        delta: &ArrayBase<T, Ix2>,
266        gamma: Self::Elem,
267    ) -> crate::Result<Self::Output> {
268        // compute the weight gradient
269        let weight_delta = input.dot(&delta.t());
270        // compute the bias gradient
271        let bias_delta = delta.sum_axis(Axis(0));
272
273        self.weights.apply_gradient(&weight_delta, gamma)?;
274        self.bias.apply_gradient(&bias_delta, gamma)?;
275        // return the sum of the squared delta
276        Ok(delta.pow2().sum())
277    }
278}
279
280impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
281where
282    A: Clone,
283    D: Dimension,
284    S: Data<Elem = A>,
285    for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
286    Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
287{
288    type Output = Z;
289
290    fn forward(&self, input: &X) -> crate::Result<Self::Output> {
291        let output = input.dot(&self.weights) + &self.bias;
292        Ok(output)
293    }
294}
295
296#[cfg(feature = "rand")]
297impl<A, S, D> crate::init::Initialize<S, D> for ParamsBase<S, D>
298where
299    D: ndarray::RemoveAxis,
300    S: ndarray::RawData<Elem = A>,
301{
302    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
303    where
304        Ds: rand_distr::Distribution<A>,
305        Sh: ndarray::ShapeBuilder<Dim = D>,
306        S: ndarray::DataOwned,
307    {
308        use rand::SeedableRng;
309        Self::rand_with(
310            shape,
311            distr,
312            &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
313        )
314    }
315
316    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
317    where
318        R: rand::Rng + ?Sized,
319        Ds: rand_distr::Distribution<A>,
320        Sh: ShapeBuilder<Dim = D>,
321        S: ndarray::DataOwned,
322    {
323        let shape = shape.into_shape_with_order();
324        let bias_shape = shape.raw_dim().remove_axis(Axis(0));
325        let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
326        let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
327        Self { bias, weights }
328    }
329}