1use crate::error::{OptimizeError, OptimizeResult};
28use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
29
30pub struct MinimaxProblem {
52 pub funcs: Vec<Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>>,
54}
55
56impl MinimaxProblem {
57 pub fn new(funcs: Vec<Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>>) -> Self {
59 Self { funcs }
60 }
61
62 pub fn eval_max(&self, x: &ArrayView1<f64>) -> f64 {
64 self.funcs
65 .iter()
66 .map(|f| f(x))
67 .fold(f64::NEG_INFINITY, f64::max)
68 }
69
70 pub fn eval_all(&self, x: &ArrayView1<f64>) -> Vec<f64> {
72 self.funcs.iter().map(|f| f(x)).collect()
73 }
74
75 pub fn argmax(&self, x: &ArrayView1<f64>) -> usize {
77 let vals = self.eval_all(x);
78 vals.iter()
79 .enumerate()
80 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
81 .map(|(i, _)| i)
82 .unwrap_or(0)
83 }
84
85 pub fn num_funcs(&self) -> usize {
87 self.funcs.len()
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct MinimaxSolverConfig {
94 pub max_iter: usize,
96 pub tol: f64,
98 pub step_size: f64,
100 pub step_decay: f64,
102 pub fd_step: f64,
104 pub smoothing_mu: f64,
106 pub fictitious_play_iter: usize,
108}
109
110impl Default for MinimaxSolverConfig {
111 fn default() -> Self {
112 Self {
113 max_iter: 2_000,
114 tol: 1e-6,
115 step_size: 1e-2,
116 step_decay: 1.0, fd_step: 1e-5,
118 smoothing_mu: 0.1,
119 fictitious_play_iter: 1_000,
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct MinimaxSolveResult {
127 pub x: Array1<f64>,
129 pub fun: f64,
131 pub active_index: usize,
133 pub n_iter: usize,
135 pub converged: bool,
137 pub message: String,
139}
140
141fn fd_gradient<F>(f: &F, x: &ArrayView1<f64>, h: f64) -> Array1<f64>
145where
146 F: Fn(&ArrayView1<f64>) -> f64,
147{
148 let n = x.len();
149 let f0 = f(x);
150 let mut g = Array1::<f64>::zeros(n);
151 let mut x_fwd = x.to_owned();
152 for i in 0..n {
153 x_fwd[i] += h;
154 g[i] = (f(&x_fwd.view()) - f0) / h;
155 x_fwd[i] = x[i];
156 }
157 g
158}
159
160#[inline]
162fn l2_norm(v: &Array1<f64>) -> f64 {
163 v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
164}
165
166pub fn minimax_subgradient(
191 problem: &MinimaxProblem,
192 x0: &ArrayView1<f64>,
193 config: &MinimaxSolverConfig,
194) -> OptimizeResult<MinimaxSolveResult> {
195 if problem.num_funcs() == 0 {
196 return Err(OptimizeError::ValueError(
197 "MinimaxProblem must contain at least one function".to_string(),
198 ));
199 }
200 let n = x0.len();
201 if n == 0 {
202 return Err(OptimizeError::ValueError(
203 "x0 must be non-empty".to_string(),
204 ));
205 }
206
207 let mut x = x0.to_owned();
208 let mut x_best = x.clone();
210 let mut val_best = problem.eval_max(&x.view());
211 let h = config.fd_step;
212 let mut converged = false;
213
214 for k in 0..config.max_iter {
215 let vals = problem.eval_all(&x.view());
216 let max_val = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
217
218 let active_idx = vals
220 .iter()
221 .enumerate()
222 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
223 .map(|(i, _)| i)
224 .unwrap_or(0);
225
226 if max_val < val_best {
227 val_best = max_val;
228 x_best = x.clone();
229 }
230
231 let subgrad = fd_gradient(
233 &|v: &ArrayView1<f64>| problem.funcs[active_idx](v),
234 &x.view(),
235 h,
236 );
237 let sg_norm = l2_norm(&subgrad);
238
239 if sg_norm < config.tol {
240 converged = true;
241 break;
242 }
243
244 let alpha = config.step_size / ((k as f64 + 1.0).sqrt());
246
247 for i in 0..n {
249 x[i] -= alpha * subgrad[i];
250 }
251 }
252
253 let active_idx = problem.argmax(&x_best.view());
254 Ok(MinimaxSolveResult {
255 fun: val_best,
256 active_index: active_idx,
257 n_iter: config.max_iter,
258 converged,
259 message: if converged {
260 "Subgradient descent converged".to_string()
261 } else {
262 "Subgradient descent reached maximum iterations".to_string()
263 },
264 x: x_best,
265 })
266}
267
268#[derive(Debug, Clone)]
272struct BundleCut {
273 point: Array1<f64>,
275 value: f64,
277 subgrad: Array1<f64>,
279}
280
281pub fn minimax_bundle(
307 problem: &MinimaxProblem,
308 x0: &ArrayView1<f64>,
309 config: &MinimaxSolverConfig,
310) -> OptimizeResult<MinimaxSolveResult> {
311 if problem.num_funcs() == 0 {
312 return Err(OptimizeError::ValueError(
313 "MinimaxProblem must contain at least one function".to_string(),
314 ));
315 }
316 let n = x0.len();
317 if n == 0 {
318 return Err(OptimizeError::ValueError(
319 "x0 must be non-empty".to_string(),
320 ));
321 }
322
323 let h = config.fd_step;
324 let max_bundle_size = 20_usize;
325 let prox_t = config.step_size;
327
328 let mut x = x0.to_owned();
329 let mut x_best = x.clone();
330 let mut val_best = problem.eval_max(&x.view());
331 let mut bundle: Vec<BundleCut> = Vec::with_capacity(max_bundle_size);
332 let mut converged = false;
333
334 let eval_model = |x_cur: &Array1<f64>, d: &Array1<f64>, cuts: &[BundleCut]| -> f64 {
337 cuts.iter()
338 .map(|cut| {
339 let diff: f64 = x_cur
340 .iter()
341 .zip(d.iter())
342 .zip(cut.point.iter())
343 .map(|((&xc, &dc), &yj)| xc + dc - yj)
344 .zip(cut.subgrad.iter())
345 .map(|(delta, &gj)| delta * gj)
346 .sum::<f64>();
347 cut.value + diff
348 })
349 .fold(f64::NEG_INFINITY, f64::max)
350 };
351
352 for k in 0..config.max_iter {
353 let vals = problem.eval_all(&x.view());
354 let max_val = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
355
356 if max_val < val_best {
357 val_best = max_val;
358 x_best = x.clone();
359 }
360
361 let threshold = max_val - config.tol;
363 let new_cuts: Vec<BundleCut> = vals
364 .iter()
365 .enumerate()
366 .filter(|(_, &v)| v >= threshold)
367 .map(|(i, _)| {
368 let subgrad = fd_gradient(&|v: &ArrayView1<f64>| problem.funcs[i](v), &x.view(), h);
369 BundleCut {
370 point: x.clone(),
371 value: vals[i],
372 subgrad,
373 }
374 })
375 .collect();
376
377 bundle.extend(new_cuts);
378
379 if bundle.len() > max_bundle_size {
381 let start = bundle.len() - max_bundle_size;
382 bundle.drain(..start);
383 }
384
385 if bundle.is_empty() {
386 break;
387 }
388
389 let mut d = Array1::<f64>::zeros(n);
393 let inner_steps = 100_usize;
394 let inner_step = prox_t * 0.5;
395
396 for _ in 0..inner_steps {
397 let active_cut = bundle
399 .iter()
400 .enumerate()
401 .max_by(|(_, a), (_, b)| {
402 let va = a.value
403 + a.point
404 .iter()
405 .zip(d.iter())
406 .zip(x.iter())
407 .map(|((&yj, &dj), &xc)| {
408 a.subgrad[{
409 0
411 }] * (xc + dj - yj)
412 })
413 .sum::<f64>();
414 let vb = b.value
415 + b.point
416 .iter()
417 .zip(d.iter())
418 .zip(x.iter())
419 .map(|((&yj, &dj), &xc)| b.subgrad[0] * (xc + dj - yj))
420 .sum::<f64>();
421 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
422 })
423 .map(|(i, _)| i)
424 .unwrap_or(0);
425
426 let active_idx = bundle
428 .iter()
429 .enumerate()
430 .max_by(|(_, ca), (_, cb)| {
431 let va: f64 = ca.value
432 + ca.subgrad
433 .iter()
434 .zip(x.iter().zip(d.iter()).zip(ca.point.iter()))
435 .map(|(&gj, ((&xc, &dc), &yj))| gj * (xc + dc - yj))
436 .sum::<f64>();
437 let vb: f64 = cb.value
438 + cb.subgrad
439 .iter()
440 .zip(x.iter().zip(d.iter()).zip(cb.point.iter()))
441 .map(|(&gj, ((&xc, &dc), &yj))| gj * (xc + dc - yj))
442 .sum::<f64>();
443 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
444 })
445 .map(|(i, _)| i)
446 .unwrap_or(active_cut);
447
448 let cut = &bundle[active_idx];
449 let grad_d: Array1<f64> = cut
453 .subgrad
454 .iter()
455 .zip(d.iter())
456 .map(|(&gj, &dj)| gj + dj / prox_t)
457 .collect();
458
459 let step_norm = l2_norm(&grad_d);
460 if step_norm < config.tol * 0.1 {
461 break;
462 }
463 for i in 0..n {
464 d[i] -= inner_step * grad_d[i];
465 }
466 }
467
468 let d_norm = l2_norm(&d);
470 if d_norm < config.tol {
471 converged = true;
472 break;
473 }
474
475 for i in 0..n {
477 x[i] += d[i];
478 }
479
480 let new_max = problem.eval_max(&x.view());
482 if (new_max - max_val).abs() < config.tol && k > 10 {
483 converged = true;
484 break;
485 }
486 }
487
488 let active_idx = problem.argmax(&x_best.view());
489 Ok(MinimaxSolveResult {
490 x: x_best,
491 fun: val_best,
492 active_index: active_idx,
493 n_iter: config.max_iter,
494 converged,
495 message: if converged {
496 "Bundle method converged".to_string()
497 } else {
498 "Bundle method reached maximum iterations".to_string()
499 },
500 })
501}
502
503pub fn smooth_minimax(
533 problem: &MinimaxProblem,
534 x0: &ArrayView1<f64>,
535 config: &MinimaxSolverConfig,
536) -> OptimizeResult<MinimaxSolveResult> {
537 if problem.num_funcs() == 0 {
538 return Err(OptimizeError::ValueError(
539 "MinimaxProblem must contain at least one function".to_string(),
540 ));
541 }
542 let n = x0.len();
543 if n == 0 {
544 return Err(OptimizeError::ValueError(
545 "x0 must be non-empty".to_string(),
546 ));
547 }
548 let mu = config.smoothing_mu;
549 if mu <= 0.0 {
550 return Err(OptimizeError::ValueError(format!(
551 "smoothing_mu must be positive, got {}",
552 mu
553 )));
554 }
555
556 let k = problem.num_funcs() as f64;
557 let h = config.fd_step;
558
559 let smooth_obj = |x: &ArrayView1<f64>| -> f64 {
562 let vals = problem.eval_all(x);
563 let max_val = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
565 if max_val.is_infinite() {
566 return max_val;
567 }
568 let sum_exp: f64 = vals.iter().map(|&v| ((v - max_val) / mu).exp()).sum();
569 mu * (sum_exp.ln() + (max_val / mu)) - mu * k.ln()
570 };
571
572 let mut x = x0.to_owned();
573 let mut x_best = x.clone();
574 let mut val_best = problem.eval_max(&x.view());
575 let mut converged = false;
576
577 let mut y = x.clone(); let mut t_k = 1.0_f64; for _ in 0..config.max_iter {
582 let grad = fd_gradient(&smooth_obj, &y.view(), h);
584 let grad_norm = l2_norm(&grad);
585
586 if grad_norm < config.tol {
587 converged = true;
588 break;
589 }
590
591 let x_new: Array1<f64> = y
593 .iter()
594 .zip(grad.iter())
595 .map(|(&yi, &gi)| yi - config.step_size * gi)
596 .collect();
597
598 let t_new = (1.0 + (1.0 + 4.0 * t_k * t_k).sqrt()) / 2.0;
600 let mom = (t_k - 1.0) / t_new;
601
602 let y_new: Array1<f64> = x_new
603 .iter()
604 .zip(x.iter())
605 .map(|(&xn, &xo)| xn + mom * (xn - xo))
606 .collect();
607
608 let max_val = problem.eval_max(&x_new.view());
610 if max_val < val_best {
611 val_best = max_val;
612 x_best = x_new.clone();
613 }
614
615 x = x_new;
616 y = y_new;
617 t_k = t_new;
618 }
619
620 let active_idx = problem.argmax(&x_best.view());
621 Ok(MinimaxSolveResult {
622 x: x_best,
623 fun: val_best,
624 active_index: active_idx,
625 n_iter: config.max_iter,
626 converged,
627 message: if converged {
628 "Smooth minimax (Nesterov) converged".to_string()
629 } else {
630 "Smooth minimax reached maximum iterations".to_string()
631 },
632 })
633}
634
635#[derive(Debug, Clone)]
639pub struct GameMinimaxResult {
640 pub x: Array1<f64>,
642 pub fun: f64,
644 pub maximizer_strategy: Array1<f64>,
647 pub n_iter: usize,
649 pub converged: bool,
651 pub message: String,
653}
654
655pub fn game_theoretic_minimax(
690 problem: &MinimaxProblem,
691 x0: &ArrayView1<f64>,
692 step_size: f64,
693 config: &MinimaxSolverConfig,
694) -> OptimizeResult<GameMinimaxResult> {
695 if problem.num_funcs() == 0 {
696 return Err(OptimizeError::ValueError(
697 "MinimaxProblem must contain at least one function".to_string(),
698 ));
699 }
700 let n = x0.len();
701 if n == 0 {
702 return Err(OptimizeError::ValueError(
703 "x0 must be non-empty".to_string(),
704 ));
705 }
706 let k = problem.num_funcs();
707 let h = config.fd_step;
708 let max_fp_iter = config.fictitious_play_iter;
709
710 let mut counts = vec![0_usize; k];
712 let mut x = x0.to_owned();
713 let mut x_best = x.clone();
714 let mut val_best = problem.eval_max(&x.view());
715 let mut converged = false;
716
717 let mut cumulative_grad = Array1::<f64>::zeros(n);
719 let mut cumulative_count = 0_usize;
720
721 for iter in 0..max_fp_iter {
722 let vals = problem.eval_all(&x.view());
725
726 let active_i = vals
728 .iter()
729 .enumerate()
730 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
731 .map(|(i, _)| i)
732 .unwrap_or(0);
733
734 counts[active_i] += 1;
735 cumulative_count += 1;
736
737 let max_val = vals[active_i];
738 if max_val < val_best {
739 val_best = max_val;
740 x_best = x.clone();
741 }
742
743 let grad_i = fd_gradient(
746 &|v: &ArrayView1<f64>| problem.funcs[active_i](v),
747 &x.view(),
748 h,
749 );
750 for i in 0..n {
751 cumulative_grad[i] += grad_i[i];
752 }
753
754 let avg_grad_norm: f64 =
756 cumulative_grad.iter().map(|&g| g * g).sum::<f64>().sqrt() / cumulative_count as f64;
757
758 if avg_grad_norm < config.tol {
759 converged = true;
760 break;
761 }
762
763 let alpha = step_size / ((iter as f64 + 1.0).sqrt());
765 for i in 0..n {
766 x[i] -= alpha * cumulative_grad[i] / cumulative_count as f64;
767 }
768
769 if iter > 10 && (max_val - problem.eval_max(&x.view())).abs() < config.tol {
771 converged = true;
772 break;
773 }
774 }
775
776 let total = counts.iter().sum::<usize>().max(1) as f64;
778 let maximizer_strategy: Array1<f64> = counts.iter().map(|&c| c as f64 / total).collect();
779
780 Ok(GameMinimaxResult {
781 x: x_best,
782 fun: val_best,
783 maximizer_strategy,
784 n_iter: max_fp_iter,
785 converged,
786 message: if converged {
787 "Fictitious play converged".to_string()
788 } else {
789 "Fictitious play reached maximum iterations".to_string()
790 },
791 })
792}
793
794pub fn smooth_max_value<F>(funcs: &[F], x: &ArrayView1<f64>, mu: f64) -> OptimizeResult<f64>
812where
813 F: Fn(&ArrayView1<f64>) -> f64,
814{
815 if funcs.is_empty() {
816 return Err(OptimizeError::ValueError(
817 "funcs must be non-empty".to_string(),
818 ));
819 }
820 if mu <= 0.0 {
821 return Err(OptimizeError::ValueError(format!(
822 "mu must be positive, got {}",
823 mu
824 )));
825 }
826 let vals: Vec<f64> = funcs.iter().map(|f| f(x)).collect();
827 let max_val = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
828 if max_val.is_infinite() {
829 return Ok(max_val);
830 }
831 let sum_exp: f64 = vals.iter().map(|&v| ((v - max_val) / mu).exp()).sum();
832 let k = funcs.len() as f64;
833 Ok(mu * (sum_exp.ln() + (max_val / mu)) - mu * k.ln())
834}
835
836#[cfg(test)]
839mod tests {
840 use super::*;
841 use scirs2_core::ndarray::array;
842
843 fn build_two_func_problem() -> MinimaxProblem {
848 MinimaxProblem::new(vec![
849 Box::new(|x: &ArrayView1<f64>| (x[0] - 1.0).powi(2)),
850 Box::new(|x: &ArrayView1<f64>| (x[0] + 1.0).powi(2)),
851 ])
852 }
853
854 #[test]
855 fn test_minimax_problem_eval() {
856 let p = build_two_func_problem();
857 let x = array![0.0];
858 assert_eq!(p.eval_max(&x.view()), 1.0);
859 let x2 = array![1.0];
860 assert!((p.eval_max(&x2.view()) - 4.0).abs() < 1e-9);
862 }
863
864 #[test]
865 fn test_subgradient_basic() {
866 let p = build_two_func_problem();
867 let x0 = array![3.0];
868 let config = MinimaxSolverConfig {
869 max_iter: 3_000,
870 tol: 1e-4,
871 step_size: 0.5,
872 ..Default::default()
873 };
874 let result = minimax_subgradient(&p, &x0.view(), &config).expect("failed to create result");
875 assert!(
877 result.fun <= 1.5,
878 "subgradient minimax value {} should be ≤ 1.5",
879 result.fun
880 );
881 assert!(
882 result.x[0].abs() < 1.0,
883 "subgradient minimizer {} should be near 0",
884 result.x[0]
885 );
886 }
887
888 #[test]
889 fn test_smooth_minimax_basic() {
890 let p = build_two_func_problem();
891 let x0 = array![3.0];
892 let config = MinimaxSolverConfig {
893 max_iter: 3_000,
894 tol: 1e-5,
895 step_size: 1e-2,
896 smoothing_mu: 0.05,
897 ..Default::default()
898 };
899 let result = smooth_minimax(&p, &x0.view(), &config).expect("failed to create result");
900 assert!(
901 result.fun <= 2.0,
902 "smooth minimax value {} should be ≤ 2.0",
903 result.fun
904 );
905 }
906
907 #[test]
908 fn test_game_theoretic_minimax() {
909 let p = build_two_func_problem();
910 let x0 = array![2.0];
911 let config = MinimaxSolverConfig {
912 max_iter: 500,
913 tol: 1e-4,
914 fictitious_play_iter: 500,
915 ..Default::default()
916 };
917 let result =
918 game_theoretic_minimax(&p, &x0.view(), 0.1, &config).expect("failed to create result");
919 assert!(
921 result.x[0].abs() < 2.5,
922 "game theoretic minimizer {} should move toward 0",
923 result.x[0]
924 );
925 assert_eq!(result.maximizer_strategy.len(), 2);
927 let strat_sum: f64 = result.maximizer_strategy.iter().sum();
928 assert!((strat_sum - 1.0).abs() < 1e-9, "strategy should sum to 1");
929 }
930
931 #[test]
932 fn test_smooth_max_value() {
933 let funcs: Vec<Box<dyn Fn(&ArrayView1<f64>) -> f64>> = vec![
934 Box::new(|_x: &ArrayView1<f64>| 1.0),
935 Box::new(|_x: &ArrayView1<f64>| 2.0),
936 Box::new(|_x: &ArrayView1<f64>| 3.0),
937 ];
938 let x = array![0.0];
939 let val = smooth_max_value(&funcs, &x.view(), 0.01).expect("failed to create val");
940 assert!((val - 3.0).abs() < 0.1, "smooth max ≈ 3.0, got {val}");
942 }
943
944 #[test]
945 fn test_bundle_method_basic() {
946 let p = build_two_func_problem();
947 let x0 = array![2.0];
948 let config = MinimaxSolverConfig {
949 max_iter: 500,
950 tol: 1e-4,
951 step_size: 0.5,
952 ..Default::default()
953 };
954 let result = minimax_bundle(&p, &x0.view(), &config).expect("failed to create result");
955 assert!(
956 result.fun <= 5.0,
957 "bundle minimax value {} should be reasonable",
958 result.fun
959 );
960 }
961
962 #[test]
963 fn test_empty_problem_error() {
964 let p = MinimaxProblem::new(vec![]);
965 let x0 = array![1.0];
966 let config = MinimaxSolverConfig::default();
967 assert!(minimax_subgradient(&p, &x0.view(), &config).is_err());
968 assert!(smooth_minimax(&p, &x0.view(), &config).is_err());
969 }
970
971 #[test]
972 fn test_empty_x0_error() {
973 let p = build_two_func_problem();
974 let x0: Array1<f64> = Array1::zeros(0);
975 let config = MinimaxSolverConfig::default();
976 assert!(minimax_subgradient(&p, &x0.view(), &config).is_err());
977 }
978}