use crate::convert::dds_suit_from_cb;
use crate::pos::Pos;
use crate::quick_tricks::{MAXNODE, MINNODE};
use crate::search::Engine;
use crate::tt::TransTable;
use contract_bridge::{FullDeal, Seat, Strain, Suit};
use std::sync::OnceLock;
const STRAINS: [Strain; 5] = Strain::ASC;
const SEATS: [Seat; 4] = Seat::ALL;
fn pos_from_deal(deal: &FullDeal) -> Pos {
let mut pos = Pos::default();
for (h, seat) in SEATS.iter().enumerate() {
let cb_hand = deal[*seat];
for cb_suit in Suit::ASC {
let bits = cb_hand[cb_suit].to_bits() >> 2;
pos.rank_in_suit[h][dds_suit_from_cb(cb_suit)] = bits;
}
}
pos
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct TrickCountTable {
pub tricks: [[u8; 4]; 5],
}
impl TrickCountTable {
#[inline]
#[must_use]
pub const fn get(&self, strain: Strain, seat: Seat) -> u8 {
self.tricks[strain as usize][seat as usize]
}
}
pub struct Solver {
engine: Engine,
tt: TransTable,
}
impl Solver {
#[must_use]
pub fn new(strain: Strain) -> Self {
Self {
engine: Engine::new(strain),
tt: TransTable::new(),
}
}
#[must_use]
pub fn with_memory(strain: Strain, default_mb: u32, max_mb: u32) -> Self {
Self {
engine: Engine::new(strain),
tt: TransTable::with_memory(default_mb, max_mb),
}
}
pub fn set_strain(&mut self, strain: Strain) {
self.engine.set_strain(strain);
}
#[must_use]
pub fn solve(&mut self, deal: FullDeal) -> [u8; 4] {
const INI_DEPTH: i32 = 48;
self.tt.reset();
let mut row = [0u8; 4];
for (seat_idx, declarer) in SEATS.iter().enumerate() {
let leader = declarer.lho() as usize;
let node_types = if matches!(declarer, Seat::North | Seat::South) {
[MAXNODE, MINNODE, MAXNODE, MINNODE]
} else {
[MINNODE, MAXNODE, MINNODE, MAXNODE]
};
self.engine.set_node_types(node_types);
let mut pos = pos_from_deal(&deal);
pos.first[INI_DEPTH as usize] = leader as i32;
self.engine.set_deal(&mut pos, &mut self.tt);
let tricks = self.engine.search_target(&mut pos, &mut self.tt, INI_DEPTH);
debug_assert!((0..=13).contains(&tricks), "tricks out of range");
row[seat_idx] = tricks as u8;
}
row
}
}
impl Solver {
#[inline]
#[must_use]
pub const fn bisection_stats(&self) -> (u64, u64) {
(self.engine.search_target_calls, self.engine.bisection_iters)
}
#[inline]
pub const fn reset_bisection_stats(&mut self) {
self.engine.search_target_calls = 0;
self.engine.bisection_iters = 0;
self.engine.iter1_nanos = 0;
self.engine.later_nanos = 0;
}
#[inline]
#[must_use]
pub const fn bisection_timing(&self) -> (u128, u128) {
(self.engine.iter1_nanos, self.engine.later_nanos)
}
#[inline]
#[must_use]
pub const fn search_stats(&self) -> crate::search::SearchStats {
self.engine.stats
}
#[inline]
pub fn reset_search_stats(&mut self) {
self.engine.stats = crate::search::SearchStats::default();
}
}
impl Default for Solver {
#[inline]
fn default() -> Self {
Self::new(Strain::Notrump)
}
}
const SOLVER_STACK_SIZE: usize = 16 * 1024 * 1024;
fn solver_pool() -> &'static rayon::ThreadPool {
static POOL: OnceLock<rayon::ThreadPool> = OnceLock::new();
POOL.get_or_init(|| {
rayon::ThreadPoolBuilder::new()
.stack_size(SOLVER_STACK_SIZE)
.thread_name(|i| format!("pons-dds-solver-{i}"))
.build()
.expect("failed to build pons-dds solver thread pool")
})
}
const fn dispatch_first(strain_idx: usize) -> bool {
matches!(STRAINS[strain_idx], Strain::Notrump)
}
fn solve_deals_pooled(deals: &[FullDeal], default_mb: u32, max_mb: u32) -> Vec<TrickCountTable> {
use rayon::iter::ParallelIterator;
use rayon::slice::ParallelSlice;
use std::cell::RefCell;
thread_local! {
static SOLVER: RefCell<Option<(u32, u32, Solver)>> = const { RefCell::new(None) };
}
let mut tasks: Vec<(usize, usize)> = (0..deals.len())
.flat_map(|d| (0..STRAINS.len()).map(move |s| (d, s)))
.collect();
tasks.sort_by_key(|&(_, s)| core::cmp::Reverse(dispatch_first(s)));
let pool = solver_pool();
let target_chunks = pool.current_num_threads().saturating_mul(8).max(1);
let chunk_size = tasks.len().div_ceil(target_chunks).max(1);
let collected: Vec<Vec<(usize, usize, [u8; 4])>> = pool.install(|| {
tasks
.par_chunks(chunk_size)
.map(|chunk| {
SOLVER.with(|cell| {
let mut slot = cell.borrow_mut();
if !matches!(slot.as_ref(), Some(&(d_mb, m_mb, _)) if d_mb == default_mb && m_mb == max_mb)
{
*slot = None;
}
let solver = &mut slot
.get_or_insert_with(|| {
(
default_mb,
max_mb,
Solver::with_memory(Strain::Notrump, default_mb, max_mb),
)
})
.2;
let mut rows = Vec::with_capacity(chunk.len());
for &(d, s) in chunk {
solver.set_strain(STRAINS[s]);
rows.push((d, s, solver.solve(deals[d])));
}
rows
})
})
.collect()
});
let mut tables = vec![TrickCountTable::default(); deals.len()];
for (d, s, row) in collected.into_iter().flatten() {
tables[d].tricks[s] = row;
}
tables
}
#[must_use]
pub fn solve_deals(deals: &[FullDeal]) -> Vec<TrickCountTable> {
solve_deals_pooled(
deals,
crate::tt::DEFAULT_MEMORY_MB,
crate::tt::MAX_MEMORY_MB,
)
}
#[must_use]
pub fn solve_deals_with_memory(
deals: &[FullDeal],
default_mb: u32,
max_mb: u32,
) -> Vec<TrickCountTable> {
solve_deals_pooled(deals, default_mb, max_mb)
}
#[must_use]
pub fn solve_deal(deal: FullDeal) -> TrickCountTable {
solve_deals(std::slice::from_ref(&deal))
.pop()
.unwrap_or_default()
}
#[must_use]
pub fn solve_deal_on(solver: &mut Solver, deal: FullDeal) -> TrickCountTable {
let mut table = TrickCountTable::default();
for (i, strain) in STRAINS.iter().enumerate() {
solver.set_strain(*strain);
table.tricks[i] = solver.solve(deal);
}
table
}
#[cfg(test)]
mod tests {
use super::*;
use contract_bridge::deal::Builder;
use contract_bridge::hand::{Hand, Holding};
fn solve_deal_sequential(deal: FullDeal) -> TrickCountTable {
solve_deal_on(&mut Solver::new(Strain::Notrump), deal)
}
fn each_hand_holds_one_suit_deal() -> FullDeal {
let full = Holding::ALL;
let empty = Holding::EMPTY;
let n_hand = Hand::new(empty, empty, empty, full); let e_hand = Hand::new(empty, empty, full, empty); let s_hand = Hand::new(empty, full, empty, empty); let w_hand = Hand::new(full, empty, empty, empty);
Builder::new()
.north(n_hand)
.east(e_hand)
.south(s_hand)
.west(w_hand)
.build_full()
.expect("each-suit fixture should be a valid full deal")
}
#[test]
fn pos_from_deal_each_hand_one_suit() {
const DDS_ALL: u16 = 0x1FFF;
let deal = each_hand_holds_one_suit_deal();
let pos = pos_from_deal(&deal);
assert_eq!(pos.rank_in_suit[0][0], DDS_ALL);
assert_eq!(pos.rank_in_suit[0][1], 0);
assert_eq!(pos.rank_in_suit[0][2], 0);
assert_eq!(pos.rank_in_suit[0][3], 0);
assert_eq!(pos.rank_in_suit[1][1], DDS_ALL);
assert_eq!(pos.rank_in_suit[2][2], DDS_ALL);
assert_eq!(pos.rank_in_suit[3][3], DDS_ALL);
}
#[test]
fn solve_deal_each_hand_one_suit_notrump() {
let deal = each_hand_holds_one_suit_deal();
let table = solve_deal_sequential(deal);
for seat in Seat::ALL {
assert_eq!(
table.get(Strain::Notrump, seat),
0,
"declarer {seat} at NT should make 0 tricks (LHO runs their suit)"
);
}
}
#[test]
fn solve_deal_each_hand_one_suit_trump_tables() {
let deal = each_hand_holds_one_suit_deal();
let table = solve_deal_sequential(deal);
let cases = [
(Strain::Spades, 13, 0), (Strain::Hearts, 0, 13), (Strain::Diamonds, 13, 0), (Strain::Clubs, 0, 13), ];
for (strain, ns, ew) in cases {
assert_eq!(table.get(strain, Seat::North), ns, "N declaring {strain}");
assert_eq!(table.get(strain, Seat::South), ns, "S declaring {strain}");
assert_eq!(table.get(strain, Seat::East), ew, "E declaring {strain}");
assert_eq!(table.get(strain, Seat::West), ew, "W declaring {strain}");
}
}
#[test]
fn solve_deals_matches_single_deal_solver() {
let deal_a = each_hand_holds_one_suit_deal();
let deals = vec![deal_a, deal_a];
let expected_a = solve_deal_sequential(deal_a);
let parallel = solve_deals(&deals);
assert_eq!(parallel.len(), 2);
assert_eq!(parallel[0], expected_a);
assert_eq!(parallel[1], expected_a);
}
#[test]
fn solve_deal_matches_single_deal_solver() {
let deal = each_hand_holds_one_suit_deal();
assert_eq!(solve_deal(deal), solve_deal_sequential(deal));
}
#[test]
fn solve_deal_matches_reference_pbn() {
let pbn = "N:.63.AKQ987.A9732 A8654.KQ5.T.QJT6 \
J973.J98742.3.K4 KQT2.AT.J6542.85";
let deal: FullDeal = pbn.parse().expect("reference PBN parses");
let got = solve_deal_sequential(deal);
let expected = TrickCountTable {
tricks: [
[8, 5, 8, 5], [8, 5, 8, 5], [6, 5, 6, 6], [4, 9, 4, 9], [5, 8, 5, 8], ],
};
assert_eq!(got, expected, "DD table mismatch for reference deal");
}
}