use super::{BaseOptimizer, BuildOptimizer, Optimizer, OptimizerStepError};
use crate::logging::StatsLogger;
use serde::{Deserialize, Serialize};
use tch::{COptimizer, TchError, Tensor};
impl BaseOptimizer for COptimizer {
fn zero_grad(&mut self) {
Self::zero_grad(self).unwrap();
}
}
impl Optimizer for COptimizer {
fn backward_step(
&mut self,
loss_fn: &mut dyn FnMut() -> Tensor,
_: &mut dyn StatsLogger,
) -> Result<Tensor, OptimizerStepError> {
let loss = loss_fn();
BaseOptimizer::zero_grad(self);
loss.backward();
self.step().unwrap();
Ok(loss)
}
}
impl<T> BuildOptimizer for T
where
for<'a> &'a Self: TryInto<COptimizer, Error = TchError>,
{
type Optimizer = COptimizer;
type Error = TchError;
fn build_optimizer<'a, I>(&self, variables: I) -> Result<Self::Optimizer, Self::Error>
where
I: IntoIterator<Item = &'a Tensor>,
{
let mut optimizer: COptimizer = self.try_into()?;
for tensor in variables {
optimizer.add_parameters(tensor, 0)?;
}
Ok(optimizer)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SgdConfig {
pub learning_rate: f64,
pub momentum: f64,
pub weight_decay: f64,
pub dampening: f64,
pub nesterov: bool,
}
impl Default for SgdConfig {
fn default() -> Self {
Self {
learning_rate: 1e-2,
momentum: 0.0,
weight_decay: 0.0,
dampening: 0.0,
nesterov: false,
}
}
}
impl TryFrom<&SgdConfig> for COptimizer {
type Error = TchError;
fn try_from(config: &SgdConfig) -> Result<Self, Self::Error> {
Self::sgd(
config.learning_rate,
config.momentum,
config.dampening,
config.weight_decay,
config.nesterov,
)
}
}
#[allow(clippy::doc_markdown)] #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RmsPropConfig {
pub learning_rate: f64,
pub momentum: f64,
pub alpha: f64,
pub eps: f64,
pub centered: bool,
pub weight_decay: f64,
}
impl Default for RmsPropConfig {
fn default() -> Self {
Self {
learning_rate: 1e-2,
momentum: 0.0,
alpha: 0.99,
eps: 1e-8,
centered: false,
weight_decay: 0.0,
}
}
}
impl TryFrom<&RmsPropConfig> for COptimizer {
type Error = TchError;
fn try_from(config: &RmsPropConfig) -> Result<Self, Self::Error> {
Self::rms_prop(
config.learning_rate,
config.alpha,
config.eps,
config.weight_decay,
config.momentum,
config.centered,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct AdamConfig {
pub learning_rate: f64,
pub beta1: f64,
pub beta2: f64,
pub weight_decay: f64,
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
weight_decay: 0.0,
}
}
}
impl TryFrom<&AdamConfig> for COptimizer {
type Error = TchError;
fn try_from(config: &AdamConfig) -> Result<Self, Self::Error> {
Self::adam(
config.learning_rate,
config.beta1,
config.beta2,
config.weight_decay,
)
}
}
#[allow(clippy::doc_markdown)]
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct AdamWConfig {
pub learning_rate: f64,
pub beta1: f64,
pub beta2: f64,
pub weight_decay: f64,
}
impl Default for AdamWConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
weight_decay: 0.0,
}
}
}
impl TryFrom<&AdamWConfig> for COptimizer {
type Error = TchError;
fn try_from(config: &AdamWConfig) -> Result<Self, Self::Error> {
Self::adamw(
config.learning_rate,
config.beta1,
config.beta2,
config.weight_decay,
)
}
}
#[cfg(test)]
#[allow(clippy::module_inception)]
mod coptimizer {
use super::super::{testing, Optimizer};
use super::*;
use tch::{Device, Kind};
#[test]
fn sgd_optimizes_quadratic() {
let config = SgdConfig {
learning_rate: 1e-1,
..SgdConfig::default()
};
testing::check_optimizes_quadratic(&config, 500);
}
#[test]
fn sgd_nan_loss() {
let x = Tensor::zeros(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let mut optimizer = SgdConfig::default().build_optimizer([&x]).unwrap();
#[allow(clippy::eq_op)]
let _ = optimizer
.backward_step(&mut (|| (&x / &x).sum(Kind::Float)), &mut ())
.unwrap();
}
#[test]
fn rms_prop_optimizes_quadratic() {
let config = RmsPropConfig {
learning_rate: 1e-1,
..RmsPropConfig::default()
};
testing::check_optimizes_quadratic(&config, 500);
}
#[test]
fn adam_optimizes_quadratic() {
let config = RmsPropConfig {
learning_rate: 1e-1,
..RmsPropConfig::default()
};
testing::check_optimizes_quadratic(&config, 500);
}
#[test]
fn adam_w_optimizes_quadratic() {
let config = RmsPropConfig {
learning_rate: 1e-1,
..RmsPropConfig::default()
};
testing::check_optimizes_quadratic(&config, 500);
}
}