concision_core/traits/
apply.rs1pub trait ApplyGradient<Delta> {
7 type Elem;
8 type Output;
9
10 fn apply_gradient(&mut self, grad: &Delta, lr: Self::Elem) -> crate::Result<Self::Output>;
11
12 fn apply_gradient_with_decay(
13 &mut self,
14 grad: &Delta,
15 lr: Self::Elem,
16 decay: Self::Elem,
17 ) -> crate::Result<Self::Output>;
18}
19
20pub trait ApplyGradientExt<Delta>: ApplyGradient<Delta> {
22 type Velocity;
23
24 fn apply_gradient_with_momentum(
25 &mut self,
26 grad: &Delta,
27 lr: Self::Elem,
28 momentum: Self::Elem,
29 velocity: &mut Self::Velocity,
30 ) -> crate::Result<Self::Output>;
31
32 fn apply_gradient_with_decay_and_momentum(
33 &mut self,
34 grad: &Delta,
35 lr: Self::Elem,
36 decay: Self::Elem,
37 momentum: Self::Elem,
38 velocity: &mut Self::Velocity,
39 ) -> crate::Result<Self::Output>;
40}
41
42use ndarray::{ArrayBase, Dimension, ScalarOperand};
43use num_traits::{Float, FromPrimitive};
44
45impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>> for ArrayBase<S, D>
46where
47 A: Float + FromPrimitive + ScalarOperand,
48 S: ndarray::DataMut<Elem = A>,
49 T: ndarray::Data<Elem = A>,
50 D: Dimension,
51{
52 type Elem = A;
53 type Output = ();
54
55 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
56 if self.shape() != grad.shape() {
57 return Err(
58 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
59 );
60 }
61 let batch_size = if grad.shape().len() > 0 {
62 A::from_usize(self.shape()[0]).unwrap()
63 } else {
64 A::one()
65 };
66 self.scaled_add(lr / batch_size, &grad);
67 Ok(())
68 }
69
70 fn apply_gradient_with_decay(
71 &mut self,
72 grad: &ArrayBase<T, D>,
73 lr: A,
74 decay: A,
75 ) -> crate::Result<Self::Output> {
76 if self.shape() != grad.shape() {
77 return Err(
78 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
79 );
80 }
81 let batch_size = if grad.shape().len() > 0 {
82 A::from_usize(self.shape()[0]).unwrap()
83 } else {
84 A::one()
85 };
86 self.scaled_add(lr / batch_size, &(grad + &*self * decay));
87 Ok(())
88 }
89}
90impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>> for ArrayBase<S, D>
91where
92 A: Float + FromPrimitive + ScalarOperand,
93 S: ndarray::DataMut<Elem = A>,
94 T: ndarray::Data<Elem = A>,
95 D: Dimension,
96{
97 type Velocity = ndarray::Array<A, D>;
98
99 fn apply_gradient_with_momentum(
100 &mut self,
101 grad: &ArrayBase<T, D>,
102 lr: A,
103 momentum: A,
104 velocity: &mut Self::Velocity,
105 ) -> crate::Result<Self::Output> {
106 if self.shape() != grad.shape() {
107 return Err(
108 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
109 );
110 }
111 let batch_size = if grad.shape().len() > 0 {
112 A::from_usize(self.shape()[0]).unwrap()
113 } else {
114 A::one()
115 };
116 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
117 self.scaled_add(lr / batch_size, &velocity);
118 Ok(())
119 }
120
121 fn apply_gradient_with_decay_and_momentum(
122 &mut self,
123 grad: &ArrayBase<T, D>,
124 lr: A,
125 decay: A,
126 momentum: A,
127 velocity: &mut Self::Velocity,
128 ) -> crate::Result<Self::Output> {
129 if self.shape() != grad.shape() {
130 return Err(
131 ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape).into(),
132 );
133 }
134 let batch_size = if grad.shape().len() > 0 {
135 A::from_usize(self.shape()[0]).unwrap()
136 } else {
137 A::one()
138 };
139
140 let adjusted_grad = grad + &*self * decay;
141 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
142 self.scaled_add(lr / batch_size, &velocity);
143 Ok(())
144 }
145}