use anyhow::anyhow;
use core::num::NonZero;
use dds_bridge::solver::Vulnerability;
use dds_bridge::{Bid, Contract, Penalty, Seat, Strain};
use pons::stats::{Accumulator, HistogramRow, HistogramTable, Statistics, average_ns_par};
#[test]
fn test_statistics_new() {
let s = Statistics::new(3.0, 1.5);
assert!(s.mean().eq(&3.0));
assert!(s.sd().eq(&1.5));
}
#[test]
fn test_statistics_default() {
let s = Statistics::default();
assert!(s.mean().eq(&0.0));
assert!(s.sd().eq(&0.0));
}
#[test]
fn test_statistics_display() {
let s = Statistics::new(3.0, 1.5);
let display = format!("{s}");
assert!(display.contains('3'));
assert!(display.contains("±"));
assert!(display.contains("1.5"));
}
#[test]
fn test_accumulator_new() {
let acc = Accumulator::new();
assert_eq!(acc.count(), 0);
assert!(acc.mean().eq(&0.0));
assert!(acc.sdm().eq(&0.0));
}
#[test]
fn test_accumulator_default() {
assert_eq!(Accumulator::default(), Accumulator::new());
}
#[test]
fn test_accumulator_single_value() {
let mut acc = Accumulator::new();
acc.push(5.0);
assert_eq!(acc.count(), 1);
assert!(acc.mean().eq(&5.0));
assert!(acc.sdm().eq(&0.0));
}
#[test]
fn test_accumulator_two_values() {
let mut acc = Accumulator::new();
acc.push(2.0);
acc.push(4.0);
assert_eq!(acc.count(), 2);
assert!(acc.mean().eq(&3.0));
assert!((acc.sdm() - 2.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_population() {
let mut acc = Accumulator::new();
acc.push(2.0);
acc.push(4.0);
let pop = acc.population();
assert!(pop.mean().eq(&3.0));
assert!((pop.sd() - 1.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_sample() {
let mut acc = Accumulator::new();
acc.push(2.0);
acc.push(4.0);
let samp = acc.sample();
assert!(samp.mean().eq(&3.0));
assert!((samp.sd() - std::f64::consts::SQRT_2).abs() < 1e-10);
}
#[test]
fn test_accumulator_empty_population() {
let acc = Accumulator::new();
let pop = acc.population();
assert!(pop.mean().is_nan());
assert!(pop.sd().is_nan());
}
#[test]
fn test_accumulator_empty_sample() {
let acc = Accumulator::new();
let samp = acc.sample();
assert!(samp.mean().is_nan());
assert!(samp.sd().is_nan());
}
#[test]
fn test_accumulator_single_sample_sd_nan() {
let mut acc = Accumulator::new();
acc.push(5.0);
let samp = acc.sample();
assert!(samp.mean().eq(&5.0));
assert!(samp.sd().is_nan());
}
#[test]
fn test_accumulator_many_values() {
let mut acc = Accumulator::new();
for i in 1..=100 {
acc.push(f64::from(i));
}
assert_eq!(acc.count(), 100);
assert!((acc.mean() - 50.5).abs() < 1e-10);
}
#[test]
fn test_histogram_row_new() {
let row = HistogramRow::new();
assert!(row.count().is_none());
}
#[test]
fn test_histogram_row_index() {
let mut row = HistogramRow::new();
row[Strain::Notrump][7] = 5;
assert_eq!(row[Strain::Notrump][7], 5);
assert_eq!(row[Strain::Clubs][7], 0);
}
#[test]
fn test_histogram_row_count() {
let mut row = HistogramRow::new();
row[Strain::Clubs][6] = 3;
assert_eq!(row.count().map(NonZero::get), Some(3));
}
#[test]
fn test_histogram_table_new() {
let table = HistogramTable::new();
assert!(table.count().is_none());
}
#[test]
fn test_histogram_table_index() {
let mut table = HistogramTable::new();
table[Seat::North][Strain::Clubs][7] = 10;
assert_eq!(table[Seat::North][Strain::Clubs][7], 10);
assert_eq!(table[Seat::South][Strain::Clubs][7], 0);
}
fn uniform_hist(tricks: u8) -> HistogramTable {
let mut hist = HistogramTable::new();
for seat in Seat::ALL {
for strain in Strain::ASC {
hist[seat][strain][usize::from(tricks)] = 1;
}
}
hist
}
#[test]
fn test_par_empty_histogram_returns_none() {
let hist = HistogramTable::new();
assert_eq!(average_ns_par(hist, Vulnerability::NONE, Seat::North), None);
}
#[test]
fn test_par_pass_out_when_every_contract_loses() -> anyhow::Result<()> {
let hist = uniform_hist(6);
let par = average_ns_par(hist, Vulnerability::NONE, Seat::North)
.ok_or_else(|| anyhow!("expected a par result"))?;
assert!(par.score.eq(&0.0));
assert_eq!(par.contract, None);
Ok(())
}
#[test]
fn test_par_pass_out_independent_of_dealer() -> anyhow::Result<()> {
let hist = uniform_hist(6);
for dealer in Seat::ALL {
let par = average_ns_par(hist, Vulnerability::NONE, dealer)
.ok_or_else(|| anyhow!("expected a par result for dealer {dealer:?}"))?;
assert!(par.score.eq(&0.0), "dealer {dealer:?}");
assert_eq!(par.contract, None, "dealer {dealer:?}");
}
Ok(())
}
#[test]
fn test_par_ns_partial_1nt() -> anyhow::Result<()> {
let mut hist = uniform_hist(6);
hist[Seat::North][Strain::Notrump] = [0; 14];
hist[Seat::North][Strain::Notrump][7] = 1;
hist[Seat::South][Strain::Notrump] = [0; 14];
hist[Seat::South][Strain::Notrump][7] = 1;
let par = average_ns_par(hist, Vulnerability::NONE, Seat::North)
.ok_or_else(|| anyhow!("expected a par result"))?;
let one_nt = Contract {
bid: Bid::new(1, Strain::Notrump),
penalty: Penalty::Undoubled,
};
let expected = f64::from(one_nt.score(7, false));
assert!(par.score.eq(&expected));
let (contract, declarer) = par.contract.expect("expected a par contract");
assert_eq!(contract, one_nt);
assert!(matches!(declarer, Seat::North | Seat::South));
Ok(())
}
#[test]
fn test_par_ns_game_4h_vul() -> anyhow::Result<()> {
let mut hist = uniform_hist(6);
hist[Seat::North][Strain::Hearts] = [0; 14];
hist[Seat::North][Strain::Hearts][10] = 1;
hist[Seat::South][Strain::Hearts] = [0; 14];
hist[Seat::South][Strain::Hearts][10] = 1;
let par = average_ns_par(hist, Vulnerability::NS, Seat::North)
.ok_or_else(|| anyhow!("expected a par result"))?;
let four_h = Contract {
bid: Bid::new(4, Strain::Hearts),
penalty: Penalty::Undoubled,
};
let expected = f64::from(four_h.score(10, true));
assert!(par.score.eq(&expected));
let (contract, declarer) = par.contract.expect("expected a par contract");
assert_eq!(contract, four_h);
assert!(matches!(declarer, Seat::North | Seat::South));
Ok(())
}
#[test]
fn test_par_ew_sacrifice_against_vulnerable_game() -> anyhow::Result<()> {
let mut hist = uniform_hist(6);
hist[Seat::North][Strain::Hearts] = [0; 14];
hist[Seat::North][Strain::Hearts][10] = 1;
hist[Seat::South][Strain::Hearts] = [0; 14];
hist[Seat::South][Strain::Hearts][10] = 1;
hist[Seat::East][Strain::Spades] = [0; 14];
hist[Seat::East][Strain::Spades][9] = 1;
hist[Seat::West][Strain::Spades] = [0; 14];
hist[Seat::West][Strain::Spades][9] = 1;
let par = average_ns_par(hist, Vulnerability::NS, Seat::North)
.ok_or_else(|| anyhow!("expected a par result"))?;
let four_sx = Contract {
bid: Bid::new(4, Strain::Spades),
penalty: Penalty::Doubled,
};
let expected = -f64::from(four_sx.score(9, false));
assert!(par.score.eq(&expected));
let (contract, declarer) = par.contract.expect("expected a par contract");
assert_eq!(contract, four_sx);
assert!(matches!(declarer, Seat::East | Seat::West));
Ok(())
}
#[test]
fn test_par_count_averages_across_deals() -> anyhow::Result<()> {
let mut hist = uniform_hist(6); hist[Seat::North][Strain::Notrump][7] = 1;
hist[Seat::South][Strain::Notrump][7] = 1;
for seat in Seat::ALL {
for strain in Strain::ASC {
if !(matches!(seat, Seat::North | Seat::South) && strain == Strain::Notrump) {
hist[seat][strain][6] = 2;
}
}
}
hist[Seat::East][Strain::Notrump][6] = 2;
hist[Seat::West][Strain::Notrump][6] = 2;
assert_eq!(hist.count().map(NonZero::get), Some(2));
let par = average_ns_par(hist, Vulnerability::NONE, Seat::North)
.ok_or_else(|| anyhow!("expected a par result"))?;
let one_nt = Contract {
bid: Bid::new(1, Strain::Notrump),
penalty: Penalty::Undoubled,
};
let one_ntx = Contract {
bid: Bid::new(1, Strain::Notrump),
penalty: Penalty::Doubled,
};
let undoubled = one_nt.score(7, false) + one_nt.score(6, false);
let doubled = one_ntx.score(7, false) + one_ntx.score(6, false);
let expected = f64::from(undoubled.min(doubled)) * 0.5;
let (contract, _) = par.contract.expect("expected a par contract");
assert_eq!(contract, one_nt);
assert!(
(par.score - expected).abs() < 1e-9,
"got {} vs {expected}",
par.score
);
Ok(())
}