concision_core/traits/
apply.rs1pub trait ApplyGradient<Grad, A> {
7 type Output;
8
9 fn apply_gradient(&mut self, grad: &Grad, lr: A) -> crate::Result<Self::Output>;
10
11 fn apply_gradient_with_decay(
12 &mut self,
13 grad: &Grad,
14 lr: A,
15 decay: A,
16 ) -> crate::Result<Self::Output>;
17}
18
19pub trait ApplyGradientExt<Grad, A>: ApplyGradient<Grad, A> {
21 type Velocity;
22
23 fn apply_gradient_with_momentum(
24 &mut self,
25 grad: &Grad,
26 lr: A,
27 momentum: A,
28 velocity: &mut Self::Velocity,
29 ) -> crate::Result<Self::Output>;
30
31 fn apply_gradient_with_decay_and_momentum(
32 &mut self,
33 grad: &Grad,
34 lr: A,
35 decay: A,
36 momentum: A,
37 velocity: &mut Self::Velocity,
38 ) -> crate::Result<Self::Output>;
39}
40
41use ndarray::{ArrayBase, Dimension, ScalarOperand};
42use num_traits::{Float, FromPrimitive};
43
44impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
45where
46 A: Float + FromPrimitive + ScalarOperand,
47 S: ndarray::DataMut<Elem = A>,
48 T: ndarray::Data<Elem = A>,
49 D: Dimension,
50{
51 type Output = ();
52
53 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
54 if self.shape() != grad.shape() {
55 return Err(
56 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
57 );
58 }
59 let batch_size = if grad.shape().len() > 0 {
60 A::from_usize(self.shape()[0]).unwrap()
61 } else {
62 A::one()
63 };
64 self.scaled_add(lr / batch_size, &grad);
65 Ok(())
66 }
67
68 fn apply_gradient_with_decay(
69 &mut self,
70 grad: &ArrayBase<T, D>,
71 lr: A,
72 decay: A,
73 ) -> crate::Result<Self::Output> {
74 if self.shape() != grad.shape() {
75 return Err(
76 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
77 );
78 }
79 let batch_size = if grad.shape().len() > 0 {
80 A::from_usize(self.shape()[0]).unwrap()
81 } else {
82 A::one()
83 };
84 self.scaled_add(lr / batch_size, &(grad + &*self * decay));
85 Ok(())
86 }
87}
88impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
89where
90 A: Float + FromPrimitive + ScalarOperand,
91 S: ndarray::DataMut<Elem = A>,
92 T: ndarray::Data<Elem = A>,
93 D: Dimension,
94{
95 type Velocity = ndarray::Array<A, D>;
96
97 fn apply_gradient_with_momentum(
98 &mut self,
99 grad: &ArrayBase<T, D>,
100 lr: A,
101 momentum: A,
102 velocity: &mut Self::Velocity,
103 ) -> crate::Result<Self::Output> {
104 if self.shape() != grad.shape() {
105 return Err(
106 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
107 );
108 }
109 let batch_size = if grad.shape().len() > 0 {
110 A::from_usize(self.shape()[0]).unwrap()
111 } else {
112 A::one()
113 };
114 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
115 self.scaled_add(lr / batch_size, &velocity);
116 Ok(())
117 }
118
119 fn apply_gradient_with_decay_and_momentum(
120 &mut self,
121 grad: &ArrayBase<T, D>,
122 lr: A,
123 decay: A,
124 momentum: A,
125 velocity: &mut Self::Velocity,
126 ) -> crate::Result<Self::Output> {
127 if self.shape() != grad.shape() {
128 return Err(
129 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
130 );
131 }
132 let batch_size = if grad.shape().len() > 0 {
133 A::from_usize(self.shape()[0]).unwrap()
134 } else {
135 A::one()
136 };
137
138 let adjusted_grad = grad + &*self * decay;
139 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
140 self.scaled_add(lr / batch_size, &velocity);
141 Ok(())
142 }
143}