concision_traits/
gradient.rs

1/*
2    appellation: gradient <module>
3    authors: @FL03
4*/
5
6/// the [`Gradient`] trait defines the gradient of a function, which is a function that
7/// takes an input and returns a delta, which is the change in the output with respect to
8/// the input.
9pub trait Gradient<T> {
10    type Delta<_U>;
11
12    fn grad(&self, rhs: &Self::Delta<T>) -> Self::Delta<T>;
13}
14
15/// A trait declaring basic gradient-related routines for a neural network
16pub trait ApplyGradient<Delta, T> {
17    type Output;
18
19    fn apply_gradient(&mut self, grad: &Delta, lr: T) -> Option<Self::Output>;
20
21    fn apply_gradient_with_decay(&mut self, grad: &Delta, lr: T, decay: T) -> Option<Self::Output>;
22}
23
24/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
25pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
26    type Velocity;
27
28    fn apply_gradient_with_momentum(
29        &mut self,
30        grad: &Delta,
31        lr: T,
32        momentum: T,
33        velocity: &mut Self::Velocity,
34    ) -> Option<Self::Output>;
35
36    fn apply_gradient_with_decay_and_momentum(
37        &mut self,
38        grad: &Delta,
39        lr: T,
40        decay: T,
41        momentum: T,
42        velocity: &mut Self::Velocity,
43    ) -> Option<Self::Output>;
44}
45
46/*
47 ************* Implementations *************
48*/
49
50use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand};
51use num_traits::{Float, FromPrimitive};
52
53impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D, A>
54where
55    A: Float + FromPrimitive + ScalarOperand,
56    S: DataMut<Elem = A>,
57    T: Data<Elem = A>,
58    D: Dimension,
59{
60    type Output = ();
61
62    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> Option<Self::Output> {
63        if self.shape() != grad.shape() {
64            // return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
65            return None;
66        }
67        let batch_size = if !grad.shape().is_empty() {
68            A::from_usize(self.shape()[0]).unwrap()
69        } else {
70            A::one()
71        };
72        self.scaled_add(lr / batch_size, grad);
73        Some(())
74    }
75
76    fn apply_gradient_with_decay(
77        &mut self,
78        grad: &ArrayBase<T, D>,
79        lr: A,
80        decay: A,
81    ) -> Option<Self::Output> {
82        if self.shape() != grad.shape() {
83            // return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
84            return None;
85        }
86        let batch_size = if !grad.shape().is_empty() {
87            A::from_usize(self.shape()[0]).unwrap()
88        } else {
89            A::one()
90        };
91        let rhs = grad + &*self * decay;
92        self.scaled_add(lr / batch_size, &rhs);
93        Some(())
94    }
95}
96
97impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D, A>
98where
99    A: Float + FromPrimitive + ScalarOperand,
100    S: DataMut<Elem = A>,
101    T: Data<Elem = A>,
102    D: Dimension,
103{
104    type Velocity = Array<A, D>;
105
106    fn apply_gradient_with_momentum(
107        &mut self,
108        grad: &ArrayBase<T, D>,
109        lr: A,
110        momentum: A,
111        velocity: &mut Self::Velocity,
112    ) -> Option<Self::Output> {
113        if self.shape() != grad.shape() {
114            // return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
115            return None;
116        }
117        let batch_size = if !grad.shape().is_empty() {
118            A::from_usize(self.shape()[0]).unwrap()
119        } else {
120            A::one()
121        };
122        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
123        self.scaled_add(lr / batch_size, velocity);
124        Some(())
125    }
126
127    fn apply_gradient_with_decay_and_momentum(
128        &mut self,
129        grad: &ArrayBase<T, D>,
130        lr: A,
131        decay: A,
132        momentum: A,
133        velocity: &mut Self::Velocity,
134    ) -> Option<Self::Output> {
135        if self.shape() != grad.shape() {
136            // return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
137            return None;
138        }
139        let batch_size = if !grad.shape().is_empty() {
140            A::from_usize(self.shape()[0]).unwrap()
141        } else {
142            A::one()
143        };
144
145        let adjusted_grad = grad + &*self * decay;
146        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
147        self.scaled_add(lr / batch_size, velocity);
148        Some(())
149    }
150}