concision_core/traits/
apply.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5/// A trait declaring basic gradient-related routines for a neural network
6pub trait ApplyGradient<Delta> {
7    type Elem;
8    type Output;
9
10    fn apply_gradient(&mut self, grad: &Delta, lr: Self::Elem) -> crate::Result<Self::Output>;
11
12    fn apply_gradient_with_decay(
13        &mut self,
14        grad: &Delta,
15        lr: Self::Elem,
16        decay: Self::Elem,
17    ) -> crate::Result<Self::Output>;
18}
19
20/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
21pub trait ApplyGradientExt<Delta>: ApplyGradient<Delta> {
22    type Velocity;
23
24    fn apply_gradient_with_momentum(
25        &mut self,
26        grad: &Delta,
27        lr: Self::Elem,
28        momentum: Self::Elem,
29        velocity: &mut Self::Velocity,
30    ) -> crate::Result<Self::Output>;
31
32    fn apply_gradient_with_decay_and_momentum(
33        &mut self,
34        grad: &Delta,
35        lr: Self::Elem,
36        decay: Self::Elem,
37        momentum: Self::Elem,
38        velocity: &mut Self::Velocity,
39    ) -> crate::Result<Self::Output>;
40}
41
42use ndarray::{ArrayBase, Dimension, ScalarOperand};
43use num_traits::{Float, FromPrimitive};
44
45impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>> for ArrayBase<S, D>
46where
47    A: Float + FromPrimitive + ScalarOperand,
48    S: ndarray::DataMut<Elem = A>,
49    T: ndarray::Data<Elem = A>,
50    D: Dimension,
51{
52    type Elem = A;
53    type Output = ();
54
55    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
56        if self.shape() != grad.shape() {
57            return Err(
58                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
59            );
60        }
61        let batch_size = if grad.shape().len() > 0 {
62            A::from_usize(self.shape()[0]).unwrap()
63        } else {
64            A::one()
65        };
66        self.scaled_add(lr / batch_size, &grad);
67        Ok(())
68    }
69
70    fn apply_gradient_with_decay(
71        &mut self,
72        grad: &ArrayBase<T, D>,
73        lr: A,
74        decay: A,
75    ) -> crate::Result<Self::Output> {
76        if self.shape() != grad.shape() {
77            return Err(
78                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
79            );
80        }
81        let batch_size = if grad.shape().len() > 0 {
82            A::from_usize(self.shape()[0]).unwrap()
83        } else {
84            A::one()
85        };
86        self.scaled_add(lr / batch_size, &(grad + &*self * decay));
87        Ok(())
88    }
89}
90impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>> for ArrayBase<S, D>
91where
92    A: Float + FromPrimitive + ScalarOperand,
93    S: ndarray::DataMut<Elem = A>,
94    T: ndarray::Data<Elem = A>,
95    D: Dimension,
96{
97    type Velocity = ndarray::Array<A, D>;
98
99    fn apply_gradient_with_momentum(
100        &mut self,
101        grad: &ArrayBase<T, D>,
102        lr: A,
103        momentum: A,
104        velocity: &mut Self::Velocity,
105    ) -> crate::Result<Self::Output> {
106        if self.shape() != grad.shape() {
107            return Err(
108                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
109            );
110        }
111        let batch_size = if grad.shape().len() > 0 {
112            A::from_usize(self.shape()[0]).unwrap()
113        } else {
114            A::one()
115        };
116        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
117        self.scaled_add(lr / batch_size, &velocity);
118        Ok(())
119    }
120
121    fn apply_gradient_with_decay_and_momentum(
122        &mut self,
123        grad: &ArrayBase<T, D>,
124        lr: A,
125        decay: A,
126        momentum: A,
127        velocity: &mut Self::Velocity,
128    ) -> crate::Result<Self::Output> {
129        if self.shape() != grad.shape() {
130            return Err(
131                ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
132            );
133        }
134        let batch_size = if grad.shape().len() > 0 {
135            A::from_usize(self.shape()[0]).unwrap()
136        } else {
137            A::one()
138        };
139
140        let adjusted_grad = grad + &*self * decay;
141        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
142        self.scaled_add(lr / batch_size, &velocity);
143        Ok(())
144    }
145}