use crate::core::math::{NegInPlace, ScaledAdd};
use crate::core::problem::{CostFunction, Gradient};
use crate::core::solver::Solver;
use crate::core::state::BasicState;
use crate::core::termination::TerminationReason;
use crate::line_search::{Constant, LineSearch};
pub struct GradientDescent<S> {
line_search: S,
}
impl GradientDescent<Constant> {
pub fn new(alpha: f64) -> Self {
Self {
line_search: Constant(alpha),
}
}
}
impl<S> GradientDescent<S> {
pub fn with_line_search(line_search: S) -> Self {
Self { line_search }
}
}
impl<P, V, S> Solver<P, BasicState<V>> for GradientDescent<S>
where
P: CostFunction<Param = V, Output = f64> + Gradient<Param = V, Gradient = V>,
V: ScaledAdd<f64> + NegInPlace + Clone,
S: LineSearch<P, V>,
{
fn init(&mut self, problem: &P, mut state: BasicState<V>) -> BasicState<V> {
state.cost = Some(problem.cost(&state.param));
state.gradient = Some(problem.gradient(&state.param));
state.cost_evals += 1;
state.gradient_evals += 1;
state
}
fn next_iter(
&mut self,
problem: &P,
mut state: BasicState<V>,
) -> (BasicState<V>, Option<TerminationReason>) {
let grad = state
.gradient
.take()
.expect("gradient not set: Solver::init must run before next_iter");
let prev_cost = state
.cost
.expect("cost not set: Solver::init must run before next_iter");
let mut direction = grad.clone();
direction.neg_in_place();
let step = self
.line_search
.next(problem, &state.param, prev_cost, &grad, &direction);
state.cost_evals += step.cost_evals;
state.gradient_evals += step.gradient_evals;
state.param.scaled_add(step.alpha, &direction);
state.cost = Some(problem.cost(&state.param));
state.gradient = Some(problem.gradient(&state.param));
state.cost_evals += 1;
state.gradient_evals += 1;
(state, None)
}
}