extern crate coaster as co;
extern crate juice;
#[cfg(all(test, whatever))]
mod solver_specs {
use co::backend::Backend;
use co::frameworks::Native;
use juice::solver::*;
#[test]
fn lr_fixed() {
let cfg = SolverConfig {
lr_policy: LRPolicy::Fixed,
base_lr: 5f32,
gamma: 0.5f32,
..SolverConfig::default()
};
assert!(cfg.get_learning_rate(0) == 5f32);
assert!(cfg.get_learning_rate(100) == 5f32);
assert!(cfg.get_learning_rate(1000) == 5f32);
}
#[test]
fn lr_step() {
let cfg = SolverConfig {
lr_policy: LRPolicy::Step,
base_lr: 5f32,
gamma: 0.5f32,
stepsize: 10,
..SolverConfig::default()
};
assert!(cfg.get_learning_rate(0) == 5f32);
assert!(cfg.get_learning_rate(10) == 2.5f32);
assert!(cfg.get_learning_rate(20) == 1.25f32);
}
#[test]
fn lr_exp() {
let cfg = SolverConfig {
lr_policy: LRPolicy::Exp,
base_lr: 5f32,
gamma: 0.5f32,
..SolverConfig::default()
};
assert!(cfg.get_learning_rate(0) == 5f32);
assert!(cfg.get_learning_rate(1) == 2.5f32);
assert!(cfg.get_learning_rate(2) == 1.25f32);
assert!(cfg.get_learning_rate(3) == 0.625f32);
let cfg2 = SolverConfig {
lr_policy: LRPolicy::Exp,
base_lr: 5f32,
gamma: 0.25f32,
..SolverConfig::default()
};
assert!(cfg2.get_learning_rate(0) == 5f32);
assert!(cfg2.get_learning_rate(1) == 1.25f32);
assert!(cfg2.get_learning_rate(2) == 0.3125f32);
}
#[test]
fn instantiate_solver_sgd_momentum() {
let cfg = SolverConfig {
solver: SolverKind::SGD(SGDKind::Momentum),
..SolverConfig::default()
};
Solver::<Box<ISolver<Backend<Native>>>, Backend<Native>>::from_config(&cfg);
}
}