use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use serde::{Deserialize, Serialize};
pub const DEFAULT_MOMENTUM: f32 = 0.9;
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum Optimizer {
GD,
Momentum(f32),
}
impl Optimizer {
pub fn default_momentum() -> Self {
Optimizer::Momentum(DEFAULT_MOMENTUM)
}
}
impl Default for Optimizer {
fn default() -> Self {
Optimizer::GD
}
}
pub(crate) enum OptimizerType {
GD,
Momentum {
momentum: f32,
weights_momentum: Array2<f32>,
biases_momentum: Array1<f32>,
},
}
impl OptimizerType {
pub(crate) fn new_momentum(
momentum: f32,
weights_dim: (usize, usize),
biases_dim: usize,
) -> Self {
OptimizerType::Momentum {
momentum: momentum,
weights_momentum: Array2::zeros(weights_dim),
biases_momentum: Array1::zeros(biases_dim),
}
}
pub fn optimize(
&mut self,
weights: &mut Array2<f32>,
biases: &mut Array1<f32>,
weights_gradient: &ArrayView2<f32>,
output_gradient: &ArrayView1<f32>,
learning_rate: f32,
) {
match self {
OptimizerType::GD => {
*weights -= &(weights_gradient * learning_rate);
*biases -= &(output_gradient.to_owned() * learning_rate);
}
OptimizerType::Momentum {
momentum,
weights_momentum,
biases_momentum,
} => {
*weights_momentum =
*momentum * &weights_momentum.view() - learning_rate * weights_gradient;
*biases_momentum =
*momentum * &biases_momentum.view() - learning_rate * output_gradient;
*weights += &*weights_momentum;
*biases += &*biases_momentum;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_gradient_descent() {
let mut weights = array![[0.5, -0.5], [0.5, -0.5]];
let mut biases = array![0.5, -0.5];
let weights_gradient = array![[0.1, -0.1], [0.1, -0.1]];
let output_gradient = array![0.1, -0.1];
let learning_rate = 0.1;
let mut optimizer = OptimizerType::GD;
optimizer.optimize(
&mut weights,
&mut biases,
&weights_gradient.view(),
&output_gradient.view(),
learning_rate,
);
assert_eq!(weights, array![[0.49, -0.49], [0.49, -0.49]]);
assert_eq!(biases, array![0.49, -0.49]);
}
#[test]
fn test_momentum() {
let mut weights = array![[0.5, -0.5], [0.5, -0.5]];
let mut biases = array![0.5, -0.5];
let weights_gradient = array![[0.1, -0.1], [0.1, -0.1]];
let output_gradient = array![0.1, -0.1];
let learning_rate = 0.1;
let momentum = 0.9;
let mut optimizer = OptimizerType::new_momentum(momentum, (2, 2), 2);
optimizer.optimize(
&mut weights,
&mut biases,
&weights_gradient.view(),
&output_gradient.view(),
learning_rate,
);
assert_eq!(weights, array![[0.49, -0.49], [0.49, -0.49]]);
assert_eq!(biases, array![0.49, -0.49]);
}
#[test]
fn test_default_optimizer() {
assert_eq!(Optimizer::default(), Optimizer::GD);
}
#[test]
fn test_default_momentum() {
assert_eq!(
Optimizer::default_momentum(),
Optimizer::Momentum(DEFAULT_MOMENTUM)
);
}
}