1pub trait ApplyGradient<Grad, A> {
7 type Output;
8
9 fn apply_gradient(&mut self, grad: &Grad, lr: A) -> crate::Result<Self::Output>;
10
11 fn apply_gradient_with_decay(
12 &mut self,
13 grad: &Grad,
14 lr: A,
15 decay: A,
16 ) -> crate::Result<Self::Output>;
17}
18
19pub trait ApplyGradientExt<Grad, A>: ApplyGradient<Grad, A> {
21 type Velocity;
22
23 fn apply_gradient_with_momentum(
24 &mut self,
25 grad: &Grad,
26 lr: A,
27 momentum: A,
28 velocity: &mut Self::Velocity,
29 ) -> crate::Result<Self::Output>;
30
31 fn apply_gradient_with_decay_and_momentum(
32 &mut self,
33 grad: &Grad,
34 lr: A,
35 decay: A,
36 momentum: A,
37 velocity: &mut Self::Velocity,
38 ) -> crate::Result<Self::Output>;
39}
40
41pub trait Backward<X, Y> {
44 type HParam;
45 type Output;
46
47 fn backward(
48 &mut self,
49 input: &X,
50 delta: &Y,
51 gamma: Self::HParam,
52 ) -> crate::Result<Self::Output>;
53}
54
55pub trait Train<X, Y> {
57 type Output;
58
59 fn train(&mut self, input: &X, target: &Y) -> crate::Result<Self::Output>;
60
61 fn train_for(&mut self, input: &X, target: &Y, epochs: usize) -> crate::Result<Self::Output> {
62 let mut output = None;
63
64 for _ in 0..epochs {
65 output = match self.train(input, target) {
66 Ok(o) => Some(o),
67 Err(e) => {
68 #[cfg(feature = "tracing")]
69 tracing::error!("Training failed: {e}");
70 return Err(e);
71 }
72 }
73 }
74 output.ok_or_else(|| crate::error::Error::TrainingFailed("No output".into()))
75 }
76}
77
78use ndarray::{ArrayBase, Dimension, ScalarOperand};
79use num_traits::{Float, FromPrimitive};
80
81impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
82where
83 A: Float + FromPrimitive + ScalarOperand,
84 S: ndarray::DataMut<Elem = A>,
85 T: ndarray::Data<Elem = A>,
86 D: Dimension,
87{
88 type Output = ();
89
90 fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
91 if self.shape() != grad.shape() {
92 return Err(crate::error::Error::ShapeMismatch(
93 self.shape().to_vec(),
94 grad.shape().to_vec(),
95 ));
96 }
97 let batch_size = if grad.shape().len() > 0 {
98 A::from_usize(self.shape()[0]).unwrap()
99 } else {
100 A::one()
101 };
102 self.scaled_add(lr / batch_size, &grad);
103 Ok(())
104 }
105
106 fn apply_gradient_with_decay(
107 &mut self,
108 grad: &ArrayBase<T, D>,
109 lr: A,
110 decay: A,
111 ) -> crate::Result<Self::Output> {
112 if self.shape() != grad.shape() {
113 return Err(crate::error::Error::ShapeMismatch(
114 self.shape().to_vec(),
115 grad.shape().to_vec(),
116 ));
117 }
118 let batch_size = if grad.shape().len() > 0 {
119 A::from_usize(self.shape()[0]).unwrap()
120 } else {
121 A::one()
122 };
123 self.scaled_add(lr / batch_size, &(grad + &*self * decay));
124 Ok(())
125 }
126}
127impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
128where
129 A: Float + FromPrimitive + ScalarOperand,
130 S: ndarray::DataMut<Elem = A>,
131 T: ndarray::Data<Elem = A>,
132 D: Dimension,
133{
134 type Velocity = ndarray::Array<A, D>;
135
136 fn apply_gradient_with_momentum(
137 &mut self,
138 grad: &ArrayBase<T, D>,
139 lr: A,
140 momentum: A,
141 velocity: &mut Self::Velocity,
142 ) -> crate::Result<Self::Output> {
143 if self.shape() != grad.shape() {
144 return Err(crate::error::Error::ShapeMismatch(
145 self.shape().to_vec(),
146 grad.shape().to_vec(),
147 ));
148 }
149 let batch_size = if grad.shape().len() > 0 {
150 A::from_usize(self.shape()[0]).unwrap()
151 } else {
152 A::one()
153 };
154 *velocity = &*velocity * momentum + grad * (A::one() - momentum);
155 self.scaled_add(lr / batch_size, &velocity);
156 Ok(())
157 }
158
159 fn apply_gradient_with_decay_and_momentum(
160 &mut self,
161 grad: &ArrayBase<T, D>,
162 lr: A,
163 decay: A,
164 momentum: A,
165 velocity: &mut Self::Velocity,
166 ) -> crate::Result<Self::Output> {
167 if self.shape() != grad.shape() {
168 return Err(crate::error::Error::ShapeMismatch(
169 self.shape().to_vec(),
170 grad.shape().to_vec(),
171 ));
172 }
173 let batch_size = if grad.shape().len() > 0 {
174 A::from_usize(self.shape()[0]).unwrap()
175 } else {
176 A::one()
177 };
178
179 let adjusted_grad = grad + &*self * decay;
180 *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
181 self.scaled_add(lr / batch_size, &velocity);
182 Ok(())
183 }
184}