use super::super::utils;
use super::{BaseOptimizer, BuildOptimizer, OptimizerStepError, TrustRegionOptimizer};
use crate::logging::StatsLogger;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::convert::Infallible;
use tch::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct ConjugateGradientOptimizerConfig {
pub iterations: u64,
pub max_backtracks: u64,
pub backtrack_ratio: f64,
pub hpv_reg_coeff: f64,
pub accept_violation: bool,
}
impl Default for ConjugateGradientOptimizerConfig {
fn default() -> Self {
Self {
iterations: 10,
max_backtracks: 15,
backtrack_ratio: 0.8,
hpv_reg_coeff: 1e-5,
accept_violation: false,
}
}
}
impl BuildOptimizer for ConjugateGradientOptimizerConfig {
type Optimizer = ConjugateGradientOptimizer;
type Error = Infallible;
fn build_optimizer<'a, I>(&self, variables: I) -> Result<Self::Optimizer, Self::Error>
where
I: IntoIterator<Item = &'a Tensor>,
{
Ok(ConjugateGradientOptimizer::new(
variables.into_iter().map(Tensor::shallow_clone).collect(),
*self,
))
}
}
#[derive(Debug, PartialEq)]
pub struct ConjugateGradientOptimizer {
params: Vec<Tensor>,
config: ConjugateGradientOptimizerConfig,
}
impl ConjugateGradientOptimizer {
#[must_use]
pub fn new(variables: Vec<Tensor>, config: ConjugateGradientOptimizerConfig) -> Self {
Self {
params: variables,
config,
}
}
}
impl BaseOptimizer for ConjugateGradientOptimizer {
fn zero_grad(&mut self) {
for param in &mut self.params {
param.zero_grad();
}
}
}
impl TrustRegionOptimizer for ConjugateGradientOptimizer {
fn trust_region_backward_step(
&mut self,
loss_distance_fn: &mut dyn FnMut() -> (Tensor, Tensor),
max_distance: f64,
logger: &mut dyn StatsLogger,
) -> Result<f64, OptimizerStepError> {
let (loss, distance) = loss_distance_fn();
let loss_grads = Tensor::run_backward(&[&loss], &self.params, true, false);
let (params, loss_grads): (Vec<&Tensor>, Vec<Tensor>) = self
.params
.iter()
.zip(loss_grads.into_iter())
.filter(|(_, grad)| grad.defined())
.unzip();
let flat_loss_grads = utils::flatten_tensors(loss_grads);
let hvp_fn = HessianVectorProduct::new(&distance, ¶ms, self.config.hpv_reg_coeff);
let mut step_dir =
solve_conjugate_gradient(&hvp_fn, &flat_loss_grads, self.config.iterations, 1e-10);
let _ = step_dir.nan_to_num_(0.0, None, None);
let step_size = match ((f64::from(step_dir.dot(&hvp_fn.mat_vec_mul(&step_dir))) + 1e-8)
.recip()
* max_distance
* 2.0)
.sqrt()
{
x if x.is_nan() => 1.0,
x => x,
};
logger.log_scalar("step_size", step_size);
let descent_step = step_size * step_dir;
let initial_loss: f64 = loss.into();
self.backtracking_line_search(
¶ms,
&descent_step,
loss_distance_fn,
max_distance,
initial_loss,
logger,
)?;
Ok(initial_loss)
}
}
impl ConjugateGradientOptimizer {
fn backtracking_line_search(
&self,
params: &[&Tensor],
descent_step: &Tensor,
loss_constraint_fn: &mut dyn FnMut() -> (Tensor, Tensor),
max_constraint_value: f64,
initial_loss: f64,
logger: &mut dyn StatsLogger,
) -> Result<(), OptimizerStepError> {
let mut params: Vec<_> = params.iter().map(|t| t.detach()).collect();
let prev_params: Vec<_> = params.iter().map(Tensor::copy).collect();
let param_shapes: Vec<_> = params.iter().map(Tensor::size).collect();
let descent_step = utils::unflatten_tensors(descent_step, ¶m_shapes);
let mut loss = initial_loss;
let mut constraint_val = f64::INFINITY;
logger.log_scalar("loss_initial", loss);
for i in 0..self.config.max_backtracks {
let ratio = self.config.backtrack_ratio.powi(i.try_into().unwrap());
for ((step, prev_param), param) in descent_step
.iter()
.zip(prev_params.iter())
.zip(params.iter_mut())
{
param.copy_(&(prev_param - ratio * step));
}
let (loss_tensor, constraint_tensor) = loss_constraint_fn();
loss = loss_tensor.into();
constraint_val = constraint_tensor.into();
if loss < initial_loss && constraint_val <= max_constraint_value {
logger.log_scalar("num_backtracks", i as f64);
logger.log_scalar("step_scale", ratio);
break;
}
}
logger.log_scalar("loss_final", loss);
logger.log_scalar("constraint_val_final", constraint_val);
let result = if loss.is_nan() {
Err(OptimizerStepError::NaNLoss)
} else if constraint_val.is_nan() {
Err(OptimizerStepError::NaNConstraint)
} else if loss >= initial_loss {
Err(OptimizerStepError::LossNotImproving {
loss,
loss_before: initial_loss,
})
} else if constraint_val >= max_constraint_value && !self.config.accept_violation {
Err(OptimizerStepError::ConstraintViolated {
constraint_val,
max_constraint_value,
})
} else {
Ok(())
};
if result.is_err() {
for (param, prev_param) in params.iter_mut().zip(prev_params) {
param.copy_(&prev_param);
}
}
result
}
}
struct HessianVectorProduct<'a, T> {
params: &'a [T],
reg_coeff: f64,
param_shapes: Vec<Vec<i64>>,
grads: Vec<Tensor>,
}
impl<'a, T> HessianVectorProduct<'a, T>
where
T: Borrow<Tensor>,
{
pub fn new(output: &Tensor, params: &'a [T], reg_coeff: f64) -> Self {
let param_shapes = params.iter().map(|t| t.borrow().size()).collect();
let mut grads = Tensor::run_backward(&[output], params, true, true);
for (grad, param) in grads.iter_mut().zip(params) {
if !grad.defined() {
*grad = param.borrow().zeros_like();
}
}
Self {
params,
reg_coeff,
param_shapes,
grads,
}
}
}
impl<'a, T> MatrixVectorProduct for HessianVectorProduct<'a, T>
where
T: Borrow<Tensor>,
{
type Vector = Tensor;
fn mat_vec_mul(&self, vector: &Tensor) -> Tensor {
let unflattened_vector = utils::unflatten_tensors(vector, &self.param_shapes);
assert_eq!(self.grads.len(), unflattened_vector.len());
let grad_vector_product = Tensor::stack(
&self
.grads
.iter()
.zip(&unflattened_vector)
.map(|(g, x)| utils::flat_dot(g, x))
.collect::<Vec<_>>(),
0,
)
.sum(vector.kind());
let hpv = Tensor::run_backward(&[grad_vector_product], self.params, true, false);
let flat_output = utils::flatten_tensors(&hpv);
flat_output.g_add(&vector.g_mul_scalar(self.reg_coeff))
}
}
pub trait MatrixVectorProduct {
type Vector;
fn mat_vec_mul(&self, vector: &Self::Vector) -> Self::Vector;
}
impl MatrixVectorProduct for Tensor {
type Vector = Self;
fn mat_vec_mul(&self, vector: &Self::Vector) -> Self::Vector {
self.mv(vector)
}
}
#[allow(non_snake_case)]
fn solve_conjugate_gradient<T: MatrixVectorProduct<Vector = Tensor>>(
f_Ax: &T,
b: &Tensor,
iterations: u64,
residual_tol: f64,
) -> Tensor {
let mut x = b.zeros_like();
let mut residual = b.copy();
let mut step = b.copy();
let mut residual_norm_squared = residual.dot(&residual);
for _ in 0..iterations {
let z = f_Ax.mat_vec_mul(&step); let alpha = &residual_norm_squared / step.dot(&z); let _ = x.addcmul_(&alpha, &step); let _ = residual.addcmul_(&(-alpha), &z);
let new_residual_norm_squared = residual.dot(&residual);
if f64::from(&new_residual_norm_squared) < residual_tol {
break;
}
let mu = &new_residual_norm_squared / &residual_norm_squared;
let _ = step.g_mul_(&mu);
let _ = step.g_add_(&residual);
residual_norm_squared = new_residual_norm_squared;
}
x
}
#[cfg(test)]
mod cg_optimizer {
use super::super::testing;
use super::*;
use tch::{Device, Kind};
#[test]
fn optimizes_quadratic() {
let config = ConjugateGradientOptimizerConfig::default();
testing::check_trust_region_optimizes_quadratic(&config, 500);
}
fn trpo_run<F, G>(
optimizer: &mut ConjugateGradientOptimizer,
mut loss_distance_fn: F,
mut on_step: G,
num_steps: u64,
max_distance: f64,
) where
F: FnMut() -> (Tensor, Tensor),
G: FnMut(),
{
for _ in 0..num_steps {
on_step();
let result =
optimizer.trust_region_backward_step(&mut loss_distance_fn, max_distance, &mut ());
match result {
Err(OptimizerStepError::LossNotImproving {
loss: _,
loss_before: _,
}) => break,
r => r.unwrap(),
};
}
}
#[test]
fn shared_loss_distance_computation() {
let config = ConjugateGradientOptimizerConfig::default();
let x = Tensor::ones(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let mut optimizer = config.build_optimizer([&x]).unwrap();
let y_prev = x.square().mean(Kind::Float).detach();
let loss_distance_fn = || {
let y = x.square().mean(Kind::Float);
let loss = &y + 1.0;
let distance = (&y - &y_prev).square();
(loss, distance)
};
trpo_run(
&mut optimizer,
loss_distance_fn,
|| {
y_prev.detach().copy_(&x.square().mean(Kind::Float));
},
100,
0.001,
);
let expected = Tensor::of_slice(&[0.0, 0.0]);
assert!(
f64::from((&x - &expected).norm()) < 0.1,
"expected: {:?}, actual: {:?}",
expected,
x
);
}
#[test]
fn unused_params() {
let config = ConjugateGradientOptimizerConfig::default();
let x = Tensor::ones(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let unused = Tensor::zeros(&[3], (Kind::Float, Device::Cpu)).requires_grad_(true);
let mut optimizer = config.build_optimizer([&x, &unused]).unwrap();
let x_prev = x.detach().copy();
let loss_distance_fn = || {
let loss = x.square().sum(Kind::Float);
let distance = (&x - &x_prev).square().sum(Kind::Float);
(loss, distance)
};
trpo_run(
&mut optimizer,
loss_distance_fn,
|| x_prev.detach().copy_(&x),
100,
0.1,
);
let expected = Tensor::of_slice(&[0.0, 0.0]);
assert!(
f64::from((&x - &expected).norm()) < 0.1,
"expected: {:?}, actual: {:?}",
expected,
x
);
}
}
#[cfg(test)]
mod hessian_vector_product {
use super::*;
use tch::{Cuda, Device, Kind};
#[test]
fn quadratic_hessian() {
Cuda::is_available();
let m = Tensor::of_slice(&[1.0_f32, -1.0, -1.0, 2.0]).reshape(&[2, 2]);
let b = Tensor::of_slice(&[2.0_f32, -3.0]);
let x = Tensor::zeros(&[2], (Kind::Float, Device::Cpu)).requires_grad_(true);
let y = m.mv(&x).dot(&x) / 2 + b.dot(&x);
let params = [&x];
let hvp = HessianVectorProduct::new(&y, ¶ms, 0.0);
assert_eq!(
hvp.mat_vec_mul(&Tensor::of_slice(&[1.0_f32, 0.0])),
Tensor::of_slice(&[1.0_f32, -1.0])
);
assert_eq!(
hvp.mat_vec_mul(&Tensor::of_slice(&[0.0_f32, 1.0])),
Tensor::of_slice(&[-1.0_f32, 2.0])
);
}
}
#[cfg(test)]
#[allow(clippy::module_inception)]
mod conjugate_gradient {
use super::*;
#[test]
fn solve_2x2() {
let a = Tensor::of_slice(&[1.0_f64, -1.0, -1.0, 2.0]).reshape(&[2, 2]);
let b = Tensor::of_slice(&[-1.0_f64, 4.0]);
let tol = 1e-4;
let x = solve_conjugate_gradient(&a, &b, 10, tol);
let expected = Tensor::of_slice(&[2.0_f64, 3.0]);
assert!(
f64::from((&x - &expected).norm()) < tol,
"expected: {:?}, actual: {:?}",
expected,
x
);
}
}