use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use crate::error::OptimizeError;
#[derive(Debug, Clone)]
pub struct MinimaxResult {
pub game_value: f64,
pub row_player_strategy: Array1<f64>,
pub col_player_strategy: Array1<f64>,
}
pub fn minimax_solve(payoff: ArrayView2<f64>) -> Result<MinimaxResult, OptimizeError> {
if let Some((row, col, val)) = find_saddle_point(payoff) {
let m = payoff.nrows();
let n = payoff.ncols();
let mut row_strat = Array1::zeros(m);
let mut col_strat = Array1::zeros(n);
row_strat[row] = 1.0;
col_strat[col] = 1.0;
return Ok(MinimaxResult {
game_value: val,
row_player_strategy: row_strat,
col_player_strategy: col_strat,
});
}
linear_program_minimax(payoff)
}
pub fn find_saddle_point(payoff: ArrayView2<f64>) -> Option<(usize, usize, f64)> {
let m = payoff.nrows();
let n = payoff.ncols();
if m == 0 || n == 0 {
return None;
}
let row_min: Vec<f64> = (0..m)
.map(|i| {
(0..n)
.map(|j| payoff[[i, j]])
.fold(f64::INFINITY, f64::min)
})
.collect();
let col_max: Vec<f64> = (0..n)
.map(|j| {
(0..m)
.map(|i| payoff[[i, j]])
.fold(f64::NEG_INFINITY, f64::max)
})
.collect();
let max_row_min = row_min.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min_col_max = col_max.iter().cloned().fold(f64::INFINITY, f64::min);
if (max_row_min - min_col_max).abs() < 1e-10 {
for i in 0..m {
for j in 0..n {
if (payoff[[i, j]] - max_row_min).abs() < 1e-10
&& (row_min[i] - max_row_min).abs() < 1e-10
&& (col_max[j] - max_row_min).abs() < 1e-10
{
return Some((i, j, max_row_min));
}
}
}
}
None
}
pub fn remove_dominated_strategies(
payoff: &mut Array2<f64>,
) -> (Vec<usize>, Vec<usize>) {
let m = payoff.nrows();
let n = payoff.ncols();
let mut active_rows: Vec<usize> = (0..m).collect();
let mut active_cols: Vec<usize> = (0..n).collect();
let mut changed = true;
while changed {
changed = false;
let mut to_remove_rows: Vec<usize> = Vec::new();
for &i in &active_rows {
let is_dominated = active_rows.iter().any(|&i2| {
if i2 == i {
return false;
}
let weakly_dominates =
active_cols.iter().all(|&j| payoff[[i2, j]] >= payoff[[i, j]]);
let strictly_somewhere =
active_cols.iter().any(|&j| payoff[[i2, j]] > payoff[[i, j]]);
weakly_dominates && strictly_somewhere
});
if is_dominated {
to_remove_rows.push(i);
changed = true;
}
}
active_rows.retain(|r| !to_remove_rows.contains(r));
let mut to_remove_cols: Vec<usize> = Vec::new();
for &j in &active_cols {
let is_dominated = active_cols.iter().any(|&j2| {
if j2 == j {
return false;
}
let weakly_dominates =
active_rows.iter().all(|&i| payoff[[i, j2]] <= payoff[[i, j]]);
let strictly_somewhere =
active_rows.iter().any(|&i| payoff[[i, j2]] < payoff[[i, j]]);
weakly_dominates && strictly_somewhere
});
if is_dominated {
to_remove_cols.push(j);
changed = true;
}
}
active_cols.retain(|c| !to_remove_cols.contains(c));
}
(active_rows, active_cols)
}
pub fn row_best_response(payoff: ArrayView2<f64>, col_mixed: &[f64]) -> Vec<usize> {
let m = payoff.nrows();
let n = payoff.ncols();
if col_mixed.len() != n {
return Vec::new();
}
let row_payoffs: Vec<f64> = (0..m)
.map(|i| {
(0..n)
.map(|j| payoff[[i, j]] * col_mixed[j])
.sum::<f64>()
})
.collect();
let max_payoff = row_payoffs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
(0..m)
.filter(|&i| (row_payoffs[i] - max_payoff).abs() < 1e-9)
.collect()
}
pub fn linear_program_minimax(
payoff: ArrayView2<f64>,
) -> Result<MinimaxResult, OptimizeError> {
let m = payoff.nrows();
let n = payoff.ncols();
if m == 0 || n == 0 {
return Err(OptimizeError::ValueError(
"Payoff matrix must be non-empty".to_string(),
));
}
let min_val = payoff
.iter()
.cloned()
.fold(f64::INFINITY, f64::min);
let shift = if min_val <= 0.0 { -min_val + 1.0 } else { 0.0 };
let shifted: Vec<f64> = payoff.iter().map(|&v| v + shift).collect();
let row_lp = solve_row_lp(&shifted, m, n)?;
let col_lp = solve_col_lp(&shifted, m, n)?;
let game_value = row_lp.1 - shift;
Ok(MinimaxResult {
game_value,
row_player_strategy: Array1::from(row_lp.0),
col_player_strategy: Array1::from(col_lp.0),
})
}
fn solve_row_lp(
shifted: &[f64],
m: usize,
n: usize,
) -> Result<(Vec<f64>, f64), OptimizeError> {
let big_m = 1e6_f64;
let n_artif = n;
let total_vars = m + n + n_artif;
let n_rows = n + 1;
let n_cols = total_vars + 1;
let rhs_col = total_vars;
let mut tab = vec![0.0_f64; n_rows * n_cols];
for j in 0..n {
for i in 0..m {
tab[j * n_cols + i] = shifted[i * n + j]; }
tab[j * n_cols + m + j] = -1.0;
tab[j * n_cols + m + n + j] = 1.0;
tab[j * n_cols + rhs_col] = 1.0;
}
for i in 0..m {
tab[n * n_cols + i] = 1.0;
}
for j in 0..n_artif {
tab[n * n_cols + m + n + j] = big_m;
}
for j in 0..n {
for k in 0..n_cols {
let constraint_val = tab[j * n_cols + k];
tab[n * n_cols + k] -= big_m * constraint_val;
}
}
let mut basis: Vec<usize> = (m + n..m + n + n_artif).collect();
simplex_method(&mut tab, &mut basis, n_rows, n_cols, total_vars)?;
let mut x = vec![0.0_f64; m];
for (b_idx, &var) in basis.iter().enumerate() {
if var < m {
x[var] = tab[b_idx * n_cols + rhs_col];
}
}
let sum_x: f64 = x.iter().sum();
if sum_x < 1e-12 {
return Err(OptimizeError::ComputationError(
"Row LP: zero sum of variables; game may be degenerate".to_string(),
));
}
let game_value = 1.0 / sum_x;
let strategy: Vec<f64> = x.iter().map(|&xi| xi * game_value).collect();
Ok((strategy, game_value))
}
fn solve_col_lp(
shifted: &[f64],
m: usize,
n: usize,
) -> Result<(Vec<f64>, f64), OptimizeError> {
let total_vars = n + m; let n_rows = m + 1;
let n_cols = total_vars + 1;
let rhs_col = total_vars;
let mut tab = vec![0.0_f64; n_rows * n_cols];
for i in 0..m {
for j in 0..n {
tab[i * n_cols + j] = shifted[i * n + j];
}
tab[i * n_cols + n + i] = 1.0; tab[i * n_cols + rhs_col] = 1.0;
}
for j in 0..n {
tab[m * n_cols + j] = -1.0;
}
let mut basis: Vec<usize> = (n..n + m).collect();
simplex_method(&mut tab, &mut basis, n_rows, n_cols, total_vars)?;
let mut y = vec![0.0_f64; n];
for (b_idx, &var) in basis.iter().enumerate() {
if var < n {
y[var] = tab[b_idx * n_cols + rhs_col];
}
}
let sum_y: f64 = y.iter().sum();
if sum_y < 1e-12 {
return Err(OptimizeError::ComputationError(
"Column LP: zero sum of variables; game may be degenerate".to_string(),
));
}
let game_value = 1.0 / sum_y;
let strategy: Vec<f64> = y.iter().map(|&yi| yi * game_value).collect();
Ok((strategy, game_value))
}
fn simplex_method(
tab: &mut Vec<f64>,
basis: &mut Vec<usize>,
n_rows: usize,
n_cols: usize,
n_vars: usize,
) -> Result<(), OptimizeError> {
let n_constraints = n_rows - 1;
let obj_row = n_constraints;
let rhs_col = n_cols - 1;
let max_iter = 10_000 * n_rows;
for _iter in 0..max_iter {
let pivot_col = (0..n_vars).find(|&j| tab[obj_row * n_cols + j] < -1e-9);
let pivot_col = match pivot_col {
None => return Ok(()), Some(c) => c,
};
let mut min_ratio = f64::INFINITY;
let mut pivot_row = None;
for i in 0..n_constraints {
let element = tab[i * n_cols + pivot_col];
if element > 1e-9 {
let ratio = tab[i * n_cols + rhs_col] / element;
if ratio < min_ratio - 1e-12 {
min_ratio = ratio;
pivot_row = Some(i);
} else if (ratio - min_ratio).abs() < 1e-12 {
if let Some(prev_row) = pivot_row {
if basis[i] < basis[prev_row] {
pivot_row = Some(i);
}
}
}
}
}
let pivot_row = pivot_row.ok_or_else(|| {
OptimizeError::ComputationError(
"Simplex: problem is unbounded".to_string(),
)
})?;
let pivot_val = tab[pivot_row * n_cols + pivot_col];
for k in 0..n_cols {
tab[pivot_row * n_cols + k] /= pivot_val;
}
for i in 0..n_rows {
if i == pivot_row {
continue;
}
let factor = tab[i * n_cols + pivot_col];
if factor.abs() < 1e-15 {
continue;
}
for k in 0..n_cols {
let pivot_k = tab[pivot_row * n_cols + k];
tab[i * n_cols + k] -= factor * pivot_k;
}
}
basis[pivot_row] = pivot_col;
}
Err(OptimizeError::ConvergenceError(
"Simplex method did not converge within maximum iterations".to_string(),
))
}
pub fn fictitious_play(
payoff: ArrayView2<f64>,
n_iterations: usize,
) -> Result<MinimaxResult, OptimizeError> {
let m = payoff.nrows();
let n = payoff.ncols();
if m == 0 || n == 0 {
return Err(OptimizeError::ValueError(
"Payoff matrix must be non-empty".to_string(),
));
}
if n_iterations == 0 {
return Err(OptimizeError::ValueError(
"n_iterations must be positive".to_string(),
));
}
let mut row_counts = vec![0u64; m];
let mut col_counts = vec![0u64; n];
let mut row_strat_idx = 0usize;
let mut col_strat_idx = 0usize;
for t in 0..n_iterations {
row_counts[row_strat_idx] += 1;
col_counts[col_strat_idx] += 1;
let col_freq: Vec<f64> = col_counts
.iter()
.map(|&c| c as f64 / (t + 1) as f64)
.collect();
let row_payoffs: Vec<f64> = (0..m)
.map(|i| (0..n).map(|j| payoff[[i, j]] * col_freq[j]).sum::<f64>())
.collect();
let max_row = row_payoffs
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
row_strat_idx = row_payoffs
.iter()
.position(|&v| (v - max_row).abs() < 1e-12)
.unwrap_or(0);
let row_freq: Vec<f64> = row_counts
.iter()
.map(|&c| c as f64 / (t + 1) as f64)
.collect();
let col_payoffs: Vec<f64> = (0..n)
.map(|j| (0..m).map(|i| payoff[[i, j]] * row_freq[i]).sum::<f64>())
.collect();
let min_col = col_payoffs
.iter()
.cloned()
.fold(f64::INFINITY, f64::min);
col_strat_idx = col_payoffs
.iter()
.position(|&v| (v - min_col).abs() < 1e-12)
.unwrap_or(0);
}
let total_row = row_counts.iter().sum::<u64>() as f64;
let total_col = col_counts.iter().sum::<u64>() as f64;
let row_strategy: Vec<f64> = row_counts.iter().map(|&c| c as f64 / total_row).collect();
let col_strategy: Vec<f64> = col_counts.iter().map(|&c| c as f64 / total_col).collect();
let game_value: f64 = (0..m)
.map(|i| {
(0..n)
.map(|j| payoff[[i, j]] * row_strategy[i] * col_strategy[j])
.sum::<f64>()
})
.sum();
Ok(MinimaxResult {
game_value,
row_player_strategy: Array1::from(row_strategy),
col_player_strategy: Array1::from(col_strategy),
})
}
pub fn security_strategies(
payoff: ArrayView2<f64>,
) -> Result<(Vec<f64>, Vec<f64>, f64, f64), OptimizeError> {
let m = payoff.nrows();
let n = payoff.ncols();
if m == 0 || n == 0 {
return Err(OptimizeError::ValueError(
"Payoff matrix must be non-empty".to_string(),
));
}
let pure_maximin_row = (0..m)
.map(|i| {
(0..n)
.map(|j| payoff[[i, j]])
.fold(f64::INFINITY, f64::min)
})
.enumerate()
.max_by(|(_, a): &(usize, f64), (_, b): &(usize, f64)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let pure_minimax_col = (0..n)
.map(|j| {
(0..m)
.map(|i| payoff[[i, j]])
.fold(f64::NEG_INFINITY, f64::max)
})
.enumerate()
.min_by(|(_, a): &(usize, f64), (_, b): &(usize, f64)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let (maximin_idx, maximin_val) = pure_maximin_row
.ok_or_else(|| OptimizeError::ComputationError("Empty row set".to_string()))?;
let (minimax_idx, minimax_val) = pure_minimax_col
.ok_or_else(|| OptimizeError::ComputationError("Empty col set".to_string()))?;
let mut row_strat = vec![0.0_f64; m];
let mut col_strat = vec![0.0_f64; n];
row_strat[maximin_idx] = 1.0;
col_strat[minimax_idx] = 1.0;
if (maximin_val - minimax_val).abs() > 1e-9 {
let result = linear_program_minimax(payoff)?;
let game_val = result.game_value;
let row_s = result.row_player_strategy.to_vec();
let col_s = result.col_player_strategy.to_vec();
return Ok((row_s, col_s, game_val, game_val));
}
Ok((row_strat, col_strat, maximin_val, minimax_val))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use approx::assert_relative_eq;
#[test]
fn test_find_saddle_point_exists() {
let payoff = array![[3.0, 2.0, 3.0], [2.0, 1.0, 2.0], [3.0, 4.0, 4.0]];
let saddle = find_saddle_point(payoff.view());
assert!(saddle.is_some());
let (r, c, v) = saddle.expect("saddle should not be None/Err");
assert_eq!(r, 2);
assert_eq!(c, 0);
assert_relative_eq!(v, 3.0);
}
#[test]
fn test_find_saddle_point_none() {
let payoff = array![[1.0, -1.0], [-1.0, 1.0]];
let saddle = find_saddle_point(payoff.view());
assert!(saddle.is_none());
}
#[test]
fn test_minimax_solve_saddle_point() {
let payoff = array![[3.0, 2.0], [1.0, 4.0]];
let result = minimax_solve(payoff.view()).expect("solve succeeds");
assert!(result.game_value >= 1.9 && result.game_value <= 3.1,
"game_value = {}", result.game_value);
let sum_row: f64 = result.row_player_strategy.iter().sum();
let sum_col: f64 = result.col_player_strategy.iter().sum();
assert_relative_eq!(sum_row, 1.0, epsilon = 1e-5);
assert_relative_eq!(sum_col, 1.0, epsilon = 1e-5);
}
#[test]
fn test_minimax_solve_matching_pennies() {
let payoff = array![[1.0, -1.0], [-1.0, 1.0]];
let result = minimax_solve(payoff.view()).expect("solve");
assert_relative_eq!(result.game_value, 0.0, epsilon = 1e-4);
assert_relative_eq!(result.row_player_strategy[0], 0.5, epsilon = 1e-4);
assert_relative_eq!(result.row_player_strategy[1], 0.5, epsilon = 1e-4);
assert_relative_eq!(result.col_player_strategy[0], 0.5, epsilon = 1e-4);
assert_relative_eq!(result.col_player_strategy[1], 0.5, epsilon = 1e-4);
}
#[test]
fn test_minimax_solve_pure_saddle_via_minimax() {
let payoff = array![[3.0, 5.0], [2.0, 4.0]];
let result = minimax_solve(payoff.view()).expect("solve");
assert_relative_eq!(result.game_value, 3.0, epsilon = 1e-6);
assert_relative_eq!(result.row_player_strategy[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(result.col_player_strategy[0], 1.0, epsilon = 1e-6);
}
#[test]
fn test_linear_program_minimax_rps() {
let payoff = array![[0.0, -1.0, 1.0], [1.0, 0.0, -1.0], [-1.0, 1.0, 0.0]];
let result = linear_program_minimax(payoff.view()).expect("solve");
assert_relative_eq!(result.game_value, 0.0, epsilon = 1e-3);
for &v in result.row_player_strategy.iter() {
assert_relative_eq!(v, 1.0 / 3.0, epsilon = 1e-3);
}
}
#[test]
fn test_remove_dominated_strategies() {
let mut payoff = array![[4.0, 2.0], [1.0, 1.0]];
let (rows, cols) = remove_dominated_strategies(&mut payoff);
assert!(rows.contains(&0));
assert!(!rows.contains(&1));
}
#[test]
fn test_row_best_response() {
let payoff = array![[3.0, 0.0], [0.0, 3.0]];
let br = row_best_response(payoff.view(), &[1.0, 0.0]);
assert_eq!(br, vec![0]);
let br2 = row_best_response(payoff.view(), &[0.0, 1.0]);
assert_eq!(br2, vec![1]);
}
#[test]
fn test_row_best_response_invalid_length() {
let payoff = array![[1.0, 2.0], [3.0, 4.0]];
let br = row_best_response(payoff.view(), &[1.0]);
assert!(br.is_empty());
}
#[test]
fn test_fictitious_play_matching_pennies() {
let payoff = array![[1.0, -1.0], [-1.0, 1.0]];
let result = fictitious_play(payoff.view(), 10_000).expect("converges");
assert_relative_eq!(result.row_player_strategy[0], 0.5, epsilon = 0.05);
assert_relative_eq!(result.col_player_strategy[0], 0.5, epsilon = 0.05);
}
#[test]
fn test_fictitious_play_pure_saddle() {
let payoff = array![[3.0, 5.0], [2.0, 4.0]];
let result = fictitious_play(payoff.view(), 1000).expect("converges");
assert!(result.row_player_strategy[0] > 0.8);
assert!(result.col_player_strategy[0] > 0.8);
}
#[test]
fn test_security_strategies_saddle() {
let payoff = array![[3.0, 5.0], [2.0, 4.0]];
let (row_s, col_s, v_max, v_min) = security_strategies(payoff.view()).expect("ok");
assert_relative_eq!(v_max, 3.0, epsilon = 1e-5);
assert_relative_eq!(v_min, 3.0, epsilon = 1e-5);
assert_relative_eq!(row_s[0], 1.0, epsilon = 1e-5);
assert_relative_eq!(col_s[0], 1.0, epsilon = 1e-5);
}
#[test]
fn test_security_strategies_mixed() {
let payoff = array![[1.0, -1.0], [-1.0, 1.0]];
let (row_s, col_s, v_max, v_min) = security_strategies(payoff.view()).expect("ok");
assert_relative_eq!(v_max, 0.0, epsilon = 1e-3);
assert_relative_eq!(v_min, 0.0, epsilon = 1e-3);
let sum_r: f64 = row_s.iter().sum();
let sum_c: f64 = col_s.iter().sum();
assert_relative_eq!(sum_r, 1.0, epsilon = 1e-5);
assert_relative_eq!(sum_c, 1.0, epsilon = 1e-5);
}
#[test]
fn test_minimax_empty_matrix() {
let payoff: Array2<f64> = Array2::zeros((0, 2));
assert!(minimax_solve(payoff.view()).is_err());
}
#[test]
fn test_saddle_point_1x1() {
let payoff = array![[5.0]];
let saddle = find_saddle_point(payoff.view());
assert!(saddle.is_some());
let (r, c, v) = saddle.expect("saddle should not be None/Err");
assert_eq!(r, 0);
assert_eq!(c, 0);
assert_relative_eq!(v, 5.0);
}
#[test]
fn test_minimax_solve_3x3_mixed() {
let payoff = array![[2.0, -1.0, 0.0], [-1.0, 2.0, 0.0], [0.0, 0.0, 1.0]];
let result = minimax_solve(payoff.view()).expect("solve");
let sum_row: f64 = result.row_player_strategy.iter().sum();
let sum_col: f64 = result.col_player_strategy.iter().sum();
assert_relative_eq!(sum_row, 1.0, epsilon = 1e-4);
assert_relative_eq!(sum_col, 1.0, epsilon = 1e-4);
}
}