use crate::contract::{Contract, Penalty};
use crate::deal::{Card, Deal, Holding, Rank, Seat};
use crate::{Strain, Suit};
use arrayvec::ArrayVec;
use core::ffi::c_int;
use core::fmt;
use core::ops::BitOr as _;
use core::str::FromStr;
use dds_bridge_sys as sys;
use parking_lot::Mutex;
use std::sync::LazyLock;
use thiserror::Error;
const fn check(status: i32) {
let msg: &[u8] = match status {
0.. => return,
sys::RETURN_ZERO_CARDS => sys::TEXT_ZERO_CARDS,
sys::RETURN_TARGET_TOO_HIGH => sys::TEXT_TARGET_TOO_HIGH,
sys::RETURN_DUPLICATE_CARDS => sys::TEXT_DUPLICATE_CARDS,
sys::RETURN_TARGET_WRONG_LO => sys::TEXT_TARGET_WRONG_LO,
sys::RETURN_TARGET_WRONG_HI => sys::TEXT_TARGET_WRONG_HI,
sys::RETURN_SOLNS_WRONG_LO => sys::TEXT_SOLNS_WRONG_LO,
sys::RETURN_SOLNS_WRONG_HI => sys::TEXT_SOLNS_WRONG_HI,
sys::RETURN_TOO_MANY_CARDS => sys::TEXT_TOO_MANY_CARDS,
sys::RETURN_SUIT_OR_RANK => sys::TEXT_SUIT_OR_RANK,
sys::RETURN_PLAYED_CARD => sys::TEXT_PLAYED_CARD,
sys::RETURN_CARD_COUNT => sys::TEXT_CARD_COUNT,
sys::RETURN_THREAD_INDEX => sys::TEXT_THREAD_INDEX,
sys::RETURN_MODE_WRONG_LO => sys::TEXT_MODE_WRONG_LO,
sys::RETURN_MODE_WRONG_HI => sys::TEXT_MODE_WRONG_HI,
sys::RETURN_TRUMP_WRONG => sys::TEXT_TRUMP_WRONG,
sys::RETURN_FIRST_WRONG => sys::TEXT_FIRST_WRONG,
sys::RETURN_PLAY_FAULT => sys::TEXT_PLAY_FAULT,
sys::RETURN_PBN_FAULT => sys::TEXT_PBN_FAULT,
sys::RETURN_TOO_MANY_BOARDS => sys::TEXT_TOO_MANY_BOARDS,
sys::RETURN_THREAD_CREATE => sys::TEXT_THREAD_CREATE,
sys::RETURN_THREAD_WAIT => sys::TEXT_THREAD_WAIT,
sys::RETURN_THREAD_MISSING => sys::TEXT_THREAD_MISSING,
sys::RETURN_NO_SUIT => sys::TEXT_NO_SUIT,
sys::RETURN_TOO_MANY_TABLES => sys::TEXT_TOO_MANY_TABLES,
sys::RETURN_CHUNK_SIZE => sys::TEXT_CHUNK_SIZE,
_ => sys::TEXT_UNKNOWN_FAULT,
};
panic!("{}", unsafe { core::str::from_utf8_unchecked(msg) });
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StrainFlags : u8 {
const CLUBS = 0x01;
const DIAMONDS = 0x02;
const HEARTS = 0x04;
const SPADES = 0x08;
const NOTRUMP = 0x10;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[repr(transparent)]
pub struct TricksRow(u16);
impl TricksRow {
#[must_use]
pub const fn new(n: u8, e: u8, s: u8, w: u8) -> Self {
Self(
(n as u16) << (4 * Seat::North as u8)
| (e as u16) << (4 * Seat::East as u8)
| (s as u16) << (4 * Seat::South as u8)
| (w as u16) << (4 * Seat::West as u8),
)
}
#[must_use]
pub const fn get(self, seat: Seat) -> u8 {
(self.0 >> (4 * seat as u8) & 0xF) as u8
}
#[must_use]
pub const fn hex(self, seat: Seat) -> TricksRowHex {
TricksRowHex { row: self, seat }
}
}
#[derive(Debug, Clone, Copy)]
pub struct TricksRowHex {
row: TricksRow,
seat: Seat,
}
impl fmt::UpperHex for TricksRowHex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{:X}{:X}{:X}{:X}",
self.row.get(self.seat),
self.row.get(self.seat.lho()),
self.row.get(self.seat.partner()),
self.row.get(self.seat.rho()),
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[repr(transparent)]
pub struct TricksTable(pub [TricksRow; 5]);
impl core::ops::Index<Strain> for TricksTable {
type Output = TricksRow;
fn index(&self, strain: Strain) -> &TricksRow {
&self.0[strain as usize]
}
}
impl TricksTable {
#[must_use]
pub const fn hex<T: AsRef<[Strain]>>(self, seat: Seat, strains: T) -> TricksTableHex<T> {
TricksTableHex {
table: self,
seat,
strains,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct TricksTableHex<T: AsRef<[Strain]>> {
table: TricksTable,
seat: Seat,
strains: T,
}
impl<T: AsRef<[Strain]>> fmt::UpperHex for TricksTableHex<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for &strain in self.strains.as_ref() {
self.table[strain].hex(self.seat).fmt(f)?;
}
Ok(())
}
}
impl Strain {
#[must_use]
const fn to_sys(self) -> usize {
match self {
Self::Spades => 0,
Self::Hearts => 1,
Self::Diamonds => 2,
Self::Clubs => 3,
Self::Notrump => 4,
}
}
}
impl From<sys::ddTableResults> for TricksTable {
fn from(table: sys::ddTableResults) -> Self {
const fn make_row(row: [c_int; 4]) -> TricksRow {
#[allow(clippy::cast_sign_loss)]
TricksRow::new(
(row[0] & 0xFF) as u8,
(row[1] & 0xFF) as u8,
(row[2] & 0xFF) as u8,
(row[3] & 0xFF) as u8,
)
}
Self([
make_row(table.resTable[Strain::Clubs.to_sys()]),
make_row(table.resTable[Strain::Diamonds.to_sys()]),
make_row(table.resTable[Strain::Hearts.to_sys()]),
make_row(table.resTable[Strain::Spades.to_sys()]),
make_row(table.resTable[Strain::Notrump.to_sys()]),
])
}
}
impl From<TricksTable> for sys::ddTableResults {
fn from(table: TricksTable) -> Self {
const fn make_row(row: TricksRow) -> [c_int; 4] {
[
row.get(Seat::North) as c_int,
row.get(Seat::East) as c_int,
row.get(Seat::South) as c_int,
row.get(Seat::West) as c_int,
]
}
Self {
resTable: [
make_row(table[Strain::Spades]),
make_row(table[Strain::Hearts]),
make_row(table[Strain::Diamonds]),
make_row(table[Strain::Clubs]),
make_row(table[Strain::Notrump]),
],
}
}
}
impl From<Deal> for sys::ddTableDeal {
fn from(deal: Deal) -> Self {
Self {
cards: Seat::ALL.map(|seat| {
let hand = deal[seat];
[
hand[Suit::Spades].to_bits().into(),
hand[Suit::Hearts].to_bits().into(),
hand[Suit::Diamonds].to_bits().into(),
hand[Suit::Clubs].to_bits().into(),
]
}),
}
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Vulnerability: u8 {
const NS = 1;
const EW = 2;
}
}
impl Vulnerability {
pub const ALL: Self = Self::all();
pub const NONE: Self = Self::empty();
#[must_use]
#[inline]
pub const fn to_sys(self) -> i32 {
const ALL: u8 = Vulnerability::all().bits();
const NS: u8 = Vulnerability::NS.bits();
const EW: u8 = Vulnerability::EW.bits();
match self.bits() {
0 => 0,
ALL => 1,
NS => 2,
EW => 3,
_ => unreachable!(),
}
}
#[must_use]
#[inline]
pub const fn rotate(self, condition: bool) -> Self {
Self::from_bits_truncate((self.bits() * 0x55) >> (condition as u8))
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("Invalid vulnerability: expected one of none, ns, ew, both")]
pub struct ParseVulnerabilityError;
impl FromStr for Vulnerability {
type Err = ParseVulnerabilityError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"none" => Ok(Self::NONE),
"ns" => Ok(Self::NS),
"ew" => Ok(Self::EW),
"both" | "all" => Ok(Self::ALL),
_ => Err(ParseVulnerabilityError),
}
}
}
impl fmt::Display for Vulnerability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match *self {
Self::NONE => "none",
Self::NS => "ns",
Self::EW => "ew",
Self::ALL => "both",
_ => unreachable!(),
})
}
}
const _: () = {
const ALL: Vulnerability = Vulnerability::all();
const NONE: Vulnerability = Vulnerability::empty();
assert!(matches!(ALL.rotate(true), ALL));
assert!(matches!(NONE.rotate(true), NONE));
assert!(matches!(Vulnerability::NS.rotate(true), Vulnerability::EW));
assert!(matches!(Vulnerability::EW.rotate(true), Vulnerability::NS));
assert!(matches!(ALL.rotate(false), ALL));
assert!(matches!(NONE.rotate(false), NONE));
assert!(matches!(Vulnerability::NS.rotate(false), Vulnerability::NS));
assert!(matches!(Vulnerability::EW.rotate(false), Vulnerability::EW));
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ParContract {
pub contract: Contract,
pub declarer: Seat,
pub overtricks: i8,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Par {
pub score: i32,
pub contracts: Vec<ParContract>,
}
impl Par {
#[must_use]
pub fn equivalent(&self, other: &Self) -> bool {
fn key(contracts: &[ParContract]) -> u32 {
contracts
.iter()
.map(|p| 1 << ((p.contract.bid.strain as u8) << 2 | p.declarer as u8))
.fold(0, u32::bitor)
}
self.score == other.score && key(&self.contracts) == key(&other.contracts)
}
}
impl From<sys::parResultsMaster> for Par {
fn from(par: sys::parResultsMaster) -> Self {
#[allow(clippy::cast_sign_loss)]
let len = par.number as usize * usize::from(par.contracts[0].level != 0);
#[allow(clippy::cast_sign_loss)]
let contracts = par.contracts[..len]
.iter()
.flat_map(|contract| {
let strain = [
Strain::Notrump,
Strain::Spades,
Strain::Hearts,
Strain::Diamonds,
Strain::Clubs,
][contract.denom as usize];
let (penalty, overtricks) = if contract.underTricks > 0 {
debug_assert!(contract.underTricks <= 13);
(Penalty::Doubled, -((contract.underTricks & 0xFF) as i8))
} else {
debug_assert!(contract.overTricks >= 0 && contract.overTricks <= 13);
(Penalty::Undoubled, (contract.overTricks & 0xFF) as i8)
};
debug_assert_eq!(contract.level, contract.level & 7);
let seat = match contract.seats & 3 {
0 => Seat::North,
1 => Seat::East,
2 => Seat::South,
_ => Seat::West,
};
let is_pair = contract.seats >= 4;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let contract = Contract::new(contract.level as u8, strain, penalty);
core::iter::once(ParContract {
contract,
declarer: seat,
overtricks,
})
.chain(if is_pair {
Some(ParContract {
contract,
declarer: seat.partner(),
overtricks,
})
} else {
None
})
})
.collect();
Self {
score: par.score,
contracts,
}
}
}
#[must_use]
pub fn calculate_par(tricks: TricksTable, vul: Vulnerability, dealer: Seat) -> Par {
let mut par = sys::parResultsMaster::default();
let status = unsafe {
sys::DealerParBin(
&mut tricks.into(),
&raw mut par,
vul.to_sys(),
dealer as c_int,
)
};
check(status);
par.into()
}
#[must_use]
pub fn calculate_pars(tricks: TricksTable, vul: Vulnerability) -> [Par; 2] {
let mut pars = [sys::parResultsMaster::default(); 2];
let status = unsafe { sys::SidesParBin(&mut tricks.into(), &raw mut pars[0], vul.to_sys()) };
check(status);
pars.map(Into::into)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Target {
Any(i8),
All(i8),
Legal,
}
impl Target {
#[must_use]
#[inline]
pub const fn target(self) -> c_int {
match self {
Self::Any(target) | Self::All(target) => target as c_int,
Self::Legal => -1,
}
}
#[must_use]
#[inline]
pub const fn solutions(self) -> c_int {
match self {
Self::Any(_) => 1,
Self::All(_) => 2,
Self::Legal => 3,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Board {
pub trump: Strain,
pub lead: Seat,
pub current_cards: ArrayVec<Card, 3>,
pub remaining: Deal,
}
impl From<Board> for sys::deal {
fn from(board: Board) -> Self {
let mut suits = [0; 3];
let mut ranks = [0; 3];
for (i, card) in board.current_cards.into_iter().enumerate() {
suits[i] = 3 - card.suit as c_int;
ranks[i] = c_int::from(card.rank.get());
}
Self {
trump: match board.trump {
Strain::Spades => 0,
Strain::Hearts => 1,
Strain::Diamonds => 2,
Strain::Clubs => 3,
Strain::Notrump => 4,
},
first: board.lead as c_int,
currentTrickSuit: suits,
currentTrickRank: ranks,
remainCards: sys::ddTableDeal::from(board.remaining).cards,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Objective {
pub board: Board,
pub target: Target,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Play {
pub card: Card,
pub equals: Holding,
pub score: i8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FoundPlays {
pub plays: ArrayVec<Play, 13>,
pub nodes: u32,
}
impl From<sys::futureTricks> for FoundPlays {
#[allow(clippy::cast_sign_loss)]
fn from(future: sys::futureTricks) -> Self {
let mut plays = ArrayVec::new();
for i in 0..future.cards as usize {
let equals = Holding::from_bits_truncate((future.equals[i] & 0xFFFF) as u16);
let score = (future.score[i] & 0xFF) as i8;
plays.push(Play {
card: Card {
suit: Suit::DESC[future.suit[i] as usize],
rank: Rank::new((future.rank[i] & 0xFF) as u8),
},
equals,
score,
});
}
Self {
plays,
nodes: future.nodes as u32,
}
}
}
static THREAD_POOL: LazyLock<Mutex<()>> = LazyLock::new(|| {
unsafe { sys::SetMaxThreads(0) };
Mutex::new(())
});
pub struct Solver(#[allow(dead_code)] parking_lot::MutexGuard<'static, ()>);
impl Solver {
#[must_use]
pub fn lock() -> Self {
Self(THREAD_POOL.lock())
}
#[must_use]
pub fn try_lock() -> Option<Self> {
THREAD_POOL.try_lock().map(Self)
}
#[must_use]
pub fn solve_deal(&self, deal: Deal) -> TricksTable {
let mut result = sys::ddTableResults::default();
let status = unsafe { sys::CalcDDtable(deal.into(), &raw mut result) };
check(status);
result.into()
}
unsafe fn solve_deal_segment(deals: &[Deal], flags: StrainFlags) -> sys::ddTablesRes {
debug_assert!(
deals.len() * flags.bits().count_ones() as usize <= sys::MAXNOOFBOARDS as usize
);
let mut pack = sys::ddTableDeals {
noOfTables: c_int::try_from(deals.len()).unwrap_or(c_int::MAX),
..Default::default()
};
deals
.iter()
.enumerate()
.for_each(|(i, &deal)| pack.deals[i] = deal.into());
let mut filter = [
c_int::from(!flags.contains(StrainFlags::SPADES)),
c_int::from(!flags.contains(StrainFlags::HEARTS)),
c_int::from(!flags.contains(StrainFlags::DIAMONDS)),
c_int::from(!flags.contains(StrainFlags::CLUBS)),
c_int::from(!flags.contains(StrainFlags::NOTRUMP)),
];
let mut res = sys::ddTablesRes::default();
let status = unsafe {
sys::CalcAllTables(
&raw mut pack,
-1,
filter.as_mut_ptr(),
&raw mut res,
&mut sys::allParResults::default(),
)
};
check(status);
res
}
#[must_use]
pub fn solve_deals(&self, deals: &[Deal], flags: StrainFlags) -> Vec<TricksTable> {
let mut tables = Vec::new();
for chunk in deals.chunks((sys::MAXNOOFBOARDS / flags.bits().count_ones()) as usize) {
tables.extend(
unsafe { Self::solve_deal_segment(chunk, flags) }.results[..chunk.len()]
.iter()
.map(|&x| TricksTable::from(x)),
);
}
tables
}
#[must_use]
pub fn solve_board(&self, objective: Objective) -> FoundPlays {
let mut result = sys::futureTricks::default();
let status = unsafe {
sys::SolveBoard(
objective.board.into(),
objective.target.target(),
objective.target.solutions(),
0,
&raw mut result,
0,
)
};
check(status);
FoundPlays::from(result)
}
unsafe fn solve_board_segment(args: &[Objective]) -> sys::solvedBoards {
debug_assert!(args.len() <= sys::MAXNOOFBOARDS as usize);
let mut pack = sys::boards {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
noOfBoards: args.len() as c_int,
..Default::default()
};
args.iter().enumerate().for_each(|(i, obj)| {
pack.deals[i] = obj.board.clone().into();
pack.target[i] = obj.target.target();
pack.solutions[i] = obj.target.solutions();
});
let mut res = sys::solvedBoards::default();
let status = unsafe { sys::SolveAllBoardsBin(&raw mut pack, &raw mut res) };
check(status);
res
}
#[must_use]
pub fn solve_boards(&self, args: &[Objective]) -> Vec<FoundPlays> {
let mut solutions = Vec::new();
for chunk in args.chunks(sys::MAXNOOFBOARDS as usize) {
solutions.extend(
unsafe { Self::solve_board_segment(chunk) }.solvedBoard[..chunk.len()]
.iter()
.map(|&x| FoundPlays::from(x)),
);
}
solutions
}
}