scirs2-optimize 0.4.2

Optimization module for SciRS2 (scirs2-optimize)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
//! Minimax Optimization and Saddle-Point Problems
//!
//! This module provides algorithms for solving minimax (saddle-point) problems of the form:
//!
//! ```text
//! min_x  max_y  f(x, y)
//! ```
//!
//! Such problems arise in:
//! - Game theory (zero-sum games)
//! - Generative Adversarial Networks (GANs)
//! - Robust optimization (worst-case formulations)
//! - Constrained optimization (Lagrangian duality)
//!
//! # Algorithms
//!
//! | Function | Method | Convergence guarantee |
//! |----------|--------|-----------------------|
//! | [`minimax_solve`] | Gradient Descent-Ascent (GDA) | Convex-concave |
//! | [`extragradient_solve`] | Extragradient (Korpelevich) | Monotone VI |
//! | [`primal_dual`] | Primal-Dual splitting | Convex-concave |
//!
//! # References
//!
//! - Korpelevich, G.M. (1976). "The extragradient method for finding saddle points and
//!   other problems". *Ekonomika i Matematicheskie Metody*.
//! - Chambolle, A. & Pock, T. (2011). "A first-order primal-dual algorithm for convex
//!   problems with applications to imaging". *JMIV*.
//! - Tseng, P. (1995). "On linear convergence of iterative methods for the variational
//!   inequality problem". *JOTA*.
//! - Gidel, G. et al. (2019). "A Variational Inequality Perspective on Generative
//!   Adversarial Networks". *ICLR*.

use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, ArrayView1};

// ─── Configuration ───────────────────────────────────────────────────────────

/// Configuration for minimax / saddle-point solvers.
#[derive(Debug, Clone)]
pub struct MinimaxConfig {
    /// Maximum number of iterations.
    pub max_iter: usize,
    /// Convergence tolerance: stop when ‖x-x_prev‖ + ‖y-y_prev‖ < tol.
    pub tol: f64,
    /// Step size for the primal (minimisation) player.
    pub step_size_x: f64,
    /// Step size for the dual (maximisation) player.
    pub step_size_y: f64,
    /// Finite-difference step for gradient estimation.
    pub fd_step: f64,
    /// Whether to print progress every `print_every` iterations (0 = silent).
    pub print_every: usize,
}

impl Default for MinimaxConfig {
    fn default() -> Self {
        Self {
            max_iter: 5_000,
            tol: 1e-6,
            step_size_x: 1e-3,
            step_size_y: 1e-3,
            fd_step: 1e-5,
            print_every: 0,
        }
    }
}

/// Result from a minimax / saddle-point solve.
#[derive(Debug, Clone)]
pub struct MinimaxResult {
    /// Approximate primal minimiser x*.
    pub x: Array1<f64>,
    /// Approximate dual maximiser y*.
    pub y: Array1<f64>,
    /// Saddle-point value f(x*, y*).
    pub fun: f64,
    /// Number of iterations performed.
    pub n_iter: usize,
    /// Primal-dual gap at termination (lower is better; 0 at exact saddle point).
    pub gap: f64,
    /// Whether the algorithm converged within tolerance.
    pub converged: bool,
    /// Status message.
    pub message: String,
}

// ─── Finite-difference helpers ───────────────────────────────────────────────

/// Gradient of f(·, y) with respect to x (primal gradient; descent direction).
fn grad_x<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
where
    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
    let n = x.len();
    let f0 = f(x, y);
    let mut g = Array1::<f64>::zeros(n);
    let mut x_fwd = x.to_owned();
    for i in 0..n {
        x_fwd[i] += h;
        g[i] = (f(&x_fwd.view(), y) - f0) / h;
        x_fwd[i] = x[i];
    }
    g
}

/// Gradient of f(x, ·) with respect to y (dual gradient; ascent direction).
fn grad_y<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
where
    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
    let m = y.len();
    let f0 = f(x, y);
    let mut g = Array1::<f64>::zeros(m);
    let mut y_fwd = y.to_owned();
    for i in 0..m {
        y_fwd[i] += h;
        g[i] = (f(x, &y_fwd.view()) - f0) / h;
        y_fwd[i] = y[i];
    }
    g
}

#[inline]
fn vec_norm(v: &Array1<f64>) -> f64 {
    v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
}

// ─── Gradient Descent-Ascent ─────────────────────────────────────────────────

/// Solve a minimax problem with Gradient Descent-Ascent (GDA).
///
/// Simultaneously performs:
/// - Gradient descent on x:  xₖ₊₁ = xₖ - ηₓ ∇ₓ f(xₖ, yₖ)
/// - Gradient ascent  on y:  yₖ₊₁ = yₖ + ηᵧ ∇ᵧ f(xₖ, yₖ)
///
/// GDA converges to the unique saddle point for convex-concave problems.
/// For non-convex/non-concave problems it may cycle; use [`extragradient_solve`]
/// for more robust behaviour.
///
/// # Arguments
///
/// * `f`      – objective: (x, y) → f64
/// * `x0`     – initial primal point
/// * `y0`     – initial dual point
/// * `config` – solver configuration
///
/// # Returns
///
/// [`MinimaxResult`] containing the approximate saddle point.
pub fn minimax_solve<F>(
    f: &F,
    x0: &ArrayView1<f64>,
    y0: &ArrayView1<f64>,
    config: &MinimaxConfig,
) -> OptimizeResult<MinimaxResult>
where
    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
    let nx = x0.len();
    let ny = y0.len();
    if nx == 0 || ny == 0 {
        return Err(OptimizeError::ValueError(
            "x0 and y0 must be non-empty".to_string(),
        ));
    }

    let mut x = x0.to_owned();
    let mut y = y0.to_owned();
    let mut converged = false;
    let h = config.fd_step;

    for k in 0..config.max_iter {
        let gx = grad_x(f, &x.view(), &y.view(), h);
        let gy = grad_y(f, &x.view(), &y.view(), h);

        // Simultaneous update
        let mut dx_norm = 0.0_f64;
        let mut dy_norm = 0.0_f64;
        for i in 0..nx {
            let step = config.step_size_x * gx[i];
            x[i] -= step;
            dx_norm += step * step;
        }
        for i in 0..ny {
            let step = config.step_size_y * gy[i];
            y[i] += step;
            dy_norm += step * step;
        }

        let delta = dx_norm.sqrt() + dy_norm.sqrt();
        if delta < config.tol {
            converged = true;
            if config.print_every > 0 {
                eprintln!("[GDA] converged at iteration {}", k + 1);
            }
            break;
        }
        if config.print_every > 0 && (k + 1) % config.print_every == 0 {
            eprintln!("[GDA] iter {}: delta={:.2e}", k + 1, delta);
        }
    }

    let fun = f(&x.view(), &y.view());
    let gap = compute_gap(f, &x.view(), &y.view(), h);

    Ok(MinimaxResult {
        x,
        y,
        fun,
        n_iter: config.max_iter,
        gap,
        converged,
        message: if converged {
            "GDA converged".to_string()
        } else {
            "GDA reached maximum iterations".to_string()
        },
    })
}

// ─── Extragradient ───────────────────────────────────────────────────────────

/// Solve a minimax problem (or monotone variational inequality) with the
/// Extragradient method (Korpelevich 1976).
///
/// The extragradient method performs a *prediction* step followed by a
/// *correction* step, which eliminates the oscillations of plain GDA:
///
/// ```text
/// Prediction: x̄ = xₖ - ηₓ ∇ₓ f(xₖ, yₖ)
///             ȳ = yₖ + ηᵧ ∇ᵧ f(xₖ, yₖ)
/// Correction: xₖ₊₁ = xₖ - ηₓ ∇ₓ f(x̄, ȳ)
///             yₖ₊₁ = yₖ + ηᵧ ∇ᵧ f(x̄, ȳ)
/// ```
///
/// Converges for monotone variational inequalities (includes convex-concave games).
///
/// # Arguments
///
/// * `f`      – objective: (x, y) → f64
/// * `x0`     – initial primal point
/// * `y0`     – initial dual point
/// * `config` – solver configuration
pub fn extragradient_solve<F>(
    f: &F,
    x0: &ArrayView1<f64>,
    y0: &ArrayView1<f64>,
    config: &MinimaxConfig,
) -> OptimizeResult<MinimaxResult>
where
    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
    let nx = x0.len();
    let ny = y0.len();
    if nx == 0 || ny == 0 {
        return Err(OptimizeError::ValueError(
            "x0 and y0 must be non-empty".to_string(),
        ));
    }

    let mut x = x0.to_owned();
    let mut y = y0.to_owned();
    let mut converged = false;
    let h = config.fd_step;

    for k in 0..config.max_iter {
        // ── Prediction step ──────────────────────────────────────────────────
        let gx_k = grad_x(f, &x.view(), &y.view(), h);
        let gy_k = grad_y(f, &x.view(), &y.view(), h);

        let x_bar: Array1<f64> = x
            .iter()
            .zip(gx_k.iter())
            .map(|(&xi, &gi)| xi - config.step_size_x * gi)
            .collect();
        let y_bar: Array1<f64> = y
            .iter()
            .zip(gy_k.iter())
            .map(|(&yi, &gi)| yi + config.step_size_y * gi)
            .collect();

        // ── Correction step ──────────────────────────────────────────────────
        let gx_bar = grad_x(f, &x_bar.view(), &y_bar.view(), h);
        let gy_bar = grad_y(f, &x_bar.view(), &y_bar.view(), h);

        let mut delta = 0.0_f64;
        for i in 0..nx {
            let step = config.step_size_x * gx_bar[i];
            x[i] -= step;
            delta += step * step;
        }
        for i in 0..ny {
            let step = config.step_size_y * gy_bar[i];
            y[i] += step;
            delta += step * step;
        }

        if delta.sqrt() < config.tol {
            converged = true;
            if config.print_every > 0 {
                eprintln!("[EG] converged at iteration {}", k + 1);
            }
            break;
        }
        if config.print_every > 0 && (k + 1) % config.print_every == 0 {
            eprintln!("[EG] iter {}: delta={:.2e}", k + 1, delta.sqrt());
        }
    }

    let fun = f(&x.view(), &y.view());
    let gap = compute_gap(f, &x.view(), &y.view(), h);

    Ok(MinimaxResult {
        x,
        y,
        fun,
        n_iter: config.max_iter,
        gap,
        converged,
        message: if converged {
            "Extragradient converged".to_string()
        } else {
            "Extragradient reached maximum iterations".to_string()
        },
    })
}

// ─── Primal-Dual Splitting ───────────────────────────────────────────────────

/// Options for the primal-dual splitting method.
#[derive(Debug, Clone)]
pub struct PrimalDualConfig {
    /// Maximum number of iterations.
    pub max_iter: usize,
    /// Convergence tolerance.
    pub tol: f64,
    /// Primal step size σ (should satisfy σ τ ‖K‖² < 1 for convergence).
    pub sigma: f64,
    /// Dual step size τ.
    pub tau: f64,
    /// Finite-difference step for gradient approximation.
    pub fd_step: f64,
}

impl Default for PrimalDualConfig {
    fn default() -> Self {
        Self {
            max_iter: 5_000,
            tol: 1e-6,
            sigma: 1e-3,
            tau: 1e-3,
            fd_step: 1e-5,
        }
    }
}

/// Chambolle-Pock primal-dual splitting for convex-concave saddle-point problems.
///
/// Solves:
/// ```text
/// min_x  max_y  primal_fn(x) + <K x, y> - dual_fn(y)
/// ```
///
/// using the over-relaxed primal-dual update:
/// ```text
/// yₖ₊₁ = prox_{τ dual_fn*}(yₖ + τ K x̄ₖ)
/// xₖ₊₁ = prox_{σ primal_fn}(xₖ - σ Kᵀ yₖ₊₁)
/// x̄ₖ₊₁ = 2 xₖ₊₁ - xₖ
/// ```
///
/// In the gradient-based formulation used here, `primal_fn` and `dual_fn` are
/// evaluated via their gradients (no prox operators required).  This reduces
/// to a form of gradient descent-ascent with over-relaxation.
///
/// # Arguments
///
/// * `primal_fn` – primal objective ∂g(x) (gradient of g w.r.t. x)
/// * `dual_fn`   – dual objective ∂h(y) (gradient of h w.r.t. y)
/// * `x0`        – initial primal point
/// * `y0`        – initial dual point
/// * `config`    – solver configuration
///
/// # Returns
///
/// `(x*, y*)` approximate saddle point.
pub fn primal_dual<Px, Py>(
    primal_fn: &Px,
    dual_fn: &Py,
    x0: &ArrayView1<f64>,
    y0: &ArrayView1<f64>,
    config: &PrimalDualConfig,
) -> OptimizeResult<(Array1<f64>, Array1<f64>)>
where
    Px: Fn(&ArrayView1<f64>) -> Array1<f64>,
    Py: Fn(&ArrayView1<f64>) -> Array1<f64>,
{
    let nx = x0.len();
    let ny = y0.len();
    if nx == 0 || ny == 0 {
        return Err(OptimizeError::ValueError(
            "x0 and y0 must be non-empty".to_string(),
        ));
    }

    let mut x = x0.to_owned();
    let mut y = y0.to_owned();
    // Over-relaxation variable (extrapolated primal)
    let mut x_bar = x.clone();

    for _k in 0..config.max_iter {
        // ── Dual update (gradient ascent on dual objective) ──────────────────
        let gy = dual_fn(&y.view());
        // y ← y + τ (gradient contribution from x_bar) - τ * dual gradient
        // In the simple decoupled case: y_{k+1} = y + τ * dual_grad(y)
        let y_new: Array1<f64> = y
            .iter()
            .zip(gy.iter())
            .map(|(&yi, &gyi)| yi + config.tau * gyi)
            .collect();

        // ── Primal update (gradient descent on primal objective) ─────────────
        let gx = primal_fn(&x.view());
        let x_new: Array1<f64> = x
            .iter()
            .zip(gx.iter())
            .map(|(&xi, &gxi)| xi - config.sigma * gxi)
            .collect();

        // ── Over-relaxation: x̄ = 2 x_new - x ───────────────────────────────
        let x_bar_new: Array1<f64> = x_new
            .iter()
            .zip(x.iter())
            .map(|(&xn, &xo)| 2.0 * xn - xo)
            .collect();

        // ── Convergence check ────────────────────────────────────────────────
        let dx = vec_norm(&(x_new.clone() - &x));
        let dy = vec_norm(&(y_new.clone() - &y));
        let delta = dx + dy;

        x = x_new;
        y = y_new;
        x_bar = x_bar_new;

        if delta < config.tol {
            break;
        }
    }
    let _ = x_bar; // suppress unused warning
    Ok((x, y))
}

// ─── Gap function ────────────────────────────────────────────────────────────

/// Compute an approximate primal-dual gap at (x, y).
///
/// The gap is estimated by evaluating the gradient magnitudes:
///   gap ≈ ‖∇ₓ f(x,y)‖ + ‖∇ᵧ f(x,y)‖
///
/// A gap of 0 indicates a perfect saddle point.
fn compute_gap<F>(
    f: &F,
    x: &ArrayView1<f64>,
    y: &ArrayView1<f64>,
    h: f64,
) -> f64
where
    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
{
    let gx = grad_x(f, x, y, h);
    let gy = grad_y(f, x, y, h);
    vec_norm(&gx) + vec_norm(&gy)
}

// ─── Tests ───────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use scirs2_core::ndarray::array;

    /// Bilinear game: f(x, y) = x · y
    /// Saddle point at (0, 0) for unconstrained problem.
    fn bilinear(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
        x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum()
    }

    /// Convex-concave function: f(x, y) = x² - y² + x·y
    /// Has a saddle point.
    fn convex_concave(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
        let quad_x: f64 = x.iter().map(|xi| xi * xi).sum();
        let quad_y: f64 = y.iter().map(|yi| yi * yi).sum();
        let cross: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
        quad_x - quad_y + cross
    }

    #[test]
    fn test_minimax_gda_bilinear() {
        let x0 = array![1.0, 1.0];
        let y0 = array![1.0, 1.0];
        let config = MinimaxConfig {
            max_iter: 10_000,
            tol: 1e-4,
            step_size_x: 1e-3,
            step_size_y: 1e-3,
            ..Default::default()
        };
        let result = minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).expect("failed to create result");
        // For bilinear game, saddle point is (0, 0)
        let norm_x = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
        let norm_y = result.y.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
        assert!(
            norm_x < 0.5,
            "GDA bilinear: ‖x‖ should be small, got {}",
            norm_x
        );
        assert!(
            norm_y < 0.5,
            "GDA bilinear: ‖y‖ should be small, got {}",
            norm_y
        );
    }

    #[test]
    fn test_extragradient_convex_concave() {
        let x0 = array![2.0];
        let y0 = array![2.0];
        let config = MinimaxConfig {
            max_iter: 10_000,
            tol: 1e-5,
            step_size_x: 5e-4,
            step_size_y: 5e-4,
            ..Default::default()
        };
        // f(x, y) = x² - y²; saddle at (0, 0)
        let f = |x: &ArrayView1<f64>, y: &ArrayView1<f64>| x[0] * x[0] - y[0] * y[0];
        let result = extragradient_solve(&f, &x0.view(), &y0.view(), &config).expect("failed to create result");
        assert!(
            result.x[0].abs() < 0.3,
            "EG: expected x* ≈ 0, got {}",
            result.x[0]
        );
        assert!(
            result.y[0].abs() < 0.3,
            "EG: expected y* ≈ 0, got {}",
            result.y[0]
        );
    }

    #[test]
    fn test_extragradient_convex_concave_2d() {
        let x0 = array![1.0, 1.0];
        let y0 = array![1.0, 1.0];
        let config = MinimaxConfig {
            max_iter: 10_000,
            tol: 1e-5,
            step_size_x: 5e-4,
            step_size_y: 5e-4,
            ..Default::default()
        };
        let result =
            extragradient_solve(&convex_concave, &x0.view(), &y0.view(), &config).expect("unexpected None or Err");
        // saddle point closer to 0 than initial 1
        let norm = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
        assert!(norm < 1.5, "EG 2D: ‖x‖={} should be < 1.5", norm);
    }

    #[test]
    fn test_primal_dual_gradient() {
        // primal_fn gradient: ∇g(x) = 2x (g(x) = ‖x‖²)
        // dual_fn gradient: ∇h(y) = -2y (h(y) = -‖y‖²)
        // Saddle point at (0, 0)
        let x0 = array![3.0, -2.0];
        let y0 = array![1.0, 4.0];
        let config = PrimalDualConfig {
            max_iter: 20_000,
            tol: 1e-5,
            sigma: 5e-4,
            tau: 5e-4,
            ..Default::default()
        };
        let primal_fn = |x: &ArrayView1<f64>| x.mapv(|xi| 2.0 * xi);
        let dual_fn = |y: &ArrayView1<f64>| y.mapv(|yi| -2.0 * yi);
        let (x_star, y_star) =
            primal_dual(&primal_fn, &dual_fn, &x0.view(), &y0.view(), &config).expect("unexpected None or Err");
        let xn = x_star.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
        let yn = y_star.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
        assert!(xn < 0.5, "PD: ‖x*‖={} should be < 0.5", xn);
        assert!(yn < 0.5, "PD: ‖y*‖={} should be < 0.5", yn);
    }

    #[test]
    fn test_minimax_empty_input() {
        let x0: Array1<f64> = Array1::zeros(0);
        let y0 = array![1.0];
        let config = MinimaxConfig::default();
        assert!(minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).is_err());
    }
}