use anyhow::Result;
use core::f64;
use serde::{Deserialize, Serialize};
use tch::{
nn::{Adam, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore},
Tensor,
};
#[cfg(not(feature = "adam_eps"))]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub enum OptimizerConfig {
Adam {
lr: f64,
},
}
#[cfg(feature = "adam_eps")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub enum OptimizerConfig {
Adam {
lr: f64,
},
AdamEps {
lr: f64,
eps: f64,
},
}
#[cfg(not(feature = "adam_eps"))]
impl OptimizerConfig {
pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
match &self {
OptimizerConfig::Adam { lr } => {
let opt = Adam::default().build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
}
}
}
#[cfg(feature = "adam_eps")]
impl OptimizerConfig {
pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
match &self {
OptimizerConfig::Adam { lr } => {
let opt = Adam::default().build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
OptimizerConfig::AdamEps { lr, eps } => {
let mut opt = Adam::default();
opt.eps = *eps;
let opt = opt.build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
}
}
}
pub enum Optimizer {
Adam(Optimizer_),
}
impl Optimizer {
pub fn backward_step(&mut self, loss: &Tensor) {
match self {
Self::Adam(opt) => {
opt.backward_step(loss);
}
}
}
}