use crate::core::inner::WarmStart;
use crate::core::math::{NegInPlace, ScaleInPlace, ScaledAdd};
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::core::solver::Solver;
use crate::core::state::BasicState;
use crate::core::termination::TerminationReason;
use crate::line_search::{Constant, LineSearch};
pub struct GradientDescent<L, V> {
line_search: L,
beta: f64,
velocity: Option<V>,
}
impl<V> GradientDescent<Constant, V> {
pub fn new(alpha: f64) -> Self {
Self {
line_search: Constant(alpha),
beta: 0.0,
velocity: None,
}
}
}
impl<L, V> GradientDescent<L, V> {
pub fn with_line_search(line_search: L) -> Self {
Self {
line_search,
beta: 0.0,
velocity: None,
}
}
pub fn with_momentum(mut self, beta: f64) -> Self {
self.beta = beta;
self
}
}
impl<P, V, L> Solver<P, BasicState<V>> for GradientDescent<L, V>
where
P: CostFunction<Param = V, Output = f64> + Gradient<Gradient = V>,
V: ScaledAdd<f64> + NegInPlace + ScaleInPlace + Clone,
L: LineSearch<P, V, Error = P::Error>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<BasicState<V>, Self::Error> {
self.velocity = None;
let (cost, grad) = problem.cost_and_gradient(&state.param)?;
state.cost = Some(cost);
state.gradient = Some(grad);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<(BasicState<V>, Option<TerminationReason>), Self::Error> {
let grad = state
.gradient
.take()
.expect("gradient not set: Solver::init must run before next_iter");
let prev_cost = state
.cost
.expect("cost not set: Solver::init must run before next_iter");
let mut direction = grad.clone();
direction.neg_in_place();
let alpha = self
.line_search
.next(problem, &state.param, prev_cost, &grad, &direction)?;
if self.beta == 0.0 {
state.param.scaled_add(alpha, &direction);
} else {
let velocity = match self.velocity.take() {
Some(mut v) => {
v.scale_in_place(self.beta);
v.scaled_add(alpha, &direction);
v
}
None => {
direction.scale_in_place(alpha);
direction
}
};
state.param.scaled_add(1.0, &velocity);
self.velocity = Some(velocity);
}
let (cost, grad) = problem.cost_and_gradient(&state.param)?;
state.cost = Some(cost);
state.gradient = Some(grad);
Ok((state, None))
}
}
impl<L, V> WarmStart<V> for GradientDescent<L, V>
where
V: Clone,
{
type State = BasicState<V>;
fn seed(&self, x: &V) -> BasicState<V> {
BasicState::new(x.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::state::State;
use crate::{BasicState, Executor};
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, Self::Error> {
Ok(x.iter().map(|v| v * v).sum())
}
}
impl Gradient for Quadratic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, Self::Error> {
Ok(x.iter().map(|v| 2.0 * v).collect())
}
}
struct IllConditioned;
impl CostFunction for IllConditioned {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, Self::Error> {
Ok(x[0] * x[0] + 100.0 * x[1] * x[1])
}
}
impl Gradient for IllConditioned {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, Self::Error> {
Ok(vec![2.0 * x[0], 200.0 * x[1]])
}
}
#[test]
fn with_momentum_zero_is_plain_descent_first_step() {
let mut solver = GradientDescent::new(0.1).with_momentum(0.0);
let mut p = Problem::new(Quadratic);
let state = solver.init(&mut p, BasicState::new(vec![1.0])).unwrap();
let (state, reason) = solver.next_iter(&mut p, state).unwrap();
assert!(reason.is_none());
assert!((state.param()[0] - 0.8).abs() < 1e-12);
}
#[test]
fn momentum_accelerates_over_plain_descent_when_ill_conditioned() {
let start = vec![1.0, 1.0];
let iters = 200;
let alpha = 0.004;
let plain = Executor::new(
IllConditioned,
GradientDescent::new(alpha),
BasicState::new(start.clone()),
)
.max_iter(iters)
.run()
.unwrap();
let momentum = Executor::new(
IllConditioned,
GradientDescent::new(alpha).with_momentum(0.9),
BasicState::new(start),
)
.max_iter(iters)
.run()
.unwrap();
assert!(
momentum.cost() < plain.cost(),
"momentum cost {} should beat plain {}",
momentum.cost(),
plain.cost()
);
}
#[test]
fn momentum_velocity_resets_between_runs() {
let start = vec![2.0, -1.0];
let mut solver = GradientDescent::new(0.05).with_momentum(0.8);
let run = |solver: &mut GradientDescent<Constant, Vec<f64>>| {
let mut p = Problem::new(Quadratic);
let mut state = solver.init(&mut p, BasicState::new(start.clone())).unwrap();
for _ in 0..10 {
let (next, _) = solver.next_iter(&mut p, state).unwrap();
state = next;
}
state.param().clone()
};
let first = run(&mut solver);
let second = run(&mut solver);
for (a, b) in first.iter().zip(second.iter()) {
assert!((a - b).abs() < 1e-12, "first={a}, second={b}");
}
}
}