concision_core/traits/
gradient.rs1pub trait Gradient<T, D> {
10 type Repr<_A>;
11 type Delta<_S, _D>;
12
13 fn grad(&self, rhs: &Self::Delta<Self::Repr<T>, D>) -> Self::Delta<Self::Repr<T>, D>;
14}
15
16pub trait ApplyGradient<Delta, T> {
18 type Output;
19
20 fn apply_gradient(&mut self, grad: &Delta, lr: T) -> crate::Result<Self::Output>;
21
22 fn apply_gradient_with_decay(
23 &mut self,
24 grad: &Delta,
25 lr: T,
26 decay: T,
27 ) -> crate::Result<Self::Output>;
28}
29
30pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
32 type Velocity;
33
34 fn apply_gradient_with_momentum(
35 &mut self,
36 grad: &Delta,
37 lr: T,
38 momentum: T,
39 velocity: &mut Self::Velocity,
40 ) -> crate::Result<Self::Output>;
41
42 fn apply_gradient_with_decay_and_momentum(
43 &mut self,
44 grad: &Delta,
45 lr: T,
46 decay: T,
47 momentum: T,
48 velocity: &mut Self::Velocity,
49 ) -> crate::Result<Self::Output>;
50}
51
52use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand, ShapeError};
57use num_traits::{Float, FromPrimitive};
58
59impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
60where
61 A: Float + FromPrimitive + ScalarOperand,
62 S: DataMut<Elem = A>,
63 T: Data<Elem = A>,
64 D: Dimension,
65{
66 type Output = ();
67
68 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
69 if self.shape() != grad.shape() {
70 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
71 }
72 let batch_size = if !grad.shape().is_empty() {
73 A::from_usize(self.shape()[0]).unwrap()
74 } else {
75 A::one()
76 };
77 self.scaled_add(lr / batch_size, grad);
78 Ok(())
79 }
80
81 fn apply_gradient_with_decay(
82 &mut self,
83 grad: &ArrayBase<T, D>,
84 lr: A,
85 decay: A,
86 ) -> crate::Result<Self::Output> {
87 if self.shape() != grad.shape() {
88 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
89 }
90 let batch_size = if !grad.shape().is_empty() {
91 A::from_usize(self.shape()[0]).unwrap()
92 } else {
93 A::one()
94 };
95 self.scaled_add(lr / batch_size, &(grad + &*self * decay));
96 Ok(())
97 }
98}
99
100impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
101where
102 A: Float + FromPrimitive + ScalarOperand,
103 S: DataMut<Elem = A>,
104 T: Data<Elem = A>,
105 D: Dimension,
106{
107 type Velocity = Array<A, D>;
108
109 fn apply_gradient_with_momentum(
110 &mut self,
111 grad: &ArrayBase<T, D>,
112 lr: A,
113 momentum: A,
114 velocity: &mut Self::Velocity,
115 ) -> crate::Result<Self::Output> {
116 if self.shape() != grad.shape() {
117 return Err(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into());
118 }
119 let batch_size = if !grad.shape().is_empty() {
120 A::from_usize(self.shape()[0]).unwrap()
121 } else {
122 A::one()
123 };
124 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
125 self.scaled_add(lr / batch_size, velocity);
126 Ok(())
127 }
128
129 fn apply_gradient_with_decay_and_momentum(
130 &mut self,
131 grad: &ArrayBase<T, D>,
132 lr: A,
133 decay: A,
134 momentum: A,
135 velocity: &mut Self::Velocity,
136 ) -> crate::Result<Self::Output> {
137 if self.shape() != grad.shape() {
138 return Err(
139 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
140 );
141 }
142 let batch_size = if !grad.shape().is_empty() {
143 A::from_usize(self.shape()[0]).unwrap()
144 } else {
145 A::one()
146 };
147
148 let adjusted_grad = grad + &*self * decay;
149 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
150 self.scaled_add(lr / batch_size, velocity);
151 Ok(())
152 }
153}