use crate::primitives::Vector;
use super::{ConvergenceStatus, OptimizationResult, Optimizer};
#[derive(Debug, Clone)]
pub struct FISTA {
pub(crate) max_iter: usize,
pub(crate) step_size: f32,
pub(crate) tol: f32,
}
impl FISTA {
#[must_use]
pub fn new(max_iter: usize, step_size: f32, tol: f32) -> Self {
Self {
max_iter,
step_size,
tol,
}
}
pub fn minimize<F, G, P>(
&mut self,
smooth: F,
grad_smooth: G,
prox: P,
x0: Vector<f32>,
) -> OptimizationResult
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
P: Fn(&Vector<f32>, f32) -> Vector<f32>,
{
let start_time = std::time::Instant::now();
let mut x = x0.clone();
let mut y = x0;
let mut t = 1.0;
for iter in 0..self.max_iter {
let grad_y = grad_smooth(&y);
let mut gradient_step = Vector::zeros(y.len());
for i in 0..y.len() {
gradient_step[i] = y[i] - self.step_size * grad_y[i];
}
let x_new = prox(&gradient_step, self.step_size);
let mut diff_norm = 0.0;
for i in 0..x.len() {
let diff = x_new[i] - x[i];
diff_norm += diff * diff;
}
diff_norm = diff_norm.sqrt();
if diff_norm < self.tol {
let final_obj = smooth(&x_new);
return OptimizationResult {
solution: x_new,
objective_value: final_obj,
iterations: iter,
status: ConvergenceStatus::Converged,
gradient_norm: diff_norm, constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
};
}
let t_new = f32::midpoint(1.0_f32, (1.0_f32 + 4.0_f32 * t * t).sqrt());
let beta = (t - 1.0_f32) / t_new;
let mut y_new = Vector::zeros(x.len());
for i in 0..x.len() {
y_new[i] = x_new[i] + beta * (x_new[i] - x[i]);
}
x = x_new;
y = y_new;
t = t_new;
}
let final_obj = smooth(&x);
OptimizationResult {
solution: x,
objective_value: final_obj,
iterations: self.max_iter,
status: ConvergenceStatus::MaxIterations,
gradient_norm: 0.0,
constraint_violation: 0.0,
elapsed_time: start_time.elapsed(),
}
}
}
impl Optimizer for FISTA {
fn step(&mut self, _params: &mut Vector<f32>, _gradients: &Vector<f32>) {
panic!(
"FISTA does not support stochastic updates (step). Use minimize() for batch optimization with proximal operators."
)
}
fn reset(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::prox;
#[test]
fn test_fista_l1_regularized() {
let smooth = |x: &Vector<f32>| 0.5 * (x[0] - 5.0).powi(2);
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[x[0] - 5.0]);
let proximal = |v: &Vector<f32>, alpha: f32| prox::soft_threshold(v, 2.0 * alpha);
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[0.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!(
(result.solution[0] - 3.0).abs() < 0.5,
"Expected ~3.0, got {}",
result.solution[0]
);
}
#[test]
fn test_fista_nonnegative() {
let smooth = |x: &Vector<f32>| (x[0] + 1.0).powi(2);
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * (x[0] + 1.0)]);
let proximal = |v: &Vector<f32>, _alpha: f32| prox::nonnegative(v);
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[5.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!(result.solution[0].abs() < 1e-4);
}
#[test]
fn test_fista_new() {
let fista = FISTA::new(500, 0.05, 1e-4);
assert_eq!(fista.max_iter, 500);
assert!((fista.step_size - 0.05).abs() < 1e-10);
assert!((fista.tol - 1e-4).abs() < 1e-10);
}
#[test]
fn test_fista_reset() {
let mut fista = FISTA::new(100, 0.1, 1e-5);
fista.reset(); }
#[test]
#[should_panic(expected = "does not support stochastic updates")]
fn test_fista_step_unimplemented() {
let mut fista = FISTA::new(100, 0.1, 1e-5);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let grad = Vector::from_slice(&[0.1, 0.2]);
fista.step(&mut params, &grad);
}
#[test]
fn test_fista_max_iterations() {
let smooth = |x: &Vector<f32>| x[0] * x[0] + x[1] * x[1];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0], 2.0 * x[1]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(2, 0.0001, 1e-10); let x0 = Vector::from_slice(&[100.0, 100.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert_eq!(result.status, ConvergenceStatus::MaxIterations);
assert_eq!(result.iterations, 2);
}
#[test]
fn test_fista_2d_quadratic() {
let smooth = |x: &Vector<f32>| x[0] * x[0] + x[1] * x[1];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0], 2.0 * x[1]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[5.0, -3.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert_eq!(result.status, ConvergenceStatus::Converged);
assert!(result.solution[0].abs() < 1e-4);
assert!(result.solution[1].abs() < 1e-4);
}
#[test]
fn test_fista_3d() {
let smooth = |x: &Vector<f32>| x[0] * x[0] + x[1] * x[1] + x[2] * x[2];
let grad_smooth =
|x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0], 2.0 * x[1], 2.0 * x[2]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-5);
let x0 = Vector::from_slice(&[5.0, -3.0, 2.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert_eq!(result.status, ConvergenceStatus::Converged);
assert!(result.solution[0].abs() < 1e-3);
assert!(result.solution[1].abs() < 1e-3);
assert!(result.solution[2].abs() < 1e-3);
}
#[test]
fn test_fista_objective_value() {
let smooth = |x: &Vector<f32>| (x[0] - 2.0).powi(2);
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * (x[0] - 2.0)]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[0.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!(result.objective_value < 1e-6);
assert!((result.solution[0] - 2.0).abs() < 1e-3);
}
#[test]
fn test_fista_gradient_norm() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[5.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!(result.gradient_norm < 1e-5);
}
#[test]
fn test_fista_constraint_violation_zero() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[5.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!((result.constraint_violation - 0.0).abs() < 1e-10);
}
#[test]
fn test_fista_elapsed_time() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[5.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
let _ = result.elapsed_time.as_nanos();
}
#[test]
fn test_fista_debug_clone() {
let fista = FISTA::new(100, 0.1, 1e-5);
let cloned = fista.clone();
assert_eq!(fista.max_iter, cloned.max_iter);
assert!((fista.step_size - cloned.step_size).abs() < 1e-10);
assert!((fista.tol - cloned.tol).abs() < 1e-10);
let debug_str = format!("{:?}", fista);
assert!(debug_str.contains("FISTA"));
}
#[test]
fn test_fista_already_at_optimum() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[1e-8]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert_eq!(result.status, ConvergenceStatus::Converged);
}
#[test]
fn test_fista_l2_ball_constraint() {
let smooth = |x: &Vector<f32>| (x[0] - 5.0).powi(2);
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * (x[0] - 5.0)]);
let proximal = |v: &Vector<f32>, _alpha: f32| prox::project_l2_ball(v, 2.0);
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[0.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!((result.solution[0] - 2.0).abs() < 0.1);
}
#[test]
fn test_fista_box_constraint() {
let smooth = |x: &Vector<f32>| (x[0] - 5.0).powi(2);
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * (x[0] - 5.0)]);
let lower = Vector::from_slice(&[0.0]);
let upper = Vector::from_slice(&[1.0]);
let proximal = move |v: &Vector<f32>, _alpha: f32| prox::project_box(v, &lower, &upper);
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[0.5]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!((result.solution[0] - 1.0).abs() < 0.1);
}
#[test]
fn test_fista_different_step_sizes() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista1 = FISTA::new(1000, 0.01, 1e-6);
let x0 = Vector::from_slice(&[5.0]);
let result1 = fista1.minimize(smooth, grad_smooth, proximal, x0.clone());
let mut fista2 = FISTA::new(1000, 0.4, 1e-6);
let result2 = fista2.minimize(smooth, grad_smooth, proximal, x0);
assert_eq!(result1.status, ConvergenceStatus::Converged);
assert_eq!(result2.status, ConvergenceStatus::Converged);
}
#[test]
fn test_fista_iterations_tracked() {
let smooth = |x: &Vector<f32>| x[0] * x[0];
let grad_smooth = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let proximal = |v: &Vector<f32>, _alpha: f32| v.clone();
let mut fista = FISTA::new(1000, 0.1, 1e-6);
let x0 = Vector::from_slice(&[10.0]);
let result = fista.minimize(smooth, grad_smooth, proximal, x0);
assert!(result.iterations <= 1000);
}
}