use super::*;
#[derive(derive_getters::Getters)]
pub struct SpectralProjectedGradient {
grad_tol: Floating,
x: DVector<Floating>,
k: usize,
lower_bound: DVector<Floating>,
upper_bound: DVector<Floating>,
lambda: Floating,
lambda_min: Floating,
lambda_max: Floating,
}
impl SpectralProjectedGradient {
pub fn with_lambdas(mut self, lambda_min: Floating, lambda_max: Floating) -> Self {
self.lambda_min = lambda_min;
self.lambda_max = lambda_max;
self
}
pub fn new(
grad_tol: Floating,
x0: DVector<Floating>,
oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
lower_bound: DVector<Floating>,
upper_bound: DVector<Floating>,
) -> Self {
let x0 = x0.box_projection(&lower_bound, &upper_bound);
let lambda_min = 1e-3;
let lambda_max = 1e3;
let eval0 = oracle(&x0);
let direction0 = &x0 - eval0.g();
let direction0 = direction0.box_projection(&lower_bound, &upper_bound);
let direction0 = direction0 - &x0;
let lambda = (1. / direction0.infinity_norm())
.min(lambda_max)
.max(lambda_min);
Self {
grad_tol,
x: x0,
k: 0,
lower_bound,
upper_bound,
lambda,
lambda_min,
lambda_max,
}
}
}
impl HasBounds for SpectralProjectedGradient {
fn lower_bound(&self) -> &DVector<Floating> {
&self.lower_bound
}
fn set_lower_bound(&mut self, lower_bound: DVector<Floating>) {
self.lower_bound = lower_bound;
}
fn set_upper_bound(&mut self, upper_bound: DVector<Floating>) {
self.upper_bound = upper_bound;
}
fn upper_bound(&self) -> &DVector<Floating> {
&self.upper_bound
}
}
impl ComputeDirection for SpectralProjectedGradient {
fn compute_direction(
&mut self,
eval: &FuncEvalMultivariate,
) -> Result<DVector<Floating>, SolverError> {
let direction = &self.x - self.lambda * eval.g();
let direction = direction.box_projection(&self.lower_bound, &self.upper_bound);
let direction = direction - &self.x;
Ok(direction)
}
}
impl LineSearchSolver for SpectralProjectedGradient {
fn has_converged(&self, eval: &FuncEvalMultivariate) -> bool {
let projected_gradient = self.projected_gradient(eval);
projected_gradient.infinity_norm() < self.grad_tol
}
fn xk(&self) -> &DVector<Floating> {
&self.x
}
fn xk_mut(&mut self) -> &mut DVector<Floating> {
&mut self.x
}
fn k(&self) -> &usize {
&self.k
}
fn k_mut(&mut self) -> &mut usize {
&mut self.k
}
fn update_next_iterate<LS: LineSearch>(
&mut self,
line_search: &mut LS,
eval_x_k: &FuncEvalMultivariate, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
direction: &DVector<Floating>,
max_iter_line_search: usize,
) -> Result<(), SolverError> {
let step = line_search.compute_step_len(
self.xk(),
eval_x_k,
direction,
oracle,
max_iter_line_search,
);
let xk = self.xk();
debug!(target: "spectral_projected_gradient", "ITERATE: {} + {} * {} = {}", xk, step, direction, xk + step * direction);
let next_iterate = xk + step * direction;
let s_k = &next_iterate - xk;
let y_k = oracle(&next_iterate).g() - eval_x_k.g();
*self.xk_mut() = next_iterate;
let skyk = s_k.dot(&y_k);
if skyk <= 0. {
debug!(target: "spectral_projected_gradient", "skyk = {} <= 0. Resetting lambda to lambda_max", skyk);
self.lambda = self.lambda_max;
return Ok(());
}
let sksk = s_k.dot(&s_k);
self.lambda = (sksk / skyk).min(self.lambda_max).max(self.lambda_min);
Ok(())
}
}
#[cfg(test)]
mod spg_test {
use super::*;
#[test]
pub fn constrained_spg_backtracking() {
std::env::set_var("RUST_LOG", "info");
let _ = Tracer::default()
.with_stdout_layer(Some(LogFormat::Normal))
.build();
let gamma = 1e9;
let mut f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
let g = DVector::from(vec![x[0], gamma * x[1]]);
(f, g).into()
};
let lower_bounds = DVector::from_vec(vec![-1., 47.]);
let upper_oounds = DVector::from_vec(vec![f64::INFINITY, f64::INFINITY]);
let c1 = 1e-4;
let m = 10;
let mut ls = GLLQuadratic::new(c1, m);
let tol = 1e-12;
let x_0 = DVector::from(vec![180.0, 152.0]);
let mut gd =
SpectralProjectedGradient::new(tol, x_0, &mut f_and_g, lower_bounds, upper_oounds);
let max_iter_solver = 10000;
let max_iter_line_search = 1000;
gd.minimize(
&mut ls,
f_and_g,
max_iter_solver,
max_iter_line_search,
None,
)
.unwrap();
println!("Iterate: {:?}", gd.xk());
let eval = f_and_g(gd.xk());
println!("Function eval: {:?}", eval);
println!(
"Projected Gradient norm: {:?}",
gd.projected_gradient(&eval).norm()
);
println!("tol: {:?}", tol);
let convergence = gd.has_converged(&eval);
println!("Convergence: {:?}", convergence);
}
}