pub mod array;
pub mod map;
pub mod trie;
pub use array::Array;
pub use map::Map;
pub use trie::Trie;
use core::borrow::Borrow;
use core::fmt::{self, Write as _};
use core::ops::Deref;
use core::str::FromStr;
use dds_bridge::{Bid, Hand, Penalty, Strain};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "serde",
derive(serde_with::SerializeDisplay, serde_with::DeserializeFromStr)
)]
pub enum Call {
Pass,
Double,
Redouble,
Bid(Bid),
}
impl From<Bid> for Call {
fn from(bid: Bid) -> Self {
Self::Bid(bid)
}
}
impl fmt::Display for Call {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Pass => f.write_char('P'),
Self::Double => f.write_char('X'),
Self::Redouble => f.write_str("XX"),
Self::Bid(bid) => bid.fmt(f),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("Invalid call: expected pass, double, redouble, or a bid like '1NT' or '3♠'")]
pub struct ParseCallError;
impl FromStr for Call {
type Err = ParseCallError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_uppercase().as_str() {
"P" | "PASS" => Ok(Self::Pass),
"X" | "DBL" | "DOUBLE" => Ok(Self::Double),
"XX" | "RDBL" | "REDOUBLE" => Ok(Self::Redouble),
_ => s.parse::<Bid>().map(Self::Bid).map_err(|_| ParseCallError),
}
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RelativeVulnerability: u8 {
const WE = 1;
const THEY = 2;
}
}
impl RelativeVulnerability {
pub const NONE: Self = Self::empty();
pub const ALL: Self = Self::all();
}
impl fmt::Display for RelativeVulnerability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::NONE => f.write_str("none"),
Self::WE => f.write_str("we"),
Self::THEY => f.write_str("they"),
Self::ALL => f.write_str("both"),
_ => unreachable!("RelativeVulnerability has only 4 valid bit combinations"),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("Invalid relative vulnerability: expected one of none, we, they, both, all")]
pub struct ParseRelativeVulnerabilityError;
impl FromStr for RelativeVulnerability {
type Err = ParseRelativeVulnerabilityError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"none" => Ok(Self::NONE),
"we" => Ok(Self::WE),
"they" => Ok(Self::THEY),
"both" | "all" => Ok(Self::ALL),
_ => Err(ParseRelativeVulnerabilityError),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum IllegalCall {
#[error("Law 27: insufficient bid")]
InsufficientBid {
this: Bid,
last: Option<Bid>,
},
#[error("Law 36: inadmissible doubles and redoubles")]
InadmissibleDouble(Penalty),
#[error("Law 39: call after the final pass")]
AfterFinalPass,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(serde_with::SerializeDisplay, serde_with::DeserializeFromStr)
)]
pub struct Auction(Vec<Call>);
impl fmt::Display for Auction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut iter = self.0.iter();
if let Some(first) = iter.next() {
first.fmt(f)?;
for call in iter {
f.write_char(' ')?;
call.fmt(f)?;
}
}
Ok(())
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
pub enum ParseAuctionError {
#[error(transparent)]
Call(#[from] ParseCallError),
#[error(transparent)]
Illegal(#[from] IllegalCall),
}
impl FromStr for Auction {
type Err = ParseAuctionError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut auction = Self::new();
for token in s.split_ascii_whitespace() {
auction.try_push(token.parse()?)?;
}
Ok(auction)
}
}
impl Deref for Auction {
type Target = [Call];
fn deref(&self) -> &[Call] {
&self.0
}
}
impl AsRef<[Call]> for Auction {
fn as_ref(&self) -> &[Call] {
self
}
}
impl Borrow<[Call]> for Auction {
fn borrow(&self) -> &[Call] {
self
}
}
impl From<Auction> for Vec<Call> {
fn from(auction: Auction) -> Self {
auction.0
}
}
impl IntoIterator for Auction {
type Item = Call;
type IntoIter = std::vec::IntoIter<Call>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a Auction {
type Item = &'a Call;
type IntoIter = core::slice::Iter<'a, Call>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl Auction {
#[must_use]
pub const fn new() -> Self {
Self(Vec::new())
}
#[must_use]
pub fn has_ended(&self) -> bool {
self.len() >= 4 && self[self.len() - 3..] == [Call::Pass; 3]
}
fn can_double(&self) -> Result<(), IllegalCall> {
let admissible = self
.iter()
.rev()
.copied()
.enumerate()
.find(|&(_, call)| call != Call::Pass)
.is_some_and(|(index, call)| index & 1 == 0 && matches!(call, Call::Bid(_)));
if !admissible {
return Err(IllegalCall::InadmissibleDouble(Penalty::Doubled));
}
Ok(())
}
fn can_redouble(&self) -> Result<(), IllegalCall> {
let admissible = self
.iter()
.rev()
.copied()
.enumerate()
.find(|&(_, call)| call != Call::Pass)
.is_some_and(|(index, call)| index & 1 == 0 && call == Call::Double);
if !admissible {
return Err(IllegalCall::InadmissibleDouble(Penalty::Redoubled));
}
Ok(())
}
fn can_bid(&self, bid: Bid) -> Result<(), IllegalCall> {
let last = self.iter().rev().find_map(|&call| match call {
Call::Bid(bid) => Some(bid),
_ => None,
});
if last >= Some(bid) {
return Err(IllegalCall::InsufficientBid { this: bid, last });
}
Ok(())
}
fn can_push(&self, call: Call) -> Result<(), IllegalCall> {
if self.has_ended() {
return Err(IllegalCall::AfterFinalPass);
}
match call {
Call::Pass => Ok(()),
Call::Double => self.can_double(),
Call::Redouble => self.can_redouble(),
Call::Bid(bid) => self.can_bid(bid),
}
}
pub fn push(&mut self, call: Call) {
self.try_push(call).unwrap();
}
pub fn try_push(&mut self, call: Call) -> Result<(), IllegalCall> {
self.can_push(call)?;
self.0.push(call);
Ok(())
}
pub fn try_extend(&mut self, iter: impl IntoIterator<Item = Call>) -> Result<(), IllegalCall> {
let iter = iter.into_iter();
if let Some(size) = iter.size_hint().1 {
self.0.reserve(size);
}
for call in iter {
self.try_push(call)?;
}
Ok(())
}
pub fn pop(&mut self) -> Option<Call> {
self.0.pop()
}
pub fn truncate(&mut self, len: usize) {
self.0.truncate(len);
}
#[must_use]
pub fn declarer(&self) -> Option<usize> {
let (parity, strain) =
self.iter()
.copied()
.enumerate()
.rev()
.find_map(|(index, call)| match call {
Call::Bid(bid) => Some((index & 1, bid.strain)),
_ => None,
})?;
self.iter()
.skip(parity)
.step_by(2)
.position(|call| match call {
Call::Bid(bid) => bid.strain == strain,
_ => false,
})
.map(|position| position << 1 | parity)
}
}
pub trait System {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits>;
}
impl System for Trie {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits> {
self.get(auction)
.map(|f| f.classify(hand, vul, self.common_prefixes(auction)))
}
}
impl System for trie::Forest {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits> {
self[vul].classify(hand, vul, auction)
}
}
#[cfg(test)]
mod tests;