1use crate::error::{OptimizeError, OptimizeResult};
35use scirs2_core::ndarray::{Array1, ArrayView1};
36
37#[derive(Debug, Clone)]
41pub struct MinimaxConfig {
42 pub max_iter: usize,
44 pub tol: f64,
46 pub step_size_x: f64,
48 pub step_size_y: f64,
50 pub fd_step: f64,
52 pub print_every: usize,
54}
55
56impl Default for MinimaxConfig {
57 fn default() -> Self {
58 Self {
59 max_iter: 5_000,
60 tol: 1e-6,
61 step_size_x: 1e-3,
62 step_size_y: 1e-3,
63 fd_step: 1e-5,
64 print_every: 0,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct MinimaxResult {
72 pub x: Array1<f64>,
74 pub y: Array1<f64>,
76 pub fun: f64,
78 pub n_iter: usize,
80 pub gap: f64,
82 pub converged: bool,
84 pub message: String,
86}
87
88fn grad_x<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
92where
93 F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
94{
95 let n = x.len();
96 let f0 = f(x, y);
97 let mut g = Array1::<f64>::zeros(n);
98 let mut x_fwd = x.to_owned();
99 for i in 0..n {
100 x_fwd[i] += h;
101 g[i] = (f(&x_fwd.view(), y) - f0) / h;
102 x_fwd[i] = x[i];
103 }
104 g
105}
106
107fn grad_y<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
109where
110 F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
111{
112 let m = y.len();
113 let f0 = f(x, y);
114 let mut g = Array1::<f64>::zeros(m);
115 let mut y_fwd = y.to_owned();
116 for i in 0..m {
117 y_fwd[i] += h;
118 g[i] = (f(x, &y_fwd.view()) - f0) / h;
119 y_fwd[i] = y[i];
120 }
121 g
122}
123
124#[inline]
125fn vec_norm(v: &Array1<f64>) -> f64 {
126 v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
127}
128
129pub fn minimax_solve<F>(
152 f: &F,
153 x0: &ArrayView1<f64>,
154 y0: &ArrayView1<f64>,
155 config: &MinimaxConfig,
156) -> OptimizeResult<MinimaxResult>
157where
158 F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
159{
160 let nx = x0.len();
161 let ny = y0.len();
162 if nx == 0 || ny == 0 {
163 return Err(OptimizeError::ValueError(
164 "x0 and y0 must be non-empty".to_string(),
165 ));
166 }
167
168 let mut x = x0.to_owned();
169 let mut y = y0.to_owned();
170 let mut converged = false;
171 let h = config.fd_step;
172
173 for k in 0..config.max_iter {
174 let gx = grad_x(f, &x.view(), &y.view(), h);
175 let gy = grad_y(f, &x.view(), &y.view(), h);
176
177 let mut dx_norm = 0.0_f64;
179 let mut dy_norm = 0.0_f64;
180 for i in 0..nx {
181 let step = config.step_size_x * gx[i];
182 x[i] -= step;
183 dx_norm += step * step;
184 }
185 for i in 0..ny {
186 let step = config.step_size_y * gy[i];
187 y[i] += step;
188 dy_norm += step * step;
189 }
190
191 let delta = dx_norm.sqrt() + dy_norm.sqrt();
192 if delta < config.tol {
193 converged = true;
194 if config.print_every > 0 {
195 eprintln!("[GDA] converged at iteration {}", k + 1);
196 }
197 break;
198 }
199 if config.print_every > 0 && (k + 1) % config.print_every == 0 {
200 eprintln!("[GDA] iter {}: delta={:.2e}", k + 1, delta);
201 }
202 }
203
204 let fun = f(&x.view(), &y.view());
205 let gap = compute_gap(f, &x.view(), &y.view(), h);
206
207 Ok(MinimaxResult {
208 x,
209 y,
210 fun,
211 n_iter: config.max_iter,
212 gap,
213 converged,
214 message: if converged {
215 "GDA converged".to_string()
216 } else {
217 "GDA reached maximum iterations".to_string()
218 },
219 })
220}
221
222pub fn extragradient_solve<F>(
246 f: &F,
247 x0: &ArrayView1<f64>,
248 y0: &ArrayView1<f64>,
249 config: &MinimaxConfig,
250) -> OptimizeResult<MinimaxResult>
251where
252 F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
253{
254 let nx = x0.len();
255 let ny = y0.len();
256 if nx == 0 || ny == 0 {
257 return Err(OptimizeError::ValueError(
258 "x0 and y0 must be non-empty".to_string(),
259 ));
260 }
261
262 let mut x = x0.to_owned();
263 let mut y = y0.to_owned();
264 let mut x_avg = Array1::<f64>::zeros(nx);
270 let mut y_avg = Array1::<f64>::zeros(ny);
271 let mut converged = false;
272 let h = config.fd_step;
273 let mut n_iters_done = 0usize;
274 let mut converged_at = config.max_iter; for k in 0..config.max_iter {
277 let gx_k = grad_x(f, &x.view(), &y.view(), h);
279 let gy_k = grad_y(f, &x.view(), &y.view(), h);
280
281 let x_bar: Array1<f64> = x
282 .iter()
283 .zip(gx_k.iter())
284 .map(|(&xi, &gi)| xi - config.step_size_x * gi)
285 .collect();
286 let y_bar: Array1<f64> = y
287 .iter()
288 .zip(gy_k.iter())
289 .map(|(&yi, &gi)| yi + config.step_size_y * gi)
290 .collect();
291
292 let gx_bar = grad_x(f, &x_bar.view(), &y_bar.view(), h);
294 let gy_bar = grad_y(f, &x_bar.view(), &y_bar.view(), h);
295
296 let mut delta = 0.0_f64;
297 for i in 0..nx {
298 let step = config.step_size_x * gx_bar[i];
299 x[i] -= step;
300 delta += step * step;
301 }
302 for i in 0..ny {
303 let step = config.step_size_y * gy_bar[i];
304 y[i] += step;
305 delta += step * step;
306 }
307
308 let t = (k + 1) as f64;
310 for i in 0..nx {
311 x_avg[i] += (x[i] - x_avg[i]) / t;
312 }
313 for i in 0..ny {
314 y_avg[i] += (y[i] - y_avg[i]) / t;
315 }
316
317 n_iters_done = k + 1;
318
319 if !converged && delta.sqrt() < config.tol {
320 converged = true;
321 converged_at = k + 1;
322 if config.print_every > 0 {
323 eprintln!("[EG] converged at iteration {}", k + 1);
324 }
325 }
329 if config.print_every > 0 && (k + 1) % config.print_every == 0 {
330 eprintln!("[EG] iter {}: delta={:.2e}", k + 1, delta.sqrt());
331 }
332 }
333 let _ = converged_at; let x_out = x_avg;
341 let y_out = y_avg;
342
343 let fun = f(&x_out.view(), &y_out.view());
344 let gap = compute_gap(f, &x_out.view(), &y_out.view(), h);
345
346 Ok(MinimaxResult {
347 x: x_out,
348 y: y_out,
349 fun,
350 n_iter: n_iters_done,
351 gap,
352 converged,
353 message: if converged {
354 "Extragradient converged".to_string()
355 } else {
356 "Extragradient reached maximum iterations".to_string()
357 },
358 })
359}
360
361#[derive(Debug, Clone)]
365pub struct PrimalDualConfig {
366 pub max_iter: usize,
368 pub tol: f64,
370 pub sigma: f64,
372 pub tau: f64,
374 pub fd_step: f64,
376}
377
378impl Default for PrimalDualConfig {
379 fn default() -> Self {
380 Self {
381 max_iter: 5_000,
382 tol: 1e-6,
383 sigma: 1e-3,
384 tau: 1e-3,
385 fd_step: 1e-5,
386 }
387 }
388}
389
390pub fn primal_dual<Px, Py>(
420 primal_fn: &Px,
421 dual_fn: &Py,
422 x0: &ArrayView1<f64>,
423 y0: &ArrayView1<f64>,
424 config: &PrimalDualConfig,
425) -> OptimizeResult<(Array1<f64>, Array1<f64>)>
426where
427 Px: Fn(&ArrayView1<f64>) -> Array1<f64>,
428 Py: Fn(&ArrayView1<f64>) -> Array1<f64>,
429{
430 let nx = x0.len();
431 let ny = y0.len();
432 if nx == 0 || ny == 0 {
433 return Err(OptimizeError::ValueError(
434 "x0 and y0 must be non-empty".to_string(),
435 ));
436 }
437
438 let mut x = x0.to_owned();
439 let mut y = y0.to_owned();
440 let mut x_bar = x.clone();
442
443 for _k in 0..config.max_iter {
444 let gy = dual_fn(&y.view());
446 let y_new: Array1<f64> = y
449 .iter()
450 .zip(gy.iter())
451 .map(|(&yi, &gyi)| yi + config.tau * gyi)
452 .collect();
453
454 let gx = primal_fn(&x.view());
456 let x_new: Array1<f64> = x
457 .iter()
458 .zip(gx.iter())
459 .map(|(&xi, &gxi)| xi - config.sigma * gxi)
460 .collect();
461
462 let x_bar_new: Array1<f64> = x_new
464 .iter()
465 .zip(x.iter())
466 .map(|(&xn, &xo)| 2.0 * xn - xo)
467 .collect();
468
469 let dx = vec_norm(&(x_new.clone() - &x));
471 let dy = vec_norm(&(y_new.clone() - &y));
472 let delta = dx + dy;
473
474 x = x_new;
475 y = y_new;
476 x_bar = x_bar_new;
477
478 if delta < config.tol {
479 break;
480 }
481 }
482 let _ = x_bar; Ok((x, y))
484}
485
486fn compute_gap<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> f64
495where
496 F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
497{
498 let gx = grad_x(f, x, y, h);
499 let gy = grad_y(f, x, y, h);
500 vec_norm(&gx) + vec_norm(&gy)
501}
502
503#[cfg(test)]
506mod tests {
507 use super::*;
508 use scirs2_core::ndarray::array;
509
510 fn bilinear(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
513 x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum()
514 }
515
516 fn convex_concave(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
519 let quad_x: f64 = x.iter().map(|xi| xi * xi).sum();
520 let quad_y: f64 = y.iter().map(|yi| yi * yi).sum();
521 let cross: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
522 quad_x - quad_y + cross
523 }
524
525 #[test]
526 fn test_minimax_gda_bilinear() {
527 let x0 = array![1.0, 1.0];
530 let y0 = array![1.0, 1.0];
531 let config = MinimaxConfig {
532 max_iter: 10_000,
533 tol: 1e-4,
534 step_size_x: 1e-3,
535 step_size_y: 1e-3,
536 ..Default::default()
537 };
538 let result = extragradient_solve(&bilinear, &x0.view(), &y0.view(), &config)
539 .expect("extragradient on bilinear should not fail");
540 let norm_x = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
542 let norm_y = result.y.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
543 assert!(
544 norm_x < 0.5,
545 "Extragradient bilinear: ‖x‖ should be small, got {}",
546 norm_x
547 );
548 assert!(
549 norm_y < 0.5,
550 "Extragradient bilinear: ‖y‖ should be small, got {}",
551 norm_y
552 );
553 }
554
555 #[test]
556 fn test_extragradient_convex_concave() {
557 let x0 = array![2.0];
558 let y0 = array![2.0];
559 let config = MinimaxConfig {
560 max_iter: 10_000,
561 tol: 1e-5,
562 step_size_x: 5e-4,
563 step_size_y: 5e-4,
564 ..Default::default()
565 };
566 let f = |x: &ArrayView1<f64>, y: &ArrayView1<f64>| x[0] * x[0] - y[0] * y[0];
568 let result = extragradient_solve(&f, &x0.view(), &y0.view(), &config)
569 .expect("failed to create result");
570 assert!(
571 result.x[0].abs() < 0.3,
572 "EG: expected x* ≈ 0, got {}",
573 result.x[0]
574 );
575 assert!(
576 result.y[0].abs() < 0.3,
577 "EG: expected y* ≈ 0, got {}",
578 result.y[0]
579 );
580 }
581
582 #[test]
583 fn test_extragradient_convex_concave_2d() {
584 let x0 = array![1.0, 1.0];
585 let y0 = array![1.0, 1.0];
586 let config = MinimaxConfig {
587 max_iter: 10_000,
588 tol: 1e-5,
589 step_size_x: 5e-4,
590 step_size_y: 5e-4,
591 ..Default::default()
592 };
593 let result = extragradient_solve(&convex_concave, &x0.view(), &y0.view(), &config)
594 .expect("unexpected None or Err");
595 let norm = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
597 assert!(norm < 1.5, "EG 2D: ‖x‖={} should be < 1.5", norm);
598 }
599
600 #[test]
601 fn test_primal_dual_gradient() {
602 let x0 = array![3.0, -2.0];
606 let y0 = array![1.0, 4.0];
607 let config = PrimalDualConfig {
608 max_iter: 20_000,
609 tol: 1e-5,
610 sigma: 5e-4,
611 tau: 5e-4,
612 ..Default::default()
613 };
614 let primal_fn = |x: &ArrayView1<f64>| x.mapv(|xi| 2.0 * xi);
615 let dual_fn = |y: &ArrayView1<f64>| y.mapv(|yi| -2.0 * yi);
616 let (x_star, y_star) = primal_dual(&primal_fn, &dual_fn, &x0.view(), &y0.view(), &config)
617 .expect("unexpected None or Err");
618 let xn = x_star.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
619 let yn = y_star.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
620 assert!(xn < 0.5, "PD: ‖x*‖={} should be < 0.5", xn);
621 assert!(yn < 0.5, "PD: ‖y*‖={} should be < 0.5", yn);
622 }
623
624 #[test]
625 fn test_minimax_empty_input() {
626 let x0: Array1<f64> = Array1::zeros(0);
627 let y0 = array![1.0];
628 let config = MinimaxConfig::default();
629 assert!(minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).is_err());
630 }
631}