use serde::{Deserialize, Serialize};
use crate::error::{Result, SanghaError, validate_finite};
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub struct CoalitionGame {
pub player_count: usize,
pub values: Vec<f64>,
}
impl<'de> Deserialize<'de> for CoalitionGame {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> core::result::Result<Self, D::Error> {
#[derive(Deserialize)]
struct Raw {
player_count: usize,
values: Vec<f64>,
}
let raw = Raw::deserialize(deserializer)?;
CoalitionGame::new(raw.player_count, raw.values).map_err(serde::de::Error::custom)
}
}
impl CoalitionGame {
pub fn new(player_count: usize, values: Vec<f64>) -> Result<Self> {
if player_count > 20 {
return Err(SanghaError::ComputationError(
"player_count must be <= 20 (bitmask limit)".into(),
));
}
let expected_len = 1 << player_count;
if values.len() != expected_len {
return Err(SanghaError::ComputationError(format!(
"values length {} != 2^{player_count} = {expected_len}",
values.len()
)));
}
for (i, &v) in values.iter().enumerate() {
validate_finite(v, &format!("values[{i}]"))?;
}
Ok(Self {
player_count,
values,
})
}
pub fn validate(&self) -> Result<()> {
if self.player_count > 20 {
return Err(SanghaError::ComputationError(
"player_count must be <= 20 (bitmask limit)".into(),
));
}
let expected_len = 1 << self.player_count;
if self.values.len() != expected_len {
return Err(SanghaError::ComputationError(format!(
"values length {} != 2^{} = {expected_len}",
self.values.len(),
self.player_count
)));
}
for (i, &v) in self.values.iter().enumerate() {
validate_finite(v, &format!("values[{i}]"))?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ShapleyValues {
pub values: Vec<f64>,
}
impl ShapleyValues {
#[inline]
#[must_use]
pub fn new(values: Vec<f64>) -> Self {
Self { values }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum StabilityStatus {
Stable,
Unstable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CoalitionStructure {
pub coalitions: Vec<Vec<usize>>,
}
impl CoalitionStructure {
#[inline]
#[must_use]
pub fn new(coalitions: Vec<Vec<usize>>) -> Self {
Self { coalitions }
}
}
#[must_use = "returns the Shapley values without side effects"]
pub fn shapley_value(game: &CoalitionGame) -> Result<ShapleyValues> {
let n = game.player_count;
if n == 0 {
return Ok(ShapleyValues::new(vec![]));
}
let mut factorial = vec![1.0_f64; n + 1];
for i in 1..=n {
factorial[i] = factorial[i - 1] * i as f64;
}
let n_fact = factorial[n];
let mut phi = vec![0.0; n];
for (i, phi_i) in phi.iter_mut().enumerate() {
let i_bit = 1usize << i;
let mask_without_i = ((1usize << n) - 1) ^ i_bit;
let mut s = 0usize;
loop {
let s_size = s.count_ones() as usize;
let weight = factorial[s_size] * factorial[n - s_size - 1] / n_fact;
let marginal = game.values[s | i_bit] - game.values[s];
*phi_i += weight * marginal;
if s == mask_without_i {
break;
}
s = (s.wrapping_sub(mask_without_i)) & mask_without_i;
}
}
Ok(ShapleyValues::new(phi))
}
#[inline]
#[must_use = "returns the coalition value without side effects"]
pub fn coalition_value(game: &CoalitionGame, members: &[usize]) -> Result<f64> {
let mut mask = 0usize;
for &m in members {
if m >= game.player_count {
return Err(SanghaError::ComputationError(format!(
"player index {m} out of bounds for {} players",
game.player_count
)));
}
mask |= 1 << m;
}
Ok(game.values[mask])
}
#[must_use = "returns the stability status without side effects"]
pub fn is_core_stable(game: &CoalitionGame, allocation: &[f64]) -> Result<StabilityStatus> {
let n = game.player_count;
if allocation.len() != n {
return Err(SanghaError::ComputationError(format!(
"allocation length {} != player_count {n}",
allocation.len()
)));
}
for (i, &a) in allocation.iter().enumerate() {
validate_finite(a, &format!("allocation[{i}]"))?;
}
let grand_mask = (1usize << n) - 1;
let grand_value = game.values[grand_mask];
let alloc_sum: f64 = allocation.iter().sum();
if (alloc_sum - grand_value).abs() > 1e-9 {
return Ok(StabilityStatus::Unstable);
}
for mask in 1..=grand_mask {
let coalition_alloc: f64 = (0..n)
.filter(|&i| mask & (1 << i) != 0)
.map(|i| allocation[i])
.sum();
if coalition_alloc < game.values[mask] - 1e-9 {
return Ok(StabilityStatus::Unstable);
}
}
Ok(StabilityStatus::Stable)
}
#[must_use = "returns the new coalition structure without side effects"]
pub fn merge_coalitions(
structure: &CoalitionStructure,
i: usize,
j: usize,
) -> Result<CoalitionStructure> {
if i >= structure.coalitions.len() || j >= structure.coalitions.len() {
return Err(SanghaError::ComputationError(
"coalition index out of bounds".into(),
));
}
if i == j {
return Err(SanghaError::ComputationError(
"cannot merge a coalition with itself".into(),
));
}
let (lo, hi) = if i < j { (i, j) } else { (j, i) };
let mut new_coalitions: Vec<Vec<usize>> = structure.coalitions.clone();
let removed = new_coalitions.remove(hi);
new_coalitions[lo].extend(removed);
Ok(CoalitionStructure::new(new_coalitions))
}
#[must_use = "returns the new coalition structure without side effects"]
pub fn split_coalition(
structure: &CoalitionStructure,
coalition_idx: usize,
split_point: usize,
) -> Result<CoalitionStructure> {
if coalition_idx >= structure.coalitions.len() {
return Err(SanghaError::ComputationError(
"coalition index out of bounds".into(),
));
}
let coalition = &structure.coalitions[coalition_idx];
if split_point == 0 || split_point >= coalition.len() {
return Err(SanghaError::ComputationError(
"split_point must be in (0, coalition.len())".into(),
));
}
let mut new_coalitions = structure.coalitions.clone();
let right = new_coalitions[coalition_idx].split_off(split_point);
new_coalitions.push(right);
Ok(CoalitionStructure::new(new_coalitions))
}
#[cfg(test)]
mod tests {
use super::*;
fn majority_game_3() -> CoalitionGame {
let mut values = vec![0.0; 8]; values[0b011] = 1.0; values[0b101] = 1.0; values[0b110] = 1.0; values[0b111] = 1.0; CoalitionGame::new(3, values).unwrap()
}
#[test]
fn test_shapley_majority_game() {
let game = majority_game_3();
let sv = shapley_value(&game).unwrap();
for &v in &sv.values {
assert!((v - 1.0 / 3.0).abs() < 1e-10);
}
}
#[test]
fn test_shapley_efficiency() {
let game = majority_game_3();
let sv = shapley_value(&game).unwrap();
let sum: f64 = sv.values.iter().sum();
let grand = game.values[(1 << game.player_count) - 1];
assert!((sum - grand).abs() < 1e-10);
}
#[test]
fn test_shapley_single_player() {
let game = CoalitionGame::new(1, vec![0.0, 5.0]).unwrap();
let sv = shapley_value(&game).unwrap();
assert!((sv.values[0] - 5.0).abs() < 1e-10);
}
#[test]
fn test_shapley_empty_game() {
let game = CoalitionGame::new(0, vec![0.0]).unwrap();
let sv = shapley_value(&game).unwrap();
assert!(sv.values.is_empty());
}
#[test]
fn test_coalition_value_lookup() {
let game = majority_game_3();
let v = coalition_value(&game, &[0, 1]).unwrap();
assert!((v - 1.0).abs() < 1e-10);
}
#[test]
fn test_coalition_value_empty() {
let game = majority_game_3();
let v = coalition_value(&game, &[]).unwrap();
assert!((v - 0.0).abs() < 1e-10);
}
#[test]
fn test_coalition_value_out_of_bounds() {
let game = majority_game_3();
assert!(coalition_value(&game, &[5]).is_err());
}
#[test]
fn test_core_stable_equal_split() {
let game = majority_game_3();
let alloc = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let status = is_core_stable(&game, &alloc).unwrap();
assert_eq!(status, StabilityStatus::Unstable);
}
#[test]
fn test_core_stable_superadditive() {
let mut values = vec![0.0; 8];
values[0b001] = 1.0;
values[0b010] = 2.0;
values[0b100] = 3.0;
values[0b011] = 3.0;
values[0b101] = 4.0;
values[0b110] = 5.0;
values[0b111] = 6.0;
let game = CoalitionGame::new(3, values).unwrap();
let alloc = vec![1.0, 2.0, 3.0];
let status = is_core_stable(&game, &alloc).unwrap();
assert_eq!(status, StabilityStatus::Stable);
}
#[test]
fn test_core_stable_wrong_total() {
let game = majority_game_3();
let alloc = vec![0.5, 0.5, 0.5]; let status = is_core_stable(&game, &alloc).unwrap();
assert_eq!(status, StabilityStatus::Unstable);
}
#[test]
fn test_core_stable_wrong_length() {
let game = majority_game_3();
assert!(is_core_stable(&game, &[0.5, 0.5]).is_err());
}
#[test]
fn test_merge_coalitions() {
let cs = CoalitionStructure::new(vec![vec![0, 1], vec![2, 3], vec![4]]);
let merged = merge_coalitions(&cs, 0, 2).unwrap();
assert_eq!(merged.coalitions.len(), 2);
assert_eq!(merged.coalitions[0], vec![0, 1, 4]);
assert_eq!(merged.coalitions[1], vec![2, 3]);
}
#[test]
fn test_merge_coalitions_same_index_error() {
let cs = CoalitionStructure::new(vec![vec![0], vec![1]]);
assert!(merge_coalitions(&cs, 0, 0).is_err());
}
#[test]
fn test_merge_coalitions_out_of_bounds() {
let cs = CoalitionStructure::new(vec![vec![0]]);
assert!(merge_coalitions(&cs, 0, 5).is_err());
}
#[test]
fn test_split_coalition() {
let cs = CoalitionStructure::new(vec![vec![0, 1, 2, 3]]);
let split = split_coalition(&cs, 0, 2).unwrap();
assert_eq!(split.coalitions.len(), 2);
assert_eq!(split.coalitions[0], vec![0, 1]);
assert_eq!(split.coalitions[1], vec![2, 3]);
}
#[test]
fn test_split_coalition_boundary_error() {
let cs = CoalitionStructure::new(vec![vec![0, 1, 2]]);
assert!(split_coalition(&cs, 0, 0).is_err()); assert!(split_coalition(&cs, 0, 3).is_err()); }
#[test]
fn test_game_too_many_players() {
assert!(CoalitionGame::new(21, vec![0.0; 1 << 21]).is_err());
}
#[test]
fn test_game_wrong_values_length() {
assert!(CoalitionGame::new(3, vec![0.0; 5]).is_err());
}
#[test]
fn test_coalition_game_serde_roundtrip() {
let game = majority_game_3();
let json = serde_json::to_string(&game).unwrap();
let back: CoalitionGame = serde_json::from_str(&json).unwrap();
assert_eq!(game.player_count, back.player_count);
assert_eq!(game.values, back.values);
}
#[test]
fn test_shapley_values_serde_roundtrip() {
let sv = ShapleyValues::new(vec![1.0, 2.0, 3.0]);
let json = serde_json::to_string(&sv).unwrap();
let back: ShapleyValues = serde_json::from_str(&json).unwrap();
assert_eq!(sv.values, back.values);
}
#[test]
fn test_stability_status_serde_roundtrip() {
let s = StabilityStatus::Stable;
let json = serde_json::to_string(&s).unwrap();
let back: StabilityStatus = serde_json::from_str(&json).unwrap();
assert_eq!(s, back);
}
#[test]
fn test_coalition_structure_serde_roundtrip() {
let cs = CoalitionStructure::new(vec![vec![0, 1], vec![2]]);
let json = serde_json::to_string(&cs).unwrap();
let back: CoalitionStructure = serde_json::from_str(&json).unwrap();
assert_eq!(cs.coalitions, back.coalitions);
}
#[test]
fn test_shapley_asymmetric_dictator() {
let mut values = vec![0.0; 8];
values[0b001] = 1.0; values[0b011] = 1.0; values[0b101] = 1.0; values[0b111] = 1.0; let game = CoalitionGame::new(3, values).unwrap();
let sv = shapley_value(&game).unwrap();
assert!((sv.values[0] - 1.0).abs() < 1e-10); assert!((sv.values[1] - 0.0).abs() < 1e-10);
assert!((sv.values[2] - 0.0).abs() < 1e-10);
}
#[test]
fn test_core_stable_nan_allocation_error() {
let game = majority_game_3();
assert!(is_core_stable(&game, &[f64::NAN, 0.0, 0.0]).is_err());
}
#[test]
fn test_game_nan_values_error() {
assert!(CoalitionGame::new(1, vec![0.0, f64::NAN]).is_err());
}
#[test]
fn test_coalition_game_deserialize_rejects_invalid() {
let json = r#"{"player_count":2,"values":[0.0,1.0]}"#;
let result: core::result::Result<CoalitionGame, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}