concision_core/traits/
apply.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5/// A trait declaring basic gradient-related routines for a neural network
6pub trait ApplyGradient<Grad, A> {
7    type Output;
8
9    fn apply_gradient(&mut self, grad: &Grad, lr: A) -> crate::Result<Self::Output>;
10
11    fn apply_gradient_with_decay(
12        &mut self,
13        grad: &Grad,
14        lr: A,
15        decay: A,
16    ) -> crate::Result<Self::Output>;
17}
18
19/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
20pub trait ApplyGradientExt<Grad, A>: ApplyGradient<Grad, A> {
21    type Velocity;
22
23    fn apply_gradient_with_momentum(
24        &mut self,
25        grad: &Grad,
26        lr: A,
27        momentum: A,
28        velocity: &mut Self::Velocity,
29    ) -> crate::Result<Self::Output>;
30
31    fn apply_gradient_with_decay_and_momentum(
32        &mut self,
33        grad: &Grad,
34        lr: A,
35        decay: A,
36        momentum: A,
37        velocity: &mut Self::Velocity,
38    ) -> crate::Result<Self::Output>;
39}
40
41use ndarray::{ArrayBase, Dimension, ScalarOperand};
42use num_traits::{Float, FromPrimitive};
43
44impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
45where
46    A: Float + FromPrimitive + ScalarOperand,
47    S: ndarray::DataMut<Elem = A>,
48    T: ndarray::Data<Elem = A>,
49    D: Dimension,
50{
51    type Output = ();
52
53    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
54        if self.shape() != grad.shape() {
55            return Err(
56                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
57            );
58        }
59        let batch_size = if grad.shape().len() > 0 {
60            A::from_usize(self.shape()[0]).unwrap()
61        } else {
62            A::one()
63        };
64        self.scaled_add(lr / batch_size, &grad);
65        Ok(())
66    }
67
68    fn apply_gradient_with_decay(
69        &mut self,
70        grad: &ArrayBase<T, D>,
71        lr: A,
72        decay: A,
73    ) -> crate::Result<Self::Output> {
74        if self.shape() != grad.shape() {
75            return Err(
76                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
77            );
78        }
79        let batch_size = if grad.shape().len() > 0 {
80            A::from_usize(self.shape()[0]).unwrap()
81        } else {
82            A::one()
83        };
84        self.scaled_add(lr / batch_size, &(grad + &*self * decay));
85        Ok(())
86    }
87}
88impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
89where
90    A: Float + FromPrimitive + ScalarOperand,
91    S: ndarray::DataMut<Elem = A>,
92    T: ndarray::Data<Elem = A>,
93    D: Dimension,
94{
95    type Velocity = ndarray::Array<A, D>;
96
97    fn apply_gradient_with_momentum(
98        &mut self,
99        grad: &ArrayBase<T, D>,
100        lr: A,
101        momentum: A,
102        velocity: &mut Self::Velocity,
103    ) -> crate::Result<Self::Output> {
104        if self.shape() != grad.shape() {
105            return Err(
106                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
107            );
108        }
109        let batch_size = if grad.shape().len() > 0 {
110            A::from_usize(self.shape()[0]).unwrap()
111        } else {
112            A::one()
113        };
114        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
115        self.scaled_add(lr / batch_size, &velocity);
116        Ok(())
117    }
118
119    fn apply_gradient_with_decay_and_momentum(
120        &mut self,
121        grad: &ArrayBase<T, D>,
122        lr: A,
123        decay: A,
124        momentum: A,
125        velocity: &mut Self::Velocity,
126    ) -> crate::Result<Self::Output> {
127        if self.shape() != grad.shape() {
128            return Err(
129                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
130            );
131        }
132        let batch_size = if grad.shape().len() > 0 {
133            A::from_usize(self.shape()[0]).unwrap()
134        } else {
135            A::one()
136        };
137
138        let adjusted_grad = grad + &*self * decay;
139        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
140        self.scaled_add(lr / batch_size, &velocity);
141        Ok(())
142    }
143}