pub mod array;
pub mod map;
pub mod trie;
pub use array::Array;
pub use map::Map;
pub use trie::Trie;
use core::borrow::Borrow;
use core::fmt::{self, Write as _};
use core::ops::Deref;
use core::str::FromStr;
use dds_bridge::{Bid, Hand, Penalty, Strain};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Call {
Pass,
Double,
Redouble,
Bid(Bid),
}
impl From<Bid> for Call {
fn from(bid: Bid) -> Self {
Self::Bid(bid)
}
}
impl fmt::Display for Call {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Pass => f.write_char('P'),
Self::Double => f.write_char('X'),
Self::Redouble => f.write_str("XX"),
Self::Bid(bid) => bid.fmt(f),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("Invalid call: expected pass, double, redouble, or a bid like '1NT' or '3♠'")]
pub struct ParseCallError;
impl FromStr for Call {
type Err = ParseCallError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_uppercase().as_str() {
"P" | "PASS" => Ok(Self::Pass),
"X" | "DBL" | "DOUBLE" => Ok(Self::Double),
"XX" | "RDBL" | "REDOUBLE" => Ok(Self::Redouble),
_ => s.parse::<Bid>().map(Self::Bid).map_err(|_| ParseCallError),
}
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RelativeVulnerability: u8 {
const WE = 1;
const THEY = 2;
}
}
impl RelativeVulnerability {
pub const NONE: Self = Self::empty();
pub const ALL: Self = Self::all();
}
impl fmt::Display for RelativeVulnerability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::NONE => f.write_str("none"),
Self::WE => f.write_str("we"),
Self::THEY => f.write_str("they"),
Self::ALL => f.write_str("both"),
_ => unreachable!("RelativeVulnerability has only 4 valid bit combinations"),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("Invalid relative vulnerability: expected one of none, we, they, both, all")]
pub struct ParseRelativeVulnerabilityError;
impl FromStr for RelativeVulnerability {
type Err = ParseRelativeVulnerabilityError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"none" => Ok(Self::NONE),
"we" => Ok(Self::WE),
"they" => Ok(Self::THEY),
"both" | "all" => Ok(Self::ALL),
_ => Err(ParseRelativeVulnerabilityError),
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum IllegalCall {
#[error("Law 27: insufficient bid")]
InsufficientBid {
this: Bid,
last: Option<Bid>,
},
#[error("Law 36: inadmissible doubles and redoubles")]
InadmissibleDouble(Penalty),
#[error("Law 39: call after the final pass")]
AfterFinalPass,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Auction(Vec<Call>);
impl fmt::Display for Auction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut iter = self.0.iter();
if let Some(first) = iter.next() {
first.fmt(f)?;
for call in iter {
f.write_char(' ')?;
call.fmt(f)?;
}
}
Ok(())
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
pub enum ParseAuctionError {
#[error(transparent)]
Call(#[from] ParseCallError),
#[error(transparent)]
Illegal(#[from] IllegalCall),
}
impl FromStr for Auction {
type Err = ParseAuctionError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut auction = Self::new();
for token in s.split_ascii_whitespace() {
auction.try_push(token.parse()?)?;
}
Ok(auction)
}
}
impl Deref for Auction {
type Target = [Call];
fn deref(&self) -> &[Call] {
&self.0
}
}
impl AsRef<[Call]> for Auction {
fn as_ref(&self) -> &[Call] {
self
}
}
impl Borrow<[Call]> for Auction {
fn borrow(&self) -> &[Call] {
self
}
}
impl From<Auction> for Vec<Call> {
fn from(auction: Auction) -> Self {
auction.0
}
}
impl IntoIterator for Auction {
type Item = Call;
type IntoIter = std::vec::IntoIter<Call>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a Auction {
type Item = &'a Call;
type IntoIter = core::slice::Iter<'a, Call>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl Auction {
#[must_use]
pub const fn new() -> Self {
Self(Vec::new())
}
#[must_use]
pub fn has_ended(&self) -> bool {
self.len() >= 4 && self[self.len() - 3..] == [Call::Pass; 3]
}
fn can_double(&self) -> Result<(), IllegalCall> {
let admissible = self
.iter()
.rev()
.copied()
.enumerate()
.find(|&(_, call)| call != Call::Pass)
.is_some_and(|(index, call)| index & 1 == 0 && matches!(call, Call::Bid(_)));
if !admissible {
return Err(IllegalCall::InadmissibleDouble(Penalty::Doubled));
}
Ok(())
}
fn can_redouble(&self) -> Result<(), IllegalCall> {
let admissible = self
.iter()
.rev()
.copied()
.enumerate()
.find(|&(_, call)| call != Call::Pass)
.is_some_and(|(index, call)| index & 1 == 0 && call == Call::Double);
if !admissible {
return Err(IllegalCall::InadmissibleDouble(Penalty::Redoubled));
}
Ok(())
}
fn can_bid(&self, bid: Bid) -> Result<(), IllegalCall> {
let last = self.iter().rev().find_map(|&call| match call {
Call::Bid(bid) => Some(bid),
_ => None,
});
if last >= Some(bid) {
return Err(IllegalCall::InsufficientBid { this: bid, last });
}
Ok(())
}
fn can_push(&self, call: Call) -> Result<(), IllegalCall> {
if self.has_ended() {
return Err(IllegalCall::AfterFinalPass);
}
match call {
Call::Pass => Ok(()),
Call::Double => self.can_double(),
Call::Redouble => self.can_redouble(),
Call::Bid(bid) => self.can_bid(bid),
}
}
pub fn push(&mut self, call: Call) {
self.try_push(call).unwrap();
}
pub fn try_push(&mut self, call: Call) -> Result<(), IllegalCall> {
self.can_push(call)?;
self.0.push(call);
Ok(())
}
pub fn try_extend(&mut self, iter: impl IntoIterator<Item = Call>) -> Result<(), IllegalCall> {
let iter = iter.into_iter();
if let Some(size) = iter.size_hint().1 {
self.0.reserve(size);
}
for call in iter {
self.try_push(call)?;
}
Ok(())
}
pub fn pop(&mut self) -> Option<Call> {
self.0.pop()
}
pub fn truncate(&mut self, len: usize) {
self.0.truncate(len);
}
#[must_use]
pub fn declarer(&self) -> Option<usize> {
let (parity, strain) =
self.iter()
.copied()
.enumerate()
.rev()
.find_map(|(index, call)| match call {
Call::Bid(bid) => Some((index & 1, bid.strain)),
_ => None,
})?;
self.iter()
.skip(parity)
.step_by(2)
.position(|call| match call {
Call::Bid(bid) => bid.strain == strain,
_ => false,
})
.map(|position| position << 1 | parity)
}
}
pub trait System {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits>;
}
impl System for Trie {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits> {
self.get(auction)
.map(|f| f.classify(hand, vul, self.common_prefixes(auction)))
}
}
impl System for trie::Forest {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
auction: &[Call],
) -> Option<array::Logits> {
self[vul].classify(hand, vul, auction)
}
}
#[cfg(feature = "serde")]
mod serde_string {
use super::{Auction, Call};
use core::fmt::Display;
use core::str::FromStr;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
fn serialize<T: Display, S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(value)
}
fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
T::Err: Display,
D: Deserializer<'de>,
{
let s = <&str>::deserialize(deserializer)?;
s.parse().map_err(de::Error::custom)
}
impl Serialize for Call {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serialize(self, s)
}
}
impl<'de> Deserialize<'de> for Call {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
deserialize(d)
}
}
impl Serialize for Auction {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serialize(self, s)
}
}
impl<'de> Deserialize<'de> for Auction {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
deserialize(d)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use dds_bridge::Level;
fn bid(level: u8, strain: Strain) -> Call {
Call::Bid(Bid {
level: Level::new(level),
strain,
})
}
#[test]
fn call_roundtrip() {
for call in [
Call::Pass,
Call::Double,
Call::Redouble,
bid(1, Strain::Spades),
bid(3, Strain::Notrump),
bid(7, Strain::Clubs),
] {
assert_eq!(call.to_string().parse::<Call>().unwrap(), call);
}
}
#[test]
fn call_parses_aliases_case_insensitive() {
assert_eq!("p".parse::<Call>().unwrap(), Call::Pass);
assert_eq!("PASS".parse::<Call>().unwrap(), Call::Pass);
assert_eq!("pass".parse::<Call>().unwrap(), Call::Pass);
assert_eq!("x".parse::<Call>().unwrap(), Call::Double);
assert_eq!("dbl".parse::<Call>().unwrap(), Call::Double);
assert_eq!("DOUBLE".parse::<Call>().unwrap(), Call::Double);
assert_eq!("xx".parse::<Call>().unwrap(), Call::Redouble);
assert_eq!("RDBL".parse::<Call>().unwrap(), Call::Redouble);
assert_eq!("redouble".parse::<Call>().unwrap(), Call::Redouble);
}
#[test]
fn call_rejects_garbage() {
for s in ["", "Q", "8C", "1Z", "pas", "xxx"] {
assert!(s.parse::<Call>().is_err(), "should reject: {s:?}");
}
}
#[test]
fn relative_vulnerability_roundtrip() {
for v in [
RelativeVulnerability::NONE,
RelativeVulnerability::WE,
RelativeVulnerability::THEY,
RelativeVulnerability::ALL,
] {
assert_eq!(v.to_string().parse::<RelativeVulnerability>().unwrap(), v);
}
}
#[test]
fn relative_vulnerability_parses_case_insensitive_and_aliases() {
assert_eq!(
"NONE".parse::<RelativeVulnerability>().unwrap(),
RelativeVulnerability::NONE,
);
assert_eq!(
"We".parse::<RelativeVulnerability>().unwrap(),
RelativeVulnerability::WE,
);
assert_eq!(
"all".parse::<RelativeVulnerability>().unwrap(),
RelativeVulnerability::ALL,
);
assert!("ns".parse::<RelativeVulnerability>().is_err());
}
#[test]
fn auction_roundtrip() {
let mut auction = Auction::new();
for call in [
Call::Pass,
bid(1, Strain::Spades),
bid(2, Strain::Hearts),
Call::Double,
Call::Pass,
Call::Pass,
Call::Pass,
] {
auction.try_push(call).unwrap();
}
let s = auction.to_string();
assert_eq!(s, "P 1♠ 2♥ X P P P");
assert_eq!(s.parse::<Auction>().unwrap(), auction);
}
#[test]
fn empty_auction_roundtrip() {
let auction = Auction::new();
assert_eq!(auction.to_string(), "");
assert_eq!("".parse::<Auction>().unwrap(), auction);
assert_eq!(" \t ".parse::<Auction>().unwrap(), auction);
}
#[test]
fn auction_rejects_illegal_sequence() {
let err = "3♥ 2♠".parse::<Auction>().unwrap_err();
assert!(matches!(err, ParseAuctionError::Illegal(_)));
}
#[test]
fn auction_rejects_bad_token() {
let err = "P 1♠ Q".parse::<Auction>().unwrap_err();
assert!(matches!(err, ParseAuctionError::Call(_)));
}
}