mod conjugate_gradient;
mod coptimizer;
pub use conjugate_gradient::{ConjugateGradientOptimizer, ConjugateGradientOptimizerConfig};
pub use coptimizer::{AdamConfig, AdamWConfig, RmsPropConfig, SgdConfig};
use crate::logging::StatsLogger;
use log::warn;
use std::error::Error;
use tch::Tensor;
use thiserror::Error;
pub trait BaseOptimizer {
fn zero_grad(&mut self);
}
pub trait Optimizer: BaseOptimizer {
fn backward_step(
&mut self,
loss_fn: &mut dyn FnMut() -> Tensor,
logger: &mut dyn StatsLogger,
) -> Result<Tensor, OptimizerStepError>;
}
pub trait TrustRegionOptimizer: BaseOptimizer {
fn trust_region_backward_step(
&mut self,
loss_distance_fn: &mut dyn FnMut() -> (Tensor, Tensor),
max_distance: f64,
logger: &mut dyn StatsLogger,
) -> Result<f64, OptimizerStepError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum OptimizerStepError {
#[error("loss is not improving: (new) {loss} >= (prev) {loss_before}")]
LossNotImproving { loss: f64, loss_before: f64 },
#[error(
"constraint is violated: (val) {constraint_val} >= (threshold) {max_constraint_value}"
)]
ConstraintViolated {
constraint_val: f64,
max_constraint_value: f64,
},
#[error("loss is NaN")]
NaNLoss,
#[error("constraint is NaN")]
NaNConstraint,
}
impl OptimizerStepError {
#[must_use]
#[inline]
pub const fn can_continue(self) -> bool {
matches!(self, Self::NaNLoss | Self::NaNConstraint)
}
}
pub fn opt_expect_ok_log<T>(result: Result<T, OptimizerStepError>, msg: &str) -> Option<T> {
match result {
Ok(x) => Some(x),
Err(err) => {
if err.can_continue() {
warn!("{msg}\ncaused by: {err}");
None
} else {
panic!("{msg}\ncaused by: {err}");
}
}
}
}
pub trait BuildOptimizer {
type Optimizer;
type Error: Error;
fn build_optimizer<'a, I>(&self, variables: I) -> Result<Self::Optimizer, Self::Error>
where
I: IntoIterator<Item = &'a Tensor>;
}
#[cfg(test)]
mod testing {
use super::*;
use tch::{Device, Kind};
pub fn check_optimizes_quadratic<OC>(optimizer_config: &OC, num_steps: u64)
where
OC: BuildOptimizer,
OC::Optimizer: Optimizer,
{
let m = Tensor::of_slice(&[1.0_f32, -1.0, -1.0, 2.0]).reshape(&[2, 2]);
let b = Tensor::of_slice(&[2.0_f32, -3.0]);
let x = Tensor::zeros(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let mut optimizer = optimizer_config.build_optimizer([&x]).unwrap();
let mut loss_fn = || m.mv(&x).dot(&x) / 2 + b.dot(&x);
for _ in 0..num_steps {
let _ = optimizer.backward_step(&mut loss_fn, &mut ()).unwrap();
}
let expected = Tensor::of_slice(&[-1.0, 1.0]);
assert!(
f64::from((&x - &expected).norm()) < 1e-3,
"expected: {:?}, actual: {:?}",
expected,
x
);
}
pub fn check_trust_region_optimizes_quadratic<OC>(optimizer_config: &OC, num_steps: u64)
where
OC: BuildOptimizer,
OC::Optimizer: TrustRegionOptimizer,
{
let m = Tensor::of_slice(&[1.0_f32, -1.0, -1.0, 2.0]).reshape(&[2, 2]);
let b = Tensor::of_slice(&[2.0_f32, -3.0]);
let x = Tensor::zeros(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let mut optimizer = optimizer_config.build_optimizer([&x]).unwrap();
let x_last = x.detach().copy();
let mut loss_distance_fn = || {
let loss = m.mv(&x).dot(&x) / 2 + b.dot(&x);
let distance = (&x - &x_last).square().sum(Kind::Float);
(loss, distance)
};
for _ in 0..num_steps {
x_last.detach().copy_(&x);
let result =
optimizer.trust_region_backward_step(&mut loss_distance_fn, 0.001, &mut ());
match result {
Err(OptimizerStepError::LossNotImproving {
loss: _,
loss_before: _,
}) => break,
r => r.unwrap(),
};
}
let expected = Tensor::of_slice(&[-1.0, 1.0]);
assert!(
f64::from((&x - &expected).norm()) < 1e-3,
"expected: {:?}, actual: {:?}",
expected,
x
);
}
}