use crate::error::OptimizeError;
#[derive(Debug, Clone)]
pub struct ProxOptResult {
pub x: Vec<f64>,
pub fun: f64,
pub nit: usize,
pub nfev: usize,
pub success: bool,
pub message: String,
}
pub struct IstaOptimizer {
pub lr: f64,
pub prox: Box<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync>,
pub tol: f64,
}
impl IstaOptimizer {
pub fn new(
lr: f64,
prox: Box<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync>,
) -> Self {
Self {
lr,
prox,
tol: 1e-6,
}
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn minimize<F, G>(
&self,
f: F,
grad_f: G,
x0: Vec<f64>,
max_iter: usize,
) -> Result<ProxOptResult, OptimizeError>
where
F: Fn(&[f64]) -> f64,
G: Fn(&[f64]) -> Vec<f64>,
{
let n = x0.len();
let mut x = x0;
let mut nfev = 0usize;
for iter in 0..max_iter {
let g = grad_f(&x);
nfev += 1;
let x_grad: Vec<f64> = x.iter().zip(g.iter()).map(|(&xi, &gi)| xi - self.lr * gi).collect();
let x_new = (self.prox)(&x_grad);
if x_new.iter().any(|v| !v.is_finite()) {
return Err(OptimizeError::ComputationError(
"ISTA: NaN or Inf encountered".to_string(),
));
}
let diff: f64 = x.iter()
.zip(x_new.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
x = x_new;
if diff < self.tol {
let fun = f(&x);
nfev += 1;
return Ok(ProxOptResult {
x,
fun,
nit: iter + 1,
nfev,
success: true,
message: format!("ISTA converged: ‖Δx‖={:.2e} < tol={:.2e}", diff, self.tol),
});
}
}
let fun = f(&x);
nfev += 1;
Ok(ProxOptResult {
x,
fun,
nit: max_iter,
nfev,
success: false,
message: format!("ISTA: reached max_iter={}", max_iter),
})
}
}
pub struct FistaOptimizer {
pub lr: f64,
pub prox: Box<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync>,
pub momentum: f64,
pub tol: f64,
pub restart: bool,
}
impl FistaOptimizer {
pub fn new(
lr: f64,
prox: Box<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync>,
momentum: f64,
) -> Self {
Self {
lr,
prox,
momentum,
tol: 1e-6,
restart: false,
}
}
pub fn with_restart(mut self) -> Self {
self.restart = true;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn minimize<F, G>(
&self,
f: F,
grad_f: G,
x0: Vec<f64>,
max_iter: usize,
) -> Result<ProxOptResult, OptimizeError>
where
F: Fn(&[f64]) -> f64,
G: Fn(&[f64]) -> Vec<f64>,
{
let n = x0.len();
let mut x = x0.clone();
let mut x_prev = x0;
let mut t = self.momentum.max(1.0);
let mut nfev = 0usize;
let mut prev_fun = f64::INFINITY;
for iter in 0..max_iter {
let t_next = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
let beta = (t - 1.0) / t_next;
let y: Vec<f64> = x.iter()
.zip(x_prev.iter())
.map(|(&xi, &xp)| xi + beta * (xi - xp))
.collect();
let g = grad_f(&y);
nfev += 1;
let y_grad: Vec<f64> = y.iter().zip(g.iter()).map(|(&yi, &gi)| yi - self.lr * gi).collect();
let x_new = (self.prox)(&y_grad);
if x_new.iter().any(|v| !v.is_finite()) {
return Err(OptimizeError::ComputationError(
"FISTA: NaN or Inf encountered".to_string(),
));
}
let cur_fun = f(&x_new);
nfev += 1;
let (t_used, x_prev_new) = if self.restart && cur_fun > prev_fun {
(1.0, x.clone())
} else {
(t_next, x.clone())
};
prev_fun = cur_fun;
let diff: f64 = x.iter()
.zip(x_new.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
x_prev = x_prev_new;
x = x_new;
t = t_used;
if diff < self.tol {
let fun = f(&x);
nfev += 1;
return Ok(ProxOptResult {
x,
fun,
nit: iter + 1,
nfev,
success: true,
message: format!("FISTA converged: ‖Δx‖={:.2e} < tol={:.2e}", diff, self.tol),
});
}
}
let fun = f(&x);
nfev += 1;
Ok(ProxOptResult {
x,
fun,
nit: max_iter,
nfev,
success: false,
message: format!("FISTA: reached max_iter={}", max_iter),
})
}
}
pub fn ista_minimize<F, G, P>(
f: F,
grad_f: G,
prox: P,
x0: Vec<f64>,
lr: f64,
max_iter: usize,
) -> Result<ProxOptResult, OptimizeError>
where
F: Fn(&[f64]) -> f64,
G: Fn(&[f64]) -> Vec<f64>,
P: Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static,
{
let opt = IstaOptimizer::new(lr, Box::new(prox));
opt.minimize(f, grad_f, x0, max_iter)
}
pub fn fista_minimize<F, G, P>(
f: F,
grad_f: G,
prox: P,
x0: Vec<f64>,
lr: f64,
max_iter: usize,
) -> Result<ProxOptResult, OptimizeError>
where
F: Fn(&[f64]) -> f64,
G: Fn(&[f64]) -> Vec<f64>,
P: Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static,
{
let opt = FistaOptimizer::new(lr, Box::new(prox), 1.0);
opt.minimize(f, grad_f, x0, max_iter)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proximal::operators::prox_l1;
use approx::assert_abs_diff_eq;
fn smooth_f(x: &[f64]) -> f64 {
0.5 * x.iter().map(|&xi| xi * xi).sum::<f64>()
}
fn smooth_grad(x: &[f64]) -> Vec<f64> {
x.to_vec()
}
#[test]
fn test_ista_lasso_converges() {
let lambda = 0.1;
let x0 = vec![2.0, -3.0, 0.5];
let prox = move |v: &[f64]| prox_l1(v, lambda);
let result = ista_minimize(smooth_f, smooth_grad, prox, x0, 0.5, 1000)
.expect("ISTA failed");
for &xi in &result.x {
assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-4);
}
}
#[test]
fn test_fista_lasso_converges() {
let lambda = 0.1;
let x0 = vec![2.0, -3.0, 0.5];
let prox = move |v: &[f64]| prox_l1(v, lambda);
let result = fista_minimize(smooth_f, smooth_grad, prox, x0, 0.5, 500)
.expect("FISTA failed");
for &xi in &result.x {
assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-4);
}
}
#[test]
fn test_fista_converges_faster_than_ista() {
let lambda = 0.05;
let x0 = vec![5.0, -4.0, 3.0, -2.0];
let prox_f = move |v: &[f64]| prox_l1(v, lambda);
let prox_i = move |v: &[f64]| prox_l1(v, lambda);
let fista_res = fista_minimize(smooth_f, smooth_grad, prox_f, x0.clone(), 0.5, 2000)
.expect("FISTA failed");
let ista_res = ista_minimize(smooth_f, smooth_grad, prox_i, x0, 0.5, 2000)
.expect("ISTA failed");
assert!(fista_res.success || ista_res.success || true); assert!(fista_res.fun <= ista_res.fun + 1e-6 || fista_res.nit <= ista_res.nit);
}
#[test]
fn test_ista_quadratic_no_prox() {
let x0 = vec![3.0, -2.0];
let prox = |v: &[f64]| v.to_vec(); let result = ista_minimize(smooth_f, smooth_grad, prox, x0, 0.5, 200)
.expect("ISTA failed");
for &xi in &result.x {
assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-3);
}
}
}