use rand::seq::SliceRandom;
use crate::core::math::{Scalar, ScaleInPlace, ScaledAdd};
use crate::core::problem::{CostFunction, MiniBatchGradient, Problem};
use crate::core::rng::{ChaCha8Rng, SeedableRng};
use crate::core::solver::Solver;
use crate::core::state::BasicState;
use crate::core::termination::TerminationReason;
pub struct Sgd<V, F = f64> {
alpha: F,
batch_size: usize,
seed: u64,
beta: F,
velocity: Option<V>,
rng: Option<ChaCha8Rng>,
perm: Vec<usize>,
cursor: usize,
effective_batch: usize,
cost_eval_every: Option<usize>,
cost_period: usize,
iters_since_cost: usize,
}
impl<V, F: Scalar> Sgd<V, F> {
pub fn new(alpha: F, batch_size: usize, seed: u64) -> Self {
assert!(batch_size > 0, "Sgd: batch_size must be > 0");
Self {
alpha,
batch_size,
seed,
beta: F::zero(),
velocity: None,
rng: None,
perm: Vec::new(),
cursor: 0,
effective_batch: 0,
cost_eval_every: None,
cost_period: 1,
iters_since_cost: 0,
}
}
pub fn with_momentum(mut self, beta: F) -> Self {
self.beta = beta;
self
}
pub fn with_cost_eval_every(mut self, period: usize) -> Self {
assert!(period > 0, "Sgd: cost_eval_every period must be > 0");
self.cost_eval_every = Some(period);
self
}
}
impl<P, V, F> Solver<P, BasicState<V, F>> for Sgd<V, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F> + MiniBatchGradient<Gradient = V>,
V: ScaledAdd<F> + ScaleInPlace<F> + Clone,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V, F>,
) -> Result<BasicState<V, F>, Self::Error> {
self.velocity = None;
let n = problem.inner().n_samples();
assert!(
n > 0,
"Sgd: problem.n_samples() == 0; no batches to draw from",
);
self.effective_batch = self.batch_size.min(n);
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
self.perm = (0..n).collect();
self.perm.as_mut_slice().shuffle(&mut rng);
self.cursor = 0;
self.rng = Some(rng);
let batches_per_epoch = (n / self.effective_batch).max(1);
self.cost_period = self.cost_eval_every.unwrap_or(batches_per_epoch);
self.iters_since_cost = 0;
let cost = problem.cost(&state.param)?;
state.cost = Some(cost);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V, F>,
) -> Result<(BasicState<V, F>, Option<TerminationReason>), Self::Error> {
let bs = self.effective_batch;
let n = self.perm.len();
if self.cursor + bs > n {
let rng = self
.rng
.as_mut()
.expect("rng not set: Solver::init must run before next_iter");
self.perm.as_mut_slice().shuffle(rng);
self.cursor = 0;
}
let batch = &self.perm[self.cursor..self.cursor + bs];
let grad = problem.batch_gradient(&state.param, batch)?;
self.cursor += bs;
if self.beta == F::zero() {
state.param.scaled_add(-self.alpha, &grad);
} else {
let velocity = match self.velocity.take() {
Some(mut v) => {
v.scale_in_place(self.beta);
v.scaled_add(-self.alpha, &grad);
v
}
None => {
let mut v = grad;
v.scale_in_place(-self.alpha);
v
}
};
state.param.scaled_add(F::one(), &velocity);
self.velocity = Some(velocity);
}
self.iters_since_cost += 1;
if self.iters_since_cost >= self.cost_period {
let cost = problem.cost(&state.param)?;
state.cost = Some(cost);
self.iters_since_cost = 0;
}
Ok((state, None))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::state::State;
use crate::{BasicState, Executor, MaxIter};
struct FiniteSumQuadratic {
centers: Vec<Vec<f64>>,
}
impl FiniteSumQuadratic {
fn centroid(&self) -> Vec<f64> {
let d = self.centers[0].len();
let n = self.centers.len() as f64;
let mut c = vec![0.0; d];
for ci in &self.centers {
for (cj, &v) in c.iter_mut().zip(ci) {
*cj += v;
}
}
for cj in &mut c {
*cj /= n;
}
c
}
}
impl CostFunction for FiniteSumQuadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, Self::Error> {
let n = self.centers.len() as f64;
let mut s = 0.0;
for c in &self.centers {
for (xi, ci) in x.iter().zip(c) {
let d = xi - ci;
s += d * d;
}
}
Ok(s / n)
}
}
impl MiniBatchGradient for FiniteSumQuadratic {
type Gradient = Vec<f64>;
fn n_samples(&self) -> usize {
self.centers.len()
}
fn batch_gradient(&self, x: &Vec<f64>, batch: &[usize]) -> Result<Vec<f64>, Self::Error> {
let d = x.len();
let inv = 2.0 / batch.len() as f64;
let mut g = vec![0.0; d];
for &i in batch {
let c = &self.centers[i];
for (gj, (xj, cj)) in g.iter_mut().zip(x.iter().zip(c)) {
*gj += inv * (xj - cj);
}
}
Ok(g)
}
}
fn problem_5_centers() -> FiniteSumQuadratic {
FiniteSumQuadratic {
centers: vec![
vec![1.0, 0.0],
vec![2.0, 1.0],
vec![0.0, 2.0],
vec![-1.0, 1.0],
vec![3.0, -1.0],
],
}
}
#[test]
fn converges_to_centroid_without_momentum() {
let problem = problem_5_centers();
let centroid = problem.centroid();
let sgd = Sgd::new(0.01, 2, 0xABCDEF);
let result = Executor::new(problem, sgd, BasicState::new(vec![0.0, 0.0]))
.terminate_on(MaxIter(3_000))
.run()
.unwrap();
let x = result.param();
for (xi, ci) in x.iter().zip(centroid.iter()) {
assert!((xi - ci).abs() < 5e-2, "x = {x:?}, centroid = {centroid:?}",);
}
}
#[test]
fn full_batch_recovers_deterministic_gradient_descent() {
let problem = problem_5_centers();
let centroid = problem.centroid();
let sgd = Sgd::new(0.1, problem.n_samples(), 0);
let result = Executor::new(problem, sgd, BasicState::new(vec![0.0, 0.0]))
.terminate_on(MaxIter(500))
.run()
.unwrap();
let x = result.param();
for (xi, ci) in x.iter().zip(centroid.iter()) {
assert!((xi - ci).abs() < 1e-6, "x={x:?}, centroid={centroid:?}");
}
}
#[test]
fn same_seed_same_trajectory() {
let problem_a = problem_5_centers();
let problem_b = problem_5_centers();
let run = |p: FiniteSumQuadratic| {
let sgd = Sgd::new(0.05, 2, 12345);
Executor::new(p, sgd, BasicState::new(vec![0.5, -0.5]))
.terminate_on(MaxIter(50))
.run()
.unwrap()
.param()
.clone()
};
let xa = run(problem_a);
let xb = run(problem_b);
for (a, b) in xa.iter().zip(xb.iter()) {
assert!((a - b).abs() < 1e-15, "xa={xa:?}, xb={xb:?}");
}
}
#[test]
fn different_seeds_diverge() {
let run = |seed: u64| {
let sgd = Sgd::new(0.05, 2, seed);
Executor::new(problem_5_centers(), sgd, BasicState::new(vec![0.5, -0.5]))
.terminate_on(MaxIter(20))
.run()
.unwrap()
.param()
.clone()
};
let xa = run(1);
let xb = run(2);
let diff: f64 = xa.iter().zip(xb.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff > 1e-6, "seeds 1 and 2 produced identical trajectory");
}
#[test]
fn momentum_resets_between_runs() {
let start = vec![1.0, 1.0];
let mut sgd = Sgd::new(0.03, 2, 7).with_momentum(0.85);
let run_once = |solver: &mut Sgd<Vec<f64>>| {
let mut p = Problem::new(problem_5_centers());
let mut state = solver.init(&mut p, BasicState::new(start.clone())).unwrap();
for _ in 0..15 {
let (next, _) = solver.next_iter(&mut p, state).unwrap();
state = next;
}
state.param().clone()
};
let first = run_once(&mut sgd);
let second = run_once(&mut sgd);
for (a, b) in first.iter().zip(second.iter()) {
assert!((a - b).abs() < 1e-15, "first={first:?}, second={second:?}");
}
}
#[test]
fn reshuffles_at_epoch_boundary() {
let problem = FiniteSumQuadratic {
centers: (0..7).map(|i| vec![i as f64, -(i as f64)]).collect(),
};
let mut sgd = Sgd::new(0.01, 3, 99);
let mut p = Problem::new(problem);
let mut state = sgd.init(&mut p, BasicState::new(vec![0.0, 0.0])).unwrap();
for _ in 0..3 {
let (next, _) = sgd.next_iter(&mut p, state).unwrap();
state = next;
}
assert_eq!(sgd.cursor, 3);
}
#[test]
fn batch_size_clamped_to_n_samples() {
let problem = FiniteSumQuadratic {
centers: vec![vec![1.0], vec![2.0], vec![3.0]],
};
let centroid = problem.centroid();
let sgd = Sgd::new(0.05, 10, 13);
let result = Executor::new(problem, sgd, BasicState::new(vec![0.0]))
.terminate_on(MaxIter(500))
.run()
.unwrap();
assert!((result.param()[0] - centroid[0]).abs() < 1e-3);
}
#[test]
fn cost_refresh_default_is_epoch_boundary() {
let problem = problem_5_centers();
let initial_cost = problem.cost(&vec![10.0, 10.0]).unwrap();
let mut sgd = Sgd::new(0.05, 2, 42);
let mut p = Problem::new(problem);
let state = sgd.init(&mut p, BasicState::new(vec![10.0, 10.0])).unwrap();
assert_eq!(state.cost(), initial_cost);
let (state, _) = sgd.next_iter(&mut p, state).unwrap();
assert_eq!(
state.cost(),
initial_cost,
"default schedule must hold state.cost stale within an epoch",
);
let (state, _) = sgd.next_iter(&mut p, state).unwrap();
assert_ne!(
state.cost(),
initial_cost,
"default schedule must refresh at the epoch boundary (iter 2)",
);
}
#[test]
fn with_cost_eval_every_one_refreshes_per_iter() {
let problem = problem_5_centers();
let initial_cost = problem.cost(&vec![10.0, 10.0]).unwrap();
let mut sgd = Sgd::new(0.05, 2, 42).with_cost_eval_every(1);
let mut p = Problem::new(problem);
let state = sgd.init(&mut p, BasicState::new(vec![10.0, 10.0])).unwrap();
let (state, _) = sgd.next_iter(&mut p, state).unwrap();
assert_ne!(
state.cost(),
initial_cost,
"with_cost_eval_every(1) must refresh state.cost after every step",
);
}
#[test]
fn zero_momentum_matches_plain_sgd_branch() {
let problem = problem_5_centers();
let centroid = problem.centroid();
let mut sgd = Sgd::new(0.1, 5, 0).with_momentum(0.0);
let mut p = Problem::new(problem);
let state = sgd.init(&mut p, BasicState::new(vec![1.0, 1.0])).unwrap();
let (state, reason) = sgd.next_iter(&mut p, state).unwrap();
assert!(reason.is_none());
let alpha = 0.1;
let x0 = [1.0, 1.0];
let expected: Vec<f64> = x0
.iter()
.zip(centroid.iter())
.map(|(x, c)| x - alpha * 2.0 * (x - c))
.collect();
for (xi, ei) in state.param().iter().zip(expected.iter()) {
assert!(
(xi - ei).abs() < 1e-12,
"got {:?}, expected {:?}",
state.param(),
expected
);
}
}
}