use crate::primitives::Vector;
use super::line_search::{LineSearch, WolfeLineSearch};
use super::{ConvergenceStatus, OptimizationResult, Optimizer};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CGBetaFormula {
FletcherReeves,
PolakRibiere,
HestenesStiefel,
}
#[derive(Debug, Clone)]
pub struct ConjugateGradient {
pub(crate) max_iter: usize,
pub(crate) tol: f32,
pub(crate) beta_formula: CGBetaFormula,
pub(crate) restart_interval: usize,
pub(crate) line_search: WolfeLineSearch,
pub(crate) prev_direction: Option<Vector<f32>>,
pub(crate) prev_gradient: Option<Vector<f32>>,
pub(crate) iter_count: usize,
}
impl ConjugateGradient {
#[must_use]
pub fn new(max_iter: usize, tol: f32, beta_formula: CGBetaFormula) -> Self {
Self {
max_iter,
tol,
beta_formula,
restart_interval: 0, line_search: WolfeLineSearch::new(1e-4, 0.1, 50), prev_direction: None,
prev_gradient: None,
iter_count: 0,
}
}
#[must_use]
pub fn with_restart_interval(mut self, interval: usize) -> Self {
self.restart_interval = interval;
self
}
fn compute_beta(
&self,
grad_new: &Vector<f32>,
grad_old: &Vector<f32>,
d_old: &Vector<f32>,
) -> f32 {
let n = grad_new.len();
match self.beta_formula {
CGBetaFormula::FletcherReeves => {
let mut numerator = 0.0;
let mut denominator = 0.0;
for i in 0..n {
numerator += grad_new[i] * grad_new[i];
denominator += grad_old[i] * grad_old[i];
}
numerator / denominator.max(1e-12)
}
CGBetaFormula::PolakRibiere => {
let mut numerator = 0.0;
let mut denominator = 0.0;
for i in 0..n {
numerator += grad_new[i] * (grad_new[i] - grad_old[i]);
denominator += grad_old[i] * grad_old[i];
}
let beta = numerator / denominator.max(1e-12);
beta.max(0.0)
}
CGBetaFormula::HestenesStiefel => {
let mut numerator = 0.0;
let mut denominator = 0.0;
for i in 0..n {
let y_i = grad_new[i] - grad_old[i];
numerator += grad_new[i] * y_i;
denominator += d_old[i] * y_i;
}
let beta = numerator / denominator.max(1e-12);
beta.max(0.0)
}
}
}
fn norm(v: &Vector<f32>) -> f32 {
let mut sum = 0.0;
for i in 0..v.len() {
sum += v[i] * v[i];
}
sum.sqrt()
}
fn steepest_descent(grad: &Vector<f32>, n: usize) -> Vector<f32> {
let mut d = Vector::zeros(n);
for i in 0..n {
d[i] = -grad[i];
}
d
}
fn compute_search_direction(&self, grad: &Vector<f32>, n: usize) -> Vector<f32> {
let (d_old, g_old) = match (&self.prev_direction, &self.prev_gradient) {
(Some(d), Some(g)) => (d, g),
_ => return Self::steepest_descent(grad, n),
};
if self.restart_interval > 0 && self.iter_count.is_multiple_of(self.restart_interval) {
return Self::steepest_descent(grad, n);
}
let beta = self.compute_beta(grad, g_old, d_old);
let mut d_new = Vector::zeros(n);
for i in 0..n {
d_new[i] = -grad[i] + beta * d_old[i];
}
let mut grad_dot_d = 0.0;
for i in 0..n {
grad_dot_d += grad[i] * d_new[i];
}
if grad_dot_d >= 0.0 {
return Self::steepest_descent(grad, n);
}
d_new
}
fn make_result(
solution: Vector<f32>,
objective_value: f32,
iterations: usize,
status: ConvergenceStatus,
gradient_norm: f32,
elapsed_time: std::time::Duration,
) -> OptimizationResult {
OptimizationResult {
solution,
objective_value,
iterations,
status,
gradient_norm,
constraint_violation: 0.0,
elapsed_time,
}
}
}
impl Optimizer for ConjugateGradient {
fn step(&mut self, _params: &mut Vector<f32>, _gradients: &Vector<f32>) {
panic!(
"Conjugate Gradient does not support stochastic updates (step). Use minimize() for batch optimization."
)
}
fn minimize<F, G>(&mut self, objective: F, gradient: G, x0: Vector<f32>) -> OptimizationResult
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
{
let start_time = std::time::Instant::now();
let n = x0.len();
self.prev_direction = None;
self.prev_gradient = None;
self.iter_count = 0;
let mut x = x0;
let mut fx = objective(&x);
let mut grad = gradient(&x);
let mut grad_norm = Self::norm(&grad);
for iter in 0..self.max_iter {
if grad_norm < self.tol {
return Self::make_result(
x,
fx,
iter,
ConvergenceStatus::Converged,
grad_norm,
start_time.elapsed(),
);
}
let d = self.compute_search_direction(&grad, n);
let alpha = self.line_search.search(&objective, &gradient, &x, &d);
if alpha < 1e-12 {
return Self::make_result(
x,
fx,
iter,
ConvergenceStatus::Stalled,
grad_norm,
start_time.elapsed(),
);
}
let mut x_new = Vector::zeros(n);
for i in 0..n {
x_new[i] = x[i] + alpha * d[i];
}
let fx_new = objective(&x_new);
let grad_new = gradient(&x_new);
if fx_new.is_nan() || fx_new.is_infinite() {
return Self::make_result(
x,
fx,
iter,
ConvergenceStatus::NumericalError,
grad_norm,
start_time.elapsed(),
);
}
self.prev_direction = Some(d);
self.prev_gradient = Some(grad);
x = x_new;
fx = fx_new;
grad = grad_new;
grad_norm = Self::norm(&grad);
self.iter_count += 1;
}
Self::make_result(
x,
fx,
self.max_iter,
ConvergenceStatus::MaxIterations,
grad_norm,
start_time.elapsed(),
)
}
fn reset(&mut self) {
self.prev_direction = None;
self.prev_gradient = None;
self.iter_count = 0;
}
}
#[cfg(test)]
#[path = "conjugate_gradient_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_cg_contract.rs"]
mod tests_cg_contract;