use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone)]
pub struct MinimaxConfig {
pub max_iter: usize,
pub tol: f64,
pub step_size_x: f64,
pub step_size_y: f64,
pub fd_step: f64,
pub print_every: usize,
}
impl Default for MinimaxConfig {
fn default() -> Self {
Self {
max_iter: 5_000,
tol: 1e-6,
step_size_x: 1e-3,
step_size_y: 1e-3,
fd_step: 1e-5,
print_every: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct MinimaxResult {
pub x: Array1<f64>,
pub y: Array1<f64>,
pub fun: f64,
pub n_iter: usize,
pub gap: f64,
pub converged: bool,
pub message: String,
}
fn grad_x<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
where
F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
let n = x.len();
let f0 = f(x, y);
let mut g = Array1::<f64>::zeros(n);
let mut x_fwd = x.to_owned();
for i in 0..n {
x_fwd[i] += h;
g[i] = (f(&x_fwd.view(), y) - f0) / h;
x_fwd[i] = x[i];
}
g
}
fn grad_y<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
where
F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
let m = y.len();
let f0 = f(x, y);
let mut g = Array1::<f64>::zeros(m);
let mut y_fwd = y.to_owned();
for i in 0..m {
y_fwd[i] += h;
g[i] = (f(x, &y_fwd.view()) - f0) / h;
y_fwd[i] = y[i];
}
g
}
#[inline]
fn vec_norm(v: &Array1<f64>) -> f64 {
v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
}
pub fn minimax_solve<F>(
f: &F,
x0: &ArrayView1<f64>,
y0: &ArrayView1<f64>,
config: &MinimaxConfig,
) -> OptimizeResult<MinimaxResult>
where
F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
let nx = x0.len();
let ny = y0.len();
if nx == 0 || ny == 0 {
return Err(OptimizeError::ValueError(
"x0 and y0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut y = y0.to_owned();
let mut converged = false;
let h = config.fd_step;
for k in 0..config.max_iter {
let gx = grad_x(f, &x.view(), &y.view(), h);
let gy = grad_y(f, &x.view(), &y.view(), h);
let mut dx_norm = 0.0_f64;
let mut dy_norm = 0.0_f64;
for i in 0..nx {
let step = config.step_size_x * gx[i];
x[i] -= step;
dx_norm += step * step;
}
for i in 0..ny {
let step = config.step_size_y * gy[i];
y[i] += step;
dy_norm += step * step;
}
let delta = dx_norm.sqrt() + dy_norm.sqrt();
if delta < config.tol {
converged = true;
if config.print_every > 0 {
eprintln!("[GDA] converged at iteration {}", k + 1);
}
break;
}
if config.print_every > 0 && (k + 1) % config.print_every == 0 {
eprintln!("[GDA] iter {}: delta={:.2e}", k + 1, delta);
}
}
let fun = f(&x.view(), &y.view());
let gap = compute_gap(f, &x.view(), &y.view(), h);
Ok(MinimaxResult {
x,
y,
fun,
n_iter: config.max_iter,
gap,
converged,
message: if converged {
"GDA converged".to_string()
} else {
"GDA reached maximum iterations".to_string()
},
})
}
pub fn extragradient_solve<F>(
f: &F,
x0: &ArrayView1<f64>,
y0: &ArrayView1<f64>,
config: &MinimaxConfig,
) -> OptimizeResult<MinimaxResult>
where
F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
let nx = x0.len();
let ny = y0.len();
if nx == 0 || ny == 0 {
return Err(OptimizeError::ValueError(
"x0 and y0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut y = y0.to_owned();
let mut converged = false;
let h = config.fd_step;
for k in 0..config.max_iter {
let gx_k = grad_x(f, &x.view(), &y.view(), h);
let gy_k = grad_y(f, &x.view(), &y.view(), h);
let x_bar: Array1<f64> = x
.iter()
.zip(gx_k.iter())
.map(|(&xi, &gi)| xi - config.step_size_x * gi)
.collect();
let y_bar: Array1<f64> = y
.iter()
.zip(gy_k.iter())
.map(|(&yi, &gi)| yi + config.step_size_y * gi)
.collect();
let gx_bar = grad_x(f, &x_bar.view(), &y_bar.view(), h);
let gy_bar = grad_y(f, &x_bar.view(), &y_bar.view(), h);
let mut delta = 0.0_f64;
for i in 0..nx {
let step = config.step_size_x * gx_bar[i];
x[i] -= step;
delta += step * step;
}
for i in 0..ny {
let step = config.step_size_y * gy_bar[i];
y[i] += step;
delta += step * step;
}
if delta.sqrt() < config.tol {
converged = true;
if config.print_every > 0 {
eprintln!("[EG] converged at iteration {}", k + 1);
}
break;
}
if config.print_every > 0 && (k + 1) % config.print_every == 0 {
eprintln!("[EG] iter {}: delta={:.2e}", k + 1, delta.sqrt());
}
}
let fun = f(&x.view(), &y.view());
let gap = compute_gap(f, &x.view(), &y.view(), h);
Ok(MinimaxResult {
x,
y,
fun,
n_iter: config.max_iter,
gap,
converged,
message: if converged {
"Extragradient converged".to_string()
} else {
"Extragradient reached maximum iterations".to_string()
},
})
}
#[derive(Debug, Clone)]
pub struct PrimalDualConfig {
pub max_iter: usize,
pub tol: f64,
pub sigma: f64,
pub tau: f64,
pub fd_step: f64,
}
impl Default for PrimalDualConfig {
fn default() -> Self {
Self {
max_iter: 5_000,
tol: 1e-6,
sigma: 1e-3,
tau: 1e-3,
fd_step: 1e-5,
}
}
}
pub fn primal_dual<Px, Py>(
primal_fn: &Px,
dual_fn: &Py,
x0: &ArrayView1<f64>,
y0: &ArrayView1<f64>,
config: &PrimalDualConfig,
) -> OptimizeResult<(Array1<f64>, Array1<f64>)>
where
Px: Fn(&ArrayView1<f64>) -> Array1<f64>,
Py: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
let nx = x0.len();
let ny = y0.len();
if nx == 0 || ny == 0 {
return Err(OptimizeError::ValueError(
"x0 and y0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut y = y0.to_owned();
let mut x_bar = x.clone();
for _k in 0..config.max_iter {
let gy = dual_fn(&y.view());
let y_new: Array1<f64> = y
.iter()
.zip(gy.iter())
.map(|(&yi, &gyi)| yi + config.tau * gyi)
.collect();
let gx = primal_fn(&x.view());
let x_new: Array1<f64> = x
.iter()
.zip(gx.iter())
.map(|(&xi, &gxi)| xi - config.sigma * gxi)
.collect();
let x_bar_new: Array1<f64> = x_new
.iter()
.zip(x.iter())
.map(|(&xn, &xo)| 2.0 * xn - xo)
.collect();
let dx = vec_norm(&(x_new.clone() - &x));
let dy = vec_norm(&(y_new.clone() - &y));
let delta = dx + dy;
x = x_new;
y = y_new;
x_bar = x_bar_new;
if delta < config.tol {
break;
}
}
let _ = x_bar; Ok((x, y))
}
fn compute_gap<F>(
f: &F,
x: &ArrayView1<f64>,
y: &ArrayView1<f64>,
h: f64,
) -> f64
where
F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
let gx = grad_x(f, x, y, h);
let gy = grad_y(f, x, y, h);
vec_norm(&gx) + vec_norm(&gy)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn bilinear(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum()
}
fn convex_concave(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
let quad_x: f64 = x.iter().map(|xi| xi * xi).sum();
let quad_y: f64 = y.iter().map(|yi| yi * yi).sum();
let cross: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
quad_x - quad_y + cross
}
#[test]
fn test_minimax_gda_bilinear() {
let x0 = array![1.0, 1.0];
let y0 = array![1.0, 1.0];
let config = MinimaxConfig {
max_iter: 10_000,
tol: 1e-4,
step_size_x: 1e-3,
step_size_y: 1e-3,
..Default::default()
};
let result = minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).expect("failed to create result");
let norm_x = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
let norm_y = result.y.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
assert!(
norm_x < 0.5,
"GDA bilinear: ‖x‖ should be small, got {}",
norm_x
);
assert!(
norm_y < 0.5,
"GDA bilinear: ‖y‖ should be small, got {}",
norm_y
);
}
#[test]
fn test_extragradient_convex_concave() {
let x0 = array![2.0];
let y0 = array![2.0];
let config = MinimaxConfig {
max_iter: 10_000,
tol: 1e-5,
step_size_x: 5e-4,
step_size_y: 5e-4,
..Default::default()
};
let f = |x: &ArrayView1<f64>, y: &ArrayView1<f64>| x[0] * x[0] - y[0] * y[0];
let result = extragradient_solve(&f, &x0.view(), &y0.view(), &config).expect("failed to create result");
assert!(
result.x[0].abs() < 0.3,
"EG: expected x* ≈ 0, got {}",
result.x[0]
);
assert!(
result.y[0].abs() < 0.3,
"EG: expected y* ≈ 0, got {}",
result.y[0]
);
}
#[test]
fn test_extragradient_convex_concave_2d() {
let x0 = array![1.0, 1.0];
let y0 = array![1.0, 1.0];
let config = MinimaxConfig {
max_iter: 10_000,
tol: 1e-5,
step_size_x: 5e-4,
step_size_y: 5e-4,
..Default::default()
};
let result =
extragradient_solve(&convex_concave, &x0.view(), &y0.view(), &config).expect("unexpected None or Err");
let norm = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
assert!(norm < 1.5, "EG 2D: ‖x‖={} should be < 1.5", norm);
}
#[test]
fn test_primal_dual_gradient() {
let x0 = array![3.0, -2.0];
let y0 = array![1.0, 4.0];
let config = PrimalDualConfig {
max_iter: 20_000,
tol: 1e-5,
sigma: 5e-4,
tau: 5e-4,
..Default::default()
};
let primal_fn = |x: &ArrayView1<f64>| x.mapv(|xi| 2.0 * xi);
let dual_fn = |y: &ArrayView1<f64>| y.mapv(|yi| -2.0 * yi);
let (x_star, y_star) =
primal_dual(&primal_fn, &dual_fn, &x0.view(), &y0.view(), &config).expect("unexpected None or Err");
let xn = x_star.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
let yn = y_star.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
assert!(xn < 0.5, "PD: ‖x*‖={} should be < 0.5", xn);
assert!(yn < 0.5, "PD: ‖y*‖={} should be < 0.5", yn);
}
#[test]
fn test_minimax_empty_input() {
let x0: Array1<f64> = Array1::zeros(0);
let y0 = array![1.0];
let config = MinimaxConfig::default();
assert!(minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).is_err());
}
}