use crate::error::OptimizeError;
pub fn douglas_rachford(
prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
gamma: f64,
max_iter: usize,
) -> Vec<f64> {
let _n = x0.len();
let mut z = x0;
for _ in 0..max_iter {
let y = prox_g(&z);
let two_y_minus_z: Vec<f64> = y.iter().zip(z.iter()).map(|(&yi, &zi)| 2.0 * yi - zi).collect();
let x = prox_f(&two_y_minus_z);
z = z.iter()
.zip(x.iter().zip(y.iter()))
.map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
.collect();
}
prox_g(&z)
}
pub fn douglas_rachford_tracked(
prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
gamma: f64,
max_iter: usize,
tol: f64,
) -> DRResult {
let n = x0.len();
let mut z = x0;
let _ = gamma;
for iter in 0..max_iter {
let z_prev = z.clone();
let y = prox_g(&z);
let two_y_minus_z: Vec<f64> = y.iter().zip(z.iter()).map(|(&yi, &zi)| 2.0 * yi - zi).collect();
let x = prox_f(&two_y_minus_z);
z = z.iter()
.zip(x.iter().zip(y.iter()))
.map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
.collect();
let dz: f64 = z.iter()
.zip(z_prev.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
if dz < tol {
let x_star = prox_g(&z);
return DRResult {
x: x_star,
nit: iter + 1,
converged: true,
final_residual: dz,
};
}
}
let x_star = prox_g(&z);
let final_res: f64 = 0.0; DRResult {
x: x_star,
nit: max_iter,
converged: false,
final_residual: final_res,
}
}
#[derive(Debug, Clone)]
pub struct DRResult {
pub x: Vec<f64>,
pub nit: usize,
pub converged: bool,
pub final_residual: f64,
}
pub fn peaceman_rachford(
prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
_gamma: f64,
max_iter: usize,
) -> Vec<f64> {
let mut z = x0;
for _ in 0..max_iter {
let y = prox_g(&z);
let refl_y: Vec<f64> = y.iter().zip(z.iter()).map(|(&yi, &zi)| 2.0 * yi - zi).collect();
let x = prox_f(&refl_y);
z = x.iter()
.zip(refl_y.iter())
.map(|(&xi, &ri)| 2.0 * xi - ri)
.collect();
}
prox_g(&z)
}
pub fn forward_backward(
grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
alpha: f64,
max_iter: usize,
tol: f64,
) -> Vec<f64> {
let mut x = x0;
for _ in 0..max_iter {
let g = grad_f(&x);
let x_grad: Vec<f64> = x.iter().zip(g.iter()).map(|(&xi, &gi)| xi - alpha * gi).collect();
let x_new = prox_g(&x_grad);
let diff: f64 = x.iter()
.zip(x_new.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
x = x_new;
if diff < tol {
break;
}
}
x
}
#[allow(clippy::too_many_arguments)]
pub fn primal_dual_chambolle_pock(
prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g_conj: &dyn Fn(&[f64]) -> Vec<f64>,
k_op: &dyn Fn(&[f64]) -> Vec<f64>,
kt_op: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
y0: Vec<f64>,
tau: f64,
sigma: f64,
theta: f64,
max_iter: usize,
) -> (Vec<f64>, Vec<f64>) {
let _ = (tau, sigma); let mut x = x0;
let mut y = y0;
let mut x_bar = x.clone();
for _ in 0..max_iter {
let x_old = x.clone();
let kx_bar = k_op(&x_bar);
let y_input: Vec<f64> = y.iter().zip(kx_bar.iter()).map(|(&yi, &kxi)| yi + kxi).collect();
y = prox_g_conj(&y_input);
let kty = kt_op(&y);
let x_input: Vec<f64> = x.iter().zip(kty.iter()).map(|(&xi, &kti)| xi - kti).collect();
x = prox_f(&x_input);
x_bar = x.iter()
.zip(x_old.iter())
.map(|(&xn, &xo)| xn + theta * (xn - xo))
.collect();
}
(x, y)
}
#[derive(Debug, Clone)]
pub struct SplittingResult {
pub x: Vec<f64>,
pub nit: usize,
pub converged: bool,
}
pub fn dr_split(
prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
x0: Vec<f64>,
gamma: f64,
max_iter: usize,
tol: f64,
) -> Result<SplittingResult, OptimizeError> {
if gamma <= 0.0 {
return Err(OptimizeError::ValueError(
"gamma must be positive for Douglas-Rachford".to_string(),
));
}
let res = douglas_rachford_tracked(prox_f, prox_g, x0, gamma, max_iter, tol);
Ok(SplittingResult {
x: res.x,
nit: res.nit,
converged: res.converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proximal::operators::{prox_l1, prox_l2};
use approx::assert_abs_diff_eq;
fn prox_id(v: &[f64]) -> Vec<f64> {
v.to_vec()
}
#[test]
fn test_douglas_rachford_l1_l2() {
let lambda_l1 = 0.5;
let lambda_l2 = 0.5;
let prox_f = |v: &[f64]| prox_l1(v, lambda_l1);
let prox_g = |v: &[f64]| prox_l2(v, lambda_l2);
let x0 = vec![3.0, -2.0, 1.0];
let result = douglas_rachford(&prox_f, &prox_g, x0, 1.0, 500);
for &xi in &result {
assert!(xi.abs() < 1.0, "DR solution out of expected range: {}", xi);
}
}
#[test]
fn test_douglas_rachford_identity_prox() {
let prox_f = |v: &[f64]| prox_l1(v, 1.0);
let x0 = vec![2.0, -3.0];
let result = douglas_rachford(&prox_f, &prox_id, x0, 1.0, 1000);
for &xi in &result {
assert!(xi.abs() <= 1.0 + 1e-8, "not in expected set: {}", xi);
}
}
#[test]
fn test_dr_tracked_convergence() {
let prox_f = |v: &[f64]| prox_l1(v, 0.3);
let prox_g = |v: &[f64]| prox_l2(v, 0.3);
let x0 = vec![2.0, -1.0];
let res = douglas_rachford_tracked(&prox_f, &prox_g, x0, 1.0, 2000, 1e-8);
assert!(res.converged, "DR should converge within 2000 iters");
assert!(res.nit < 2000, "DR should converge before max_iter");
}
#[test]
fn test_forward_backward_quadratic() {
let grad_f = |x: &[f64]| x.to_vec();
let x0 = vec![3.0, -2.0];
let result = forward_backward(&grad_f, &prox_id, x0, 0.5, 500, 1e-8);
for &xi in &result {
assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-4);
}
}
#[test]
fn test_peaceman_rachford_converges() {
let prox_f = |v: &[f64]| prox_l2(v, 0.5);
let prox_g = |v: &[f64]| prox_l2(v, 0.5);
let x0 = vec![2.0, -1.5];
let result = peaceman_rachford(&prox_f, &prox_g, x0, 1.0, 500);
for &xi in &result {
assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
}
}
#[test]
fn test_dr_split_negative_gamma() {
let prox_f = |v: &[f64]| v.to_vec();
let prox_g = |v: &[f64]| v.to_vec();
let result = dr_split(&prox_f, &prox_g, vec![1.0], -1.0, 10, 1e-6);
assert!(result.is_err());
}
#[test]
fn test_primal_dual_basic() {
let prox_f = |v: &[f64]| prox_l2(v, 0.5);
let prox_g_conj = |v: &[f64]| prox_l2(v, 0.5);
let k_op = |x: &[f64]| x.to_vec();
let kt_op = |y: &[f64]| y.to_vec();
let x0 = vec![2.0, -1.0];
let y0 = vec![0.0, 0.0];
let (x_star, _) = primal_dual_chambolle_pock(
&prox_f, &prox_g_conj, &k_op, &kt_op, x0, y0, 0.5, 0.5, 1.0, 500,
);
for &xi in &x_star {
assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
}
}
}