1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
use crate::array::*;
use crate::numbers::*;
use crate::optimizer::Optimizer;
pub struct GradientDescent {
learning_rate: Float,
}
impl GradientDescent {
pub fn new(learning_rate: Float) -> GradientDescent {
GradientDescent { learning_rate }
}
}
impl Optimizer for GradientDescent {
fn update(&self, parameters: Vec<&mut Array>) {
for parameter in parameters {
let gradient = parameter.gradient();
if let Some(x) = gradient {
parameter.stop_tracking();
*parameter = &*parameter - &(&x * self.learning_rate);
parameter.start_tracking();
*parameter.gradient_mut() = None;
}
}
}
}