use crate::error::OptimizeError;
#[derive(Debug, Clone)]
pub struct SvrgOptimizer {
pub lr: f64,
pub inner_iters: usize,
snapshot: Vec<f64>,
full_grad: Vec<f64>,
inner_step: usize,
}
impl SvrgOptimizer {
pub fn new(lr: f64, inner_iters: usize) -> Self {
Self {
lr,
inner_iters,
snapshot: Vec::new(),
full_grad: Vec::new(),
inner_step: 0,
}
}
pub fn update_snapshot(&mut self, params: &[f64], full_gradient: Vec<f64>) {
self.snapshot = params.to_vec();
self.full_grad = full_gradient;
self.inner_step = 0;
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
stoch_grad_current: &[f64],
stoch_grad_snapshot: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if stoch_grad_current.len() != n || stoch_grad_snapshot.len() != n {
return Err(OptimizeError::ValueError(format!(
"length mismatch: params={}, current_grad={}, snapshot_grad={}",
n,
stoch_grad_current.len(),
stoch_grad_snapshot.len()
)));
}
if self.full_grad.len() != n {
return Err(OptimizeError::ValueError(
"Snapshot not initialised; call update_snapshot first".to_string(),
));
}
for i in 0..n {
let g = stoch_grad_current[i] - stoch_grad_snapshot[i] + self.full_grad[i];
params[i] -= self.lr * g;
}
self.inner_step += 1;
Ok(())
}
pub fn epoch_done(&self) -> bool {
self.inner_step >= self.inner_iters
}
pub fn reset(&mut self) {
self.snapshot.clear();
self.full_grad.clear();
self.inner_step = 0;
}
pub fn snapshot(&self) -> &[f64] {
&self.snapshot
}
pub fn run_epoch<F>(
&mut self,
params: &mut Vec<f64>,
full_gradient: Vec<f64>,
mut grad_fn: F,
) -> Result<(), OptimizeError>
where
F: FnMut(&[f64], &[f64]) -> (Vec<f64>, Vec<f64>),
{
self.update_snapshot(params, full_gradient);
let snapshot_copy = self.snapshot.clone();
for _ in 0..self.inner_iters {
let (gc, gs) = grad_fn(params, &snapshot_copy);
self.step(params, &gc, &gs)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SarahOptimizer {
pub lr: f64,
pub inner_iters: usize,
v: Vec<f64>,
prev_params: Vec<f64>,
inner_step: usize,
}
impl SarahOptimizer {
pub fn new(lr: f64, inner_iters: usize) -> Self {
Self {
lr,
inner_iters,
v: Vec::new(),
prev_params: Vec::new(),
inner_step: 0,
}
}
pub fn init_epoch(&mut self, params: &[f64], full_gradient: Vec<f64>) {
self.prev_params = params.to_vec();
self.v = full_gradient;
self.inner_step = 0;
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
grad_current: &[f64],
grad_prev: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if grad_current.len() != n || grad_prev.len() != n {
return Err(OptimizeError::ValueError(format!(
"length mismatch: params={}, grad_current={}, grad_prev={}",
n,
grad_current.len(),
grad_prev.len()
)));
}
if self.v.len() != n {
return Err(OptimizeError::ValueError(
"SARAH not initialised; call init_epoch first".to_string(),
));
}
for i in 0..n {
self.v[i] = grad_current[i] - grad_prev[i] + self.v[i];
params[i] -= self.lr * self.v[i];
}
self.prev_params = params.clone();
self.inner_step += 1;
Ok(())
}
pub fn epoch_done(&self) -> bool {
self.inner_step >= self.inner_iters
}
pub fn reset(&mut self) {
self.v.clear();
self.prev_params.clear();
self.inner_step = 0;
}
pub fn prev_params(&self) -> &[f64] {
&self.prev_params
}
pub fn run_epoch<F>(
&mut self,
params: &mut Vec<f64>,
full_gradient: Vec<f64>,
mut grad_fn: F,
) -> Result<(), OptimizeError>
where
F: FnMut(&[f64], &[f64]) -> (Vec<f64>, Vec<f64>),
{
self.init_epoch(params, full_gradient);
for _ in 0..self.inner_iters {
let prev = self.prev_params.clone();
let (gc, gp) = grad_fn(params, &prev);
self.step(params, &gc, &gp)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SpiderOptimizer {
pub lr: f64,
pub batch_size: usize,
pub inner_iters: usize,
v: Vec<f64>,
prev_params: Vec<f64>,
inner_step: usize,
}
impl SpiderOptimizer {
pub fn new(lr: f64, batch_size: usize, inner_iters: usize) -> Self {
Self {
lr,
batch_size,
inner_iters,
v: Vec::new(),
prev_params: Vec::new(),
inner_step: 0,
}
}
pub fn refresh(&mut self, params: &[f64], large_batch_grad: Vec<f64>) {
self.prev_params = params.to_vec();
self.v = large_batch_grad;
self.inner_step = 0;
}
pub fn refresh_needed(&self) -> bool {
self.inner_step >= self.inner_iters || self.v.is_empty()
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
grad_current: &[f64],
grad_prev: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if grad_current.len() != n || grad_prev.len() != n {
return Err(OptimizeError::ValueError(format!(
"length mismatch: params={}, grad_current={}, grad_prev={}",
n,
grad_current.len(),
grad_prev.len()
)));
}
if self.v.len() != n {
return Err(OptimizeError::ValueError(
"SPIDER not initialised; call refresh first".to_string(),
));
}
for i in 0..n {
self.v[i] = grad_current[i] - grad_prev[i] + self.v[i];
params[i] -= self.lr * self.v[i];
}
self.prev_params = params.clone();
self.inner_step += 1;
Ok(())
}
pub fn reset(&mut self) {
self.v.clear();
self.prev_params.clear();
self.inner_step = 0;
}
pub fn prev_params(&self) -> &[f64] {
&self.prev_params
}
pub fn run_outer_iter<F, G>(
&mut self,
params: &mut Vec<f64>,
mut large_grad_fn: F,
mut mini_grad_fn: G,
) -> Result<(), OptimizeError>
where
F: FnMut(&[f64]) -> Vec<f64>,
G: FnMut(&[f64], &[f64]) -> (Vec<f64>, Vec<f64>),
{
let large_grad = large_grad_fn(params);
self.refresh(params, large_grad);
for _ in 0..self.inner_iters {
let prev = self.prev_params.clone();
let (gc, gp) = mini_grad_fn(params, &prev);
self.step(params, &gc, &gp)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn full_grad(x: &[f64]) -> Vec<f64> {
x.iter().map(|&xi| 2.0 * xi).collect()
}
#[test]
fn test_svrg_converges() {
let mut opt = SvrgOptimizer::new(0.05, 20);
let mut params = vec![2.0, -1.5];
for _epoch in 0..50 {
let fg = full_grad(¶ms);
opt.run_epoch(&mut params, fg, |p, snap| {
(full_grad(p), full_grad(snap))
})
.expect("epoch failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_svrg_epoch_done() {
let mut opt = SvrgOptimizer::new(0.1, 5);
let params = vec![1.0, 1.0];
let fg = full_grad(¶ms);
opt.update_snapshot(¶ms, fg);
assert!(!opt.epoch_done());
for _ in 0..5 {
let snap = opt.snapshot().to_vec();
let mut p = params.clone();
let fg = full_grad(&p);
let sg = full_grad(&snap);
opt.step(&mut p, &fg, &sg)
.expect("step failed");
}
assert!(opt.epoch_done());
}
#[test]
fn test_sarah_converges() {
let mut opt = SarahOptimizer::new(0.05, 20);
let mut params = vec![3.0, -1.0];
for _epoch in 0..50 {
let fg = full_grad(¶ms);
opt.run_epoch(&mut params, fg, |p, prev| {
(full_grad(p), full_grad(prev))
})
.expect("epoch failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-2);
}
}
#[test]
fn test_spider_converges() {
let mut opt = SpiderOptimizer::new(0.05, 10, 10);
let mut params = vec![2.0, -2.0];
for _outer in 0..50 {
opt.run_outer_iter(
&mut params,
|p| full_grad(p),
|p, prev| (full_grad(p), full_grad(prev)),
)
.expect("outer iter failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-2);
}
}
#[test]
fn test_svrg_length_mismatch() {
let mut opt = SvrgOptimizer::new(0.1, 5);
let params = vec![1.0, 2.0];
opt.update_snapshot(¶ms, vec![0.0, 0.0]);
let mut p = params.clone();
let result = opt.step(&mut p, &[0.1], &[0.1, 0.2]);
assert!(result.is_err());
}
}