concision_core/traits/
gradient.rs

1/*
2    appellation: gradient <module>
3    authors: @FL03
4*/
5
6/// the [`Gradient`] trait defines the gradient of a function, which is a function that
7/// takes an input and returns a delta, which is the change in the output with respect to
8/// the input.
9pub trait Gradient<T, D> {
10    type Repr<_A>;
11    type Delta<_S, _D>;
12
13    fn grad(&self, rhs: &Self::Delta<Self::Repr<T>, D>) -> Self::Delta<Self::Repr<T>, D>;
14}
15
16/// A trait declaring basic gradient-related routines for a neural network
17pub trait ApplyGradient<Delta, T> {
18    type Output;
19
20    fn apply_gradient(&mut self, grad: &Delta, lr: T) -> crate::Result<Self::Output>;
21
22    fn apply_gradient_with_decay(
23        &mut self,
24        grad: &Delta,
25        lr: T,
26        decay: T,
27    ) -> crate::Result<Self::Output>;
28}
29
30/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
31pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
32    type Velocity;
33
34    fn apply_gradient_with_momentum(
35        &mut self,
36        grad: &Delta,
37        lr: T,
38        momentum: T,
39        velocity: &mut Self::Velocity,
40    ) -> crate::Result<Self::Output>;
41
42    fn apply_gradient_with_decay_and_momentum(
43        &mut self,
44        grad: &Delta,
45        lr: T,
46        decay: T,
47        momentum: T,
48        velocity: &mut Self::Velocity,
49    ) -> crate::Result<Self::Output>;
50}
51
52/*
53 ************* Implementations *************
54*/
55
56use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand, ShapeError};
57use num_traits::{Float, FromPrimitive};
58
59impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
60where
61    A: Float + FromPrimitive + ScalarOperand,
62    S: DataMut<Elem = A>,
63    T: Data<Elem = A>,
64    D: Dimension,
65{
66    type Output = ();
67
68    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
69        if self.shape() != grad.shape() {
70            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
71        }
72        let batch_size = if !grad.shape().is_empty() {
73            A::from_usize(self.shape()[0]).unwrap()
74        } else {
75            A::one()
76        };
77        self.scaled_add(lr / batch_size, grad);
78        Ok(())
79    }
80
81    fn apply_gradient_with_decay(
82        &mut self,
83        grad: &ArrayBase<T, D>,
84        lr: A,
85        decay: A,
86    ) -> crate::Result<Self::Output> {
87        if self.shape() != grad.shape() {
88            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
89        }
90        let batch_size = if !grad.shape().is_empty() {
91            A::from_usize(self.shape()[0]).unwrap()
92        } else {
93            A::one()
94        };
95        self.scaled_add(lr / batch_size, &(grad + &*self * decay));
96        Ok(())
97    }
98}
99
100impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
101where
102    A: Float + FromPrimitive + ScalarOperand,
103    S: DataMut<Elem = A>,
104    T: Data<Elem = A>,
105    D: Dimension,
106{
107    type Velocity = Array<A, D>;
108
109    fn apply_gradient_with_momentum(
110        &mut self,
111        grad: &ArrayBase<T, D>,
112        lr: A,
113        momentum: A,
114        velocity: &mut Self::Velocity,
115    ) -> crate::Result<Self::Output> {
116        if self.shape() != grad.shape() {
117            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
118        }
119        let batch_size = if !grad.shape().is_empty() {
120            A::from_usize(self.shape()[0]).unwrap()
121        } else {
122            A::one()
123        };
124        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
125        self.scaled_add(lr / batch_size, velocity);
126        Ok(())
127    }
128
129    fn apply_gradient_with_decay_and_momentum(
130        &mut self,
131        grad: &ArrayBase<T, D>,
132        lr: A,
133        decay: A,
134        momentum: A,
135        velocity: &mut Self::Velocity,
136    ) -> crate::Result<Self::Output> {
137        if self.shape() != grad.shape() {
138            return Err(
139                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
140            );
141        }
142        let batch_size = if !grad.shape().is_empty() {
143            A::from_usize(self.shape()[0]).unwrap()
144        } else {
145            A::one()
146        };
147
148        let adjusted_grad = grad + &*self * decay;
149        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
150        self.scaled_add(lr / batch_size, velocity);
151        Ok(())
152    }
153}