concision_traits/
gradient.rs1pub trait Gradient<T> {
10 type Delta<_U>;
11
12 fn grad(&self, rhs: &Self::Delta<T>) -> Self::Delta<T>;
13}
14
15pub trait ApplyGradient<Delta, T> {
17 type Output;
18
19 fn apply_gradient(&mut self, grad: &Delta, lr: T) -> Option<Self::Output>;
20
21 fn apply_gradient_with_decay(&mut self, grad: &Delta, lr: T, decay: T) -> Option<Self::Output>;
22}
23
24pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
26 type Velocity;
27
28 fn apply_gradient_with_momentum(
29 &mut self,
30 grad: &Delta,
31 lr: T,
32 momentum: T,
33 velocity: &mut Self::Velocity,
34 ) -> Option<Self::Output>;
35
36 fn apply_gradient_with_decay_and_momentum(
37 &mut self,
38 grad: &Delta,
39 lr: T,
40 decay: T,
41 momentum: T,
42 velocity: &mut Self::Velocity,
43 ) -> Option<Self::Output>;
44}
45
46use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, ScalarOperand};
51use num_traits::{Float, FromPrimitive};
52
53impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D, A>
54where
55 A: Float + FromPrimitive + ScalarOperand,
56 S: DataMut<Elem = A>,
57 T: Data<Elem = A>,
58 D: Dimension,
59{
60 type Output = ();
61
62 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> Option<Self::Output> {
63 if self.shape() != grad.shape() {
64 return None;
66 }
67 let batch_size = if !grad.shape().is_empty() {
68 A::from_usize(self.shape()[0]).unwrap()
69 } else {
70 A::one()
71 };
72 self.scaled_add(lr / batch_size, grad);
73 Some(())
74 }
75
76 fn apply_gradient_with_decay(
77 &mut self,
78 grad: &ArrayBase<T, D>,
79 lr: A,
80 decay: A,
81 ) -> Option<Self::Output> {
82 if self.shape() != grad.shape() {
83 return None;
85 }
86 let batch_size = if !grad.shape().is_empty() {
87 A::from_usize(self.shape()[0]).unwrap()
88 } else {
89 A::one()
90 };
91 let rhs = grad + &*self * decay;
92 self.scaled_add(lr / batch_size, &rhs);
93 Some(())
94 }
95}
96
97impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D, A>
98where
99 A: Float + FromPrimitive + ScalarOperand,
100 S: DataMut<Elem = A>,
101 T: Data<Elem = A>,
102 D: Dimension,
103{
104 type Velocity = Array<A, D>;
105
106 fn apply_gradient_with_momentum(
107 &mut self,
108 grad: &ArrayBase<T, D>,
109 lr: A,
110 momentum: A,
111 velocity: &mut Self::Velocity,
112 ) -> Option<Self::Output> {
113 if self.shape() != grad.shape() {
114 return None;
116 }
117 let batch_size = if !grad.shape().is_empty() {
118 A::from_usize(self.shape()[0]).unwrap()
119 } else {
120 A::one()
121 };
122 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
123 self.scaled_add(lr / batch_size, velocity);
124 Some(())
125 }
126
127 fn apply_gradient_with_decay_and_momentum(
128 &mut self,
129 grad: &ArrayBase<T, D>,
130 lr: A,
131 decay: A,
132 momentum: A,
133 velocity: &mut Self::Velocity,
134 ) -> Option<Self::Output> {
135 if self.shape() != grad.shape() {
136 return None;
138 }
139 let batch_size = if !grad.shape().is_empty() {
140 A::from_usize(self.shape()[0]).unwrap()
141 } else {
142 A::one()
143 };
144
145 let adjusted_grad = grad + &*self * decay;
146 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
147 self.scaled_add(lr / batch_size, velocity);
148 Some(())
149 }
150}