use std::io::{Read, Write};
use crate::tensor::Result;
mod sgd;
mod adam;
mod rmsprop;
mod adagrad;
mod radam;
mod nadam;
pub use sgd::{SGD, SGDBuilder};
pub use adam::{Adam, AdamBuilder, AdamW, AdamWBuilder};
pub use rmsprop::{RMSprop, RMSpropBuilder};
pub use adagrad::{Adagrad, AdagradBuilder};
pub use radam::RAdam;
pub use nadam::NAdam;
pub trait Optimizer {
fn step(&mut self) -> Result<()>;
fn zero_grad(&self);
fn lr(&self) -> f64;
fn set_lr(&mut self, lr: f64);
fn set_group_lr(&mut self, _group: usize, lr: f64) {
self.set_lr(lr);
}
fn scale_lr(&mut self, factor: f64) {
self.set_lr(self.lr() * factor);
}
}
struct GroupMeta {
lr: f64,
range: std::ops::Range<usize>,
}
pub trait Stateful {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()>;
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()>;
fn save_state_file(&self, path: &str) -> Result<()> {
let f = std::fs::File::create(path).map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
if path.ends_with(".gz") {
let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
self.save_state(&mut w)?;
w.finish().map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
Ok(())
} else {
let mut w = std::io::BufWriter::new(f);
self.save_state(&mut w)
}
}
fn load_state_file(&mut self, path: &str) -> Result<()> {
let f = std::fs::File::open(path).map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
if path.ends_with(".gz") {
let mut r = flate2::read::GzDecoder::new(f);
self.load_state(&mut r)
} else {
let mut r = std::io::BufReader::new(f);
self.load_state(&mut r)
}
}
}
#[cfg(test)]
mod test_helpers {
use crate::nn::parameter::Parameter;
use crate::tensor::{Tensor, TensorOptions};
pub(super) fn make_param(name: &str, shape: &[i64]) -> Parameter {
let t = Tensor::randn(shape, TensorOptions {
dtype: crate::tensor::DType::Float32,
device: crate::tensor::test_device(),
}).unwrap();
Parameter::new(t, name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::test_helpers::make_param;
use crate::nn::parameter::Parameter;
#[test]
fn test_empty_params_optimizers_no_panic() {
let empty: &[Parameter] = &[];
let mut adam = Adam::new(empty, 0.001);
adam.step().unwrap();
adam.zero_grad();
let mut sgd = SGD::new(empty, 0.01, 0.9);
sgd.step().unwrap();
sgd.zero_grad();
let mut adamw = AdamW::new(empty, 0.001, 0.01);
adamw.step().unwrap();
adamw.zero_grad();
let mut rmsprop = RMSprop::new(empty, 0.01);
rmsprop.step().unwrap();
rmsprop.zero_grad();
let mut adagrad = Adagrad::new(empty, 0.01);
adagrad.step().unwrap();
adagrad.zero_grad();
let mut radam = RAdam::new(empty, 0.01);
radam.step().unwrap();
radam.zero_grad();
let mut nadam = NAdam::new(empty, 0.01);
nadam.step().unwrap();
nadam.zero_grad();
}
#[test]
fn test_step_after_zero_grad_on_fresh_optimizer() {
let p = make_param("w", &[3, 2]);
let mut adam = Adam::new(std::slice::from_ref(&p), 0.001);
let mut sgd = SGD::new(std::slice::from_ref(&p), 0.01, 0.9);
adam.zero_grad();
adam.step().unwrap();
sgd.zero_grad();
sgd.step().unwrap();
let vals = p.variable.data().to_f32_vec().unwrap();
for (i, &v) in vals.iter().enumerate() {
assert!(v.is_finite(), "param[{}] should be finite after step-without-backward: {}", i, v);
}
}
#[test]
fn test_set_lr_all_optimizers() {
let p = make_param("w", &[2]);
let mut adam = Adam::new(std::slice::from_ref(&p), 0.001);
adam.set_lr(0.42);
assert!((adam.lr() - 0.42).abs() < 1e-12, "Adam set_lr failed");
let mut sgd = SGD::new(std::slice::from_ref(&p), 0.01, 0.0);
sgd.set_lr(0.42);
assert!((sgd.lr() - 0.42).abs() < 1e-12, "SGD set_lr failed");
let mut adamw = AdamW::new(std::slice::from_ref(&p), 0.001, 0.01);
adamw.set_lr(0.42);
assert!((adamw.lr() - 0.42).abs() < 1e-12, "AdamW set_lr failed");
let mut rmsprop = RMSprop::new(std::slice::from_ref(&p), 0.01);
rmsprop.set_lr(0.42);
assert!((rmsprop.lr() - 0.42).abs() < 1e-12, "RMSprop set_lr failed");
let mut nadam = NAdam::new(std::slice::from_ref(&p), 0.01);
nadam.set_lr(0.42);
assert!((nadam.lr() - 0.42).abs() < 1e-12, "NAdam set_lr failed");
let mut radam = RAdam::new(std::slice::from_ref(&p), 0.01);
radam.set_lr(0.42);
assert!((radam.lr() - 0.42).abs() < 1e-12, "RAdam set_lr failed");
let mut adagrad = Adagrad::new(std::slice::from_ref(&p), 0.01);
adagrad.set_lr(0.42);
assert!((adagrad.lr() - 0.42).abs() < 1e-12, "Adagrad set_lr failed");
}
}