concision_core/traits/
gradient.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5/// A trait declaring basic gradient-related routines for a neural network
6pub trait ApplyGradient<Delta, T> {
7    type Output;
8
9    fn apply_gradient(&mut self, grad: &Delta, lr: T) -> crate::Result<Self::Output>;
10
11    fn apply_gradient_with_decay(
12        &mut self,
13        grad: &Delta,
14        lr: T,
15        decay: T,
16    ) -> crate::Result<Self::Output>;
17}
18
19/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
20pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
21    type Velocity;
22
23    fn apply_gradient_with_momentum(
24        &mut self,
25        grad: &Delta,
26        lr: T,
27        momentum: T,
28        velocity: &mut Self::Velocity,
29    ) -> crate::Result<Self::Output>;
30
31    fn apply_gradient_with_decay_and_momentum(
32        &mut self,
33        grad: &Delta,
34        lr: T,
35        decay: T,
36        momentum: T,
37        velocity: &mut Self::Velocity,
38    ) -> crate::Result<Self::Output>;
39}
40
41/*
42 ************* Implementations *************
43*/
44
45use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand, ShapeError};
46use num_traits::{Float, FromPrimitive};
47
48impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
49where
50    A: Float + FromPrimitive + ScalarOperand,
51    S: DataMut<Elem = A>,
52    T: Data<Elem = A>,
53    D: Dimension,
54{
55    type Output = ();
56
57    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
58        if self.shape() != grad.shape() {
59            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
60        }
61        let batch_size = if grad.shape().len() > 0 {
62            A::from_usize(self.shape()[0]).unwrap()
63        } else {
64            A::one()
65        };
66        self.scaled_add(lr / batch_size, &grad);
67        Ok(())
68    }
69
70    fn apply_gradient_with_decay(
71        &mut self,
72        grad: &ArrayBase<T, D>,
73        lr: A,
74        decay: A,
75    ) -> crate::Result<Self::Output> {
76        if self.shape() != grad.shape() {
77            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
78        }
79        let batch_size = if grad.shape().len() > 0 {
80            A::from_usize(self.shape()[0]).unwrap()
81        } else {
82            A::one()
83        };
84        self.scaled_add(lr / batch_size, &(grad + &*self * decay));
85        Ok(())
86    }
87}
88
89impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
90where
91    A: Float + FromPrimitive + ScalarOperand,
92    S: DataMut<Elem = A>,
93    T: Data<Elem = A>,
94    D: Dimension,
95{
96    type Velocity = Array<A, D>;
97
98    fn apply_gradient_with_momentum(
99        &mut self,
100        grad: &ArrayBase<T, D>,
101        lr: A,
102        momentum: A,
103        velocity: &mut Self::Velocity,
104    ) -> crate::Result<Self::Output> {
105        if self.shape() != grad.shape() {
106            return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
107        }
108        let batch_size = if grad.shape().len() > 0 {
109            A::from_usize(self.shape()[0]).unwrap()
110        } else {
111            A::one()
112        };
113        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
114        self.scaled_add(lr / batch_size, &velocity);
115        Ok(())
116    }
117
118    fn apply_gradient_with_decay_and_momentum(
119        &mut self,
120        grad: &ArrayBase<T, D>,
121        lr: A,
122        decay: A,
123        momentum: A,
124        velocity: &mut Self::Velocity,
125    ) -> crate::Result<Self::Output> {
126        if self.shape() != grad.shape() {
127            return Err(
128                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
129            );
130        }
131        let batch_size = if grad.shape().len() > 0 {
132            A::from_usize(self.shape()[0]).unwrap()
133        } else {
134            A::one()
135        };
136
137        let adjusted_grad = grad + &*self * decay;
138        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
139        self.scaled_add(lr / batch_size, &velocity);
140        Ok(())
141    }
142}