concision_core/traits/
gradient.rs1pub trait ApplyGradient<Delta, T> {
7 type Output;
8
9 fn apply_gradient(&mut self, grad: &Delta, lr: T) -> crate::Result<Self::Output>;
10
11 fn apply_gradient_with_decay(
12 &mut self,
13 grad: &Delta,
14 lr: T,
15 decay: T,
16 ) -> crate::Result<Self::Output>;
17}
18
19pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
21 type Velocity;
22
23 fn apply_gradient_with_momentum(
24 &mut self,
25 grad: &Delta,
26 lr: T,
27 momentum: T,
28 velocity: &mut Self::Velocity,
29 ) -> crate::Result<Self::Output>;
30
31 fn apply_gradient_with_decay_and_momentum(
32 &mut self,
33 grad: &Delta,
34 lr: T,
35 decay: T,
36 momentum: T,
37 velocity: &mut Self::Velocity,
38 ) -> crate::Result<Self::Output>;
39}
40
41use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand, ShapeError};
46use num_traits::{Float, FromPrimitive};
47
48impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
49where
50 A: Float + FromPrimitive + ScalarOperand,
51 S: DataMut<Elem = A>,
52 T: Data<Elem = A>,
53 D: Dimension,
54{
55 type Output = ();
56
57 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
58 if self.shape() != grad.shape() {
59 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
60 }
61 let batch_size = if grad.shape().len() > 0 {
62 A::from_usize(self.shape()[0]).unwrap()
63 } else {
64 A::one()
65 };
66 self.scaled_add(lr / batch_size, &grad);
67 Ok(())
68 }
69
70 fn apply_gradient_with_decay(
71 &mut self,
72 grad: &ArrayBase<T, D>,
73 lr: A,
74 decay: A,
75 ) -> crate::Result<Self::Output> {
76 if self.shape() != grad.shape() {
77 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
78 }
79 let batch_size = if grad.shape().len() > 0 {
80 A::from_usize(self.shape()[0]).unwrap()
81 } else {
82 A::one()
83 };
84 self.scaled_add(lr / batch_size, &(grad + &*self * decay));
85 Ok(())
86 }
87}
88
89impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
90where
91 A: Float + FromPrimitive + ScalarOperand,
92 S: DataMut<Elem = A>,
93 T: Data<Elem = A>,
94 D: Dimension,
95{
96 type Velocity = Array<A, D>;
97
98 fn apply_gradient_with_momentum(
99 &mut self,
100 grad: &ArrayBase<T, D>,
101 lr: A,
102 momentum: A,
103 velocity: &mut Self::Velocity,
104 ) -> crate::Result<Self::Output> {
105 if self.shape() != grad.shape() {
106 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
107 }
108 let batch_size = if grad.shape().len() > 0 {
109 A::from_usize(self.shape()[0]).unwrap()
110 } else {
111 A::one()
112 };
113 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
114 self.scaled_add(lr / batch_size, &velocity);
115 Ok(())
116 }
117
118 fn apply_gradient_with_decay_and_momentum(
119 &mut self,
120 grad: &ArrayBase<T, D>,
121 lr: A,
122 decay: A,
123 momentum: A,
124 velocity: &mut Self::Velocity,
125 ) -> crate::Result<Self::Output> {
126 if self.shape() != grad.shape() {
127 return Err(
128 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
129 );
130 }
131 let batch_size = if grad.shape().len() > 0 {
132 A::from_usize(self.shape()[0]).unwrap()
133 } else {
134 A::one()
135 };
136
137 let adjusted_grad = grad + &*self * decay;
138 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
139 self.scaled_add(lr / batch_size, &velocity);
140 Ok(())
141 }
142}