concision_core/params/impls/
impl_ops.rs

1/*
2    Appellation: impl_ops <module>
3    Contrib: @FL03
4*/
5use crate::params::{Params, ParamsBase};
6use crate::traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::prelude::*;
9use ndarray::{ArrayBase, Data, DataMut, Dimension, ScalarOperand};
10use num_traits::{Float, FromPrimitive};
11
12impl<A, S, D> ParamsBase<S, D>
13where
14    A: ScalarOperand + Float + FromPrimitive,
15    D: Dimension,
16    S: Data<Elem = A>,
17{
18    /// Returns the L1 norm of the parameters (bias and weights).
19    pub fn l1_norm(&self) -> A {
20        let bias = self.bias.l1_norm();
21        let weights = self.weights.l1_norm();
22        bias + weights
23    }
24    /// Returns the L2 norm of the parameters (bias and weights).
25    pub fn l2_norm(&self) -> A {
26        let bias = self.bias.l2_norm();
27        let weights = self.weights.l2_norm();
28        bias + weights
29    }
30
31    pub fn apply_gradient<Grad, Z>(&mut self, grad: &Grad, lr: A) -> crate::Result<Z>
32    where
33        S: DataMut,
34        Self: ApplyGradient<Grad, A, Output = Z>,
35    {
36        <Self as ApplyGradient<Grad, A>>::apply_gradient(self, grad, lr)
37    }
38
39    pub fn apply_gradient_with_decay<Grad, Z>(
40        &mut self,
41        grad: &Grad,
42        lr: A,
43        decay: A,
44    ) -> crate::Result<Z>
45    where
46        S: DataMut,
47        Self: ApplyGradient<Grad, A, Output = Z>,
48    {
49        <Self as ApplyGradient<Grad, A>>::apply_gradient_with_decay(self, grad, lr, decay)
50    }
51
52    pub fn apply_gradient_with_momentum<Grad, V, Z>(
53        &mut self,
54        grad: &Grad,
55        lr: A,
56        momentum: A,
57        velocity: &mut V,
58    ) -> crate::Result<Z>
59    where
60        S: DataMut,
61        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
62    {
63        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_momentum(
64            self, grad, lr, momentum, velocity,
65        )
66    }
67
68    pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
69        &mut self,
70        grad: &Grad,
71        lr: A,
72        decay: A,
73        momentum: A,
74        velocity: &mut V,
75    ) -> crate::Result<Z>
76    where
77        S: DataMut,
78        Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
79    {
80        <Self as ApplyGradientExt<Grad, A>>::apply_gradient_with_decay_and_momentum(
81            self, grad, lr, decay, momentum, velocity,
82        )
83    }
84}
85
86impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
87where
88    A: Float + FromPrimitive + ScalarOperand,
89    S: DataMut<Elem = A>,
90    T: Data<Elem = A>,
91    D: Dimension,
92{
93    type Output = ();
94
95    fn apply_gradient(&mut self, grad: &ParamsBase<T, D>, lr: A) -> crate::Result<Self::Output> {
96        // apply the bias gradient
97        self.bias.apply_gradient(grad.bias(), lr)?;
98        // apply the weight gradient
99        self.weights.apply_gradient(grad.weights(), lr)?;
100        Ok(())
101    }
102
103    fn apply_gradient_with_decay(
104        &mut self,
105        grad: &ParamsBase<T, D>,
106        lr: A,
107        decay: A,
108    ) -> crate::Result<Self::Output> {
109        // apply the bias gradient
110        self.bias
111            .apply_gradient_with_decay(grad.bias(), lr, decay)?;
112        // apply the weight gradient
113        self.weights
114            .apply_gradient_with_decay(grad.weights(), lr, decay)?;
115        Ok(())
116    }
117}
118
119impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
120where
121    A: Float + FromPrimitive + ScalarOperand,
122    S: DataMut<Elem = A>,
123    T: Data<Elem = A>,
124    D: Dimension,
125{
126    type Velocity = Params<A, D>;
127
128    fn apply_gradient_with_momentum(
129        &mut self,
130        grad: &ParamsBase<T, D>,
131        lr: A,
132        momentum: A,
133        velocity: &mut Self::Velocity,
134    ) -> crate::Result<()> {
135        // apply the bias gradient
136        self.bias
137            .apply_gradient_with_momentum(grad.bias(), lr, momentum, velocity.bias_mut())?;
138        // apply the weight gradient
139        self.weights.apply_gradient_with_momentum(
140            grad.weights(),
141            lr,
142            momentum,
143            velocity.weights_mut(),
144        )?;
145        Ok(())
146    }
147
148    fn apply_gradient_with_decay_and_momentum(
149        &mut self,
150        grad: &ParamsBase<T, D>,
151        lr: A,
152        decay: A,
153        momentum: A,
154        velocity: &mut Self::Velocity,
155    ) -> crate::Result<()> {
156        // apply the bias gradient
157        self.bias.apply_gradient_with_decay_and_momentum(
158            grad.bias(),
159            lr,
160            decay,
161            momentum,
162            velocity.bias_mut(),
163        )?;
164        // apply the weight gradient
165        self.weights.apply_gradient_with_decay_and_momentum(
166            grad.weights(),
167            lr,
168            decay,
169            momentum,
170            velocity.weights_mut(),
171        )?;
172        Ok(())
173    }
174}
175
176impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix1>> for Params<A, Ix1>
177where
178    A: Float + FromPrimitive + ScalarOperand,
179    S: Data<Elem = A>,
180    T: Data<Elem = A>,
181{
182    type HParam = A;
183    type Output = A;
184
185    fn backward(
186        &mut self,
187        input: &ArrayBase<S, Ix2>,
188        delta: &ArrayBase<T, Ix1>,
189        gamma: Self::HParam,
190    ) -> crate::Result<Self::Output> {
191        // compute the weight gradient
192        let weight_delta = delta.t().dot(input);
193        // update the weights and bias
194        self.weights.apply_gradient(&weight_delta, gamma)?;
195        self.bias.apply_gradient(&delta.sum_axis(Axis(0)), gamma)?;
196        // return the sum of the squared delta
197        Ok(delta.pow2().sum())
198    }
199}
200
201impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix0>> for Params<A, Ix1>
202where
203    A: Float + FromPrimitive + ScalarOperand,
204    S: Data<Elem = A>,
205    T: Data<Elem = A>,
206{
207    type HParam = A;
208    type Output = A;
209
210    fn backward(
211        &mut self,
212        input: &ArrayBase<S, Ix1>,
213        delta: &ArrayBase<T, Ix0>,
214        gamma: Self::HParam,
215    ) -> crate::Result<Self::Output> {
216        // compute the weight gradient
217        let weight_delta = input * delta;
218        // update the weights and bias
219        self.weights.apply_gradient(&weight_delta, gamma)?;
220        self.bias.apply_gradient(&delta, gamma)?;
221        // return the sum of the squared delta
222        Ok(delta.pow2().sum())
223    }
224}
225
226impl<A, S, T> Backward<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for Params<A, Ix2>
227where
228    A: Float + FromPrimitive + ScalarOperand,
229    S: Data<Elem = A>,
230    T: Data<Elem = A>,
231{
232    type HParam = A;
233    type Output = A;
234
235    fn backward(
236        &mut self,
237        input: &ArrayBase<S, Ix1>,
238        delta: &ArrayBase<T, Ix1>,
239        gamma: Self::HParam,
240    ) -> crate::Result<Self::Output> {
241        // compute the weight gradient
242        let dw = &self.weights * delta.t().dot(input);
243        // update the weights and bias
244        self.weights.apply_gradient(&dw, gamma)?;
245        self.bias.apply_gradient(&delta, gamma)?;
246        // return the sum of the squared delta
247        Ok(delta.pow2().sum())
248    }
249}
250
251impl<A, S, T> Backward<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for Params<A, Ix2>
252where
253    A: Float + FromPrimitive + ScalarOperand,
254    S: Data<Elem = A>,
255    T: Data<Elem = A>,
256{
257    type HParam = A;
258    type Output = A;
259
260    fn backward(
261        &mut self,
262        input: &ArrayBase<S, Ix2>,
263        delta: &ArrayBase<T, Ix2>,
264        gamma: Self::HParam,
265    ) -> crate::Result<Self::Output> {
266        // compute the weight gradient
267        let weight_delta = input.dot(&delta.t());
268        // compute the bias gradient
269        let bias_delta = delta.sum_axis(Axis(0));
270
271        self.weights.apply_gradient(&weight_delta, gamma)?;
272        self.bias.apply_gradient(&bias_delta, gamma)?;
273        // return the sum of the squared delta
274        Ok(delta.pow2().sum())
275    }
276}
277
278impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
279where
280    A: Clone,
281    D: Dimension,
282    S: Data<Elem = A>,
283    for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
284    Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller>, Output = Z>,
285{
286    type Output = Z;
287
288    fn forward(&self, input: &X) -> crate::Result<Self::Output> {
289        let output = input.dot(&self.weights) + &self.bias;
290        Ok(output)
291    }
292}
293
294#[cfg(feature = "rand")]
295impl<A, S, D> crate::init::Initialize<A, D> for ParamsBase<S, D>
296where
297    D: ndarray::RemoveAxis,
298    S: ndarray::RawData<Elem = A>,
299{
300    type Data = S;
301
302    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
303    where
304        Ds: rand_distr::Distribution<A>,
305        Sh: ndarray::ShapeBuilder<Dim = D>,
306        S: ndarray::DataOwned,
307    {
308        use rand::SeedableRng;
309        Self::rand_with(
310            shape,
311            distr,
312            &mut rand::rngs::SmallRng::from_rng(&mut rand::rng()),
313        )
314    }
315
316    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
317    where
318        R: rand::Rng + ?Sized,
319        Ds: rand_distr::Distribution<A>,
320        Sh: ShapeBuilder<Dim = D>,
321        S: ndarray::DataOwned,
322    {
323        let shape = shape.into_shape_with_order();
324        let bias_shape = shape.raw_dim().remove_axis(Axis(0));
325        let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
326        let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
327        Self { bias, weights }
328    }
329}