use crate::core::card::Card;
use super::eval;
use super::{CardBitSet, FlatHand, Hand};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Copy)]
#[repr(transparent)]
pub struct Rank(u16);
const CATEGORY_SHIFT: u32 = 12;
impl Rank {
#[inline]
pub(crate) const fn from_score(score: u16) -> Self {
Rank(score)
}
pub const HIGH_CARD_MIN: Rank = Rank(1 << CATEGORY_SHIFT);
pub const ONE_PAIR_MIN: Rank = Rank(2 << CATEGORY_SHIFT);
pub const TWO_PAIR_MIN: Rank = Rank(3 << CATEGORY_SHIFT);
pub const THREE_OF_A_KIND_MIN: Rank = Rank(4 << CATEGORY_SHIFT);
pub const STRAIGHT_MIN: Rank = Rank(5 << CATEGORY_SHIFT);
pub const FLUSH_MIN: Rank = Rank(6 << CATEGORY_SHIFT);
pub const FULL_HOUSE_MIN: Rank = Rank(7 << CATEGORY_SHIFT);
pub const FOUR_OF_A_KIND_MIN: Rank = Rank(8 << CATEGORY_SHIFT);
pub const STRAIGHT_FLUSH_MIN: Rank = Rank(9 << CATEGORY_SHIFT);
#[inline]
pub const fn category(self) -> CoreRank {
match self.0 >> CATEGORY_SHIFT {
1 => CoreRank::HighCard,
2 => CoreRank::OnePair,
3 => CoreRank::TwoPair,
4 => CoreRank::ThreeOfAKind,
5 => CoreRank::Straight,
6 => CoreRank::Flush,
7 => CoreRank::FullHouse,
8 => CoreRank::FourOfAKind,
_ => CoreRank::StraightFlush,
}
}
#[inline]
pub const fn value_bits(self) -> u16 {
self.0 & ((1 << CATEGORY_SHIFT) - 1)
}
}
impl std::fmt::Debug for Rank {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}({})", self.category(), self.value_bits())
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Copy)]
pub enum CoreRank {
HighCard,
OnePair,
TwoPair,
ThreeOfAKind,
Straight,
Flush,
FullHouse,
FourOfAKind,
StraightFlush,
}
impl std::fmt::Display for CoreRank {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HighCard => write!(f, "High Card"),
Self::OnePair => write!(f, "One Pair"),
Self::TwoPair => write!(f, "Two Pair"),
Self::ThreeOfAKind => write!(f, "Three of a Kind"),
Self::Straight => write!(f, "Straight"),
Self::Flush => write!(f, "Flush"),
Self::FullHouse => write!(f, "Full House"),
Self::FourOfAKind => write!(f, "Four of a Kind"),
Self::StraightFlush => write!(f, "Straight Flush"),
}
}
}
impl From<Rank> for CoreRank {
fn from(rank: Rank) -> Self {
rank.category()
}
}
pub trait Rankable {
fn rank(&self) -> Rank;
}
#[derive(Clone, Copy)]
pub struct SevenCardAccum {
key: u64,
mask: u64,
}
impl Default for SevenCardAccum {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl SevenCardAccum {
#[inline]
pub fn new() -> Self {
Self {
key: eval::DEFAULT_KEY,
mask: 0,
}
}
#[inline]
pub fn add(&mut self, c: Card) {
eval::add_card(&mut self.key, &mut self.mask, u8::from(c));
}
#[inline]
pub fn rank(&self) -> Rank {
debug_assert!(
self.mask.count_ones() <= 7,
"hand evaluator supports at most 7 cards, got {}",
self.mask.count_ones()
);
Rank::from_score(eval::evaluate_key(self.key, self.mask))
}
}
impl std::ops::AddAssign<Card> for SevenCardAccum {
#[inline]
fn add_assign(&mut self, c: Card) {
self.add(c);
}
}
impl std::ops::Add<Card> for SevenCardAccum {
type Output = SevenCardAccum;
#[inline]
fn add(mut self, c: Card) -> SevenCardAccum {
SevenCardAccum::add(&mut self, c);
self
}
}
const _: () = assert!(std::mem::size_of::<SevenCardAccum>() == 16);
fn rank_cards<I: Iterator<Item = Card>>(cards: I) -> Rank {
let mut acc = SevenCardAccum::new();
for c in cards {
acc.add(c);
}
acc.rank()
}
impl Rankable for FlatHand {
fn rank(&self) -> Rank {
rank_cards(self.iter().copied())
}
}
impl Rankable for Vec<Card> {
fn rank(&self) -> Rank {
rank_cards(self.iter().copied())
}
}
impl Rankable for [Card] {
fn rank(&self) -> Rank {
rank_cards(self.iter().copied())
}
}
impl Rankable for &[Card] {
fn rank(&self) -> Rank {
rank_cards(self.iter().copied())
}
}
impl Rankable for Hand {
fn rank(&self) -> Rank {
rank_cards(self.iter())
}
}
impl Rankable for CardBitSet {
fn rank(&self) -> Rank {
rank_cards(self.into_iter())
}
}
#[cfg(test)]
pub(crate) mod oracle {
const WHEEL: u32 = 0b1_0000_0000_1111;
fn straight(v: u32) -> Option<u32> {
let run = v & (v << 1) & (v << 2) & (v << 3) & (v << 4);
if run != 0 {
Some(32 - 4 - run.leading_zeros())
} else if v & WHEEL == WHEEL {
Some(0)
} else {
None
}
}
fn keep_high(v: u32) -> u32 {
if v == 0 {
0
} else {
1 << (31 - v.leading_zeros())
}
}
fn keep_top(mut v: u32, n: u32) -> u32 {
while v.count_ones() > n {
v &= v - 1;
}
v
}
pub(crate) fn rank_u64(cards: u64) -> u32 {
let s = [
(cards & 0x1FFF) as u32,
((cards >> 13) & 0x1FFF) as u32,
((cards >> 26) & 0x1FFF) as u32,
((cards >> 39) & 0x1FFF) as u32,
];
let value_set = s[0] | s[1] | s[2] | s[3];
let pack = |cat: u32, payload: u32| (cat << 28) | payload;
let flush = s.iter().find(|m| m.count_ones() >= 5);
if let Some(&fs) = flush {
return match straight(fs) {
Some(r) => pack(8, r),
None => pack(5, keep_top(fs, 5)),
};
}
let e2 = (s[0] & s[1])
| (s[0] & s[2])
| (s[0] & s[3])
| (s[1] & s[2])
| (s[1] & s[3])
| (s[2] & s[3]);
let e3 = (s[0] & s[1] & s[2])
| (s[0] & s[1] & s[3])
| (s[0] & s[2] & s[3])
| (s[1] & s[2] & s[3]);
let e4 = s[0] & s[1] & s[2] & s[3];
let pairs = e2 & !e3;
let trips = e3 & !e4;
let quads = e4;
if quads != 0 {
pack(7, (quads << 13) | keep_high(value_set ^ quads))
} else if trips != 0 && trips.count_ones() == 2 {
let set = keep_high(trips);
pack(6, (set << 13) | (trips ^ set))
} else if trips != 0 && pairs != 0 {
pack(6, (trips << 13) | keep_high(pairs))
} else if let Some(r) = straight(value_set) {
pack(4, r)
} else if trips != 0 {
pack(3, (trips << 13) | keep_top(value_set ^ trips, 2))
} else if pairs.count_ones() >= 2 {
let two = keep_top(pairs, 2);
pack(2, (two << 13) | keep_high(value_set ^ two))
} else if pairs != 0 {
pack(1, (pairs << 13) | keep_top(value_set ^ pairs, 3))
} else {
pack(0, keep_top(value_set, 5))
}
}
pub(crate) fn category(packed: u32) -> u32 {
packed >> 28
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Card;
use crate::core::card::*;
fn bits_of(cards: &[Card]) -> u64 {
cards.iter().fold(0u64, |a, c| a | (1u64 << u8::from(*c)))
}
fn new_rank(cards: &[Card]) -> Rank {
let mut acc = SevenCardAccum::new();
for c in cards {
acc.add(*c);
}
acc.rank()
}
fn old_core(packed: u32) -> CoreRank {
match oracle::category(packed) {
0 => CoreRank::HighCard,
1 => CoreRank::OnePair,
2 => CoreRank::TwoPair,
3 => CoreRank::ThreeOfAKind,
4 => CoreRank::Straight,
5 => CoreRank::Flush,
6 => CoreRank::FullHouse,
7 => CoreRank::FourOfAKind,
_ => CoreRank::StraightFlush,
}
}
#[test]
fn category_agrees_exhaustive_five() {
let cards: Vec<Card> = (0u8..52).map(Card::from).collect();
let n = cards.len();
for a in 0..n {
for b in (a + 1)..n {
for c in (b + 1)..n {
for d in (c + 1)..n {
for e in (d + 1)..n {
let hand = [cards[a], cards[b], cards[c], cards[d], cards[e]];
let newr = new_rank(&hand);
let oldp = oracle::rank_u64(bits_of(&hand));
assert_eq!(
newr.category(),
old_core(oldp),
"category mismatch {hand:?}"
);
}
}
}
}
}
}
#[test]
fn order_isomorphism_exhaustive_five() {
use std::collections::HashMap;
let cards: Vec<Card> = (0u8..52).map(Card::from).collect();
let n = cards.len();
let mut map: HashMap<u32, u16> = HashMap::new();
for a in 0..n {
for b in (a + 1)..n {
for c in (b + 1)..n {
for d in (c + 1)..n {
for e in (d + 1)..n {
let hand = [cards[a], cards[b], cards[c], cards[d], cards[e]];
let oldp = oracle::rank_u64(bits_of(&hand));
let news = new_rank(&hand).0;
if let Some(prev) = map.insert(oldp, news) {
assert_eq!(prev, news, "old rank maps to two new scores");
}
}
}
}
}
}
let mut pairs: Vec<(u32, u16)> = map.into_iter().collect();
pairs.sort_unstable_by_key(|p| p.0);
for w in pairs.windows(2) {
assert!(w[0].1 < w[1].1, "ordering not preserved: {w:?}");
}
}
#[test]
fn differential_random_seven() {
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
let mut cards: Vec<Card> = (0u8..52).map(Card::from).collect();
let mut rng = StdRng::seed_from_u64(0xC0FFEE);
let mut samples: Vec<(u32, u16)> = Vec::new();
for _ in 0..200_000 {
cards.shuffle(&mut rng);
let hand = &cards[..7];
let oldp = oracle::rank_u64(bits_of(hand));
let news = new_rank(hand);
assert_eq!(
news.category(),
old_core(oldp),
"category mismatch {hand:?}"
);
samples.push((oldp, news.0));
}
for i in 0..samples.len().min(2000) {
for j in (i + 1)..samples.len().min(2000) {
let (oa, na) = samples[i];
let (ob, nb) = samples[j];
assert_eq!(oa.cmp(&ob), na.cmp(&nb), "pairwise order disagreement");
}
}
}
#[test]
fn differential_partial_hands() {
let cards: Vec<Card> = (0u8..52).map(Card::from).collect();
for count in [2usize, 3, 4, 6] {
for start in 0..(52 - count) {
let hand = &cards[start..start + count];
let oldp = oracle::rank_u64(bits_of(hand));
let news = new_rank(hand);
assert_eq!(
news.category(),
old_core(oldp),
"category mismatch {hand:?}"
);
}
}
}
#[test]
fn category_ordering_holds() {
assert!(Rank::HIGH_CARD_MIN < Rank::ONE_PAIR_MIN);
assert!(Rank::ONE_PAIR_MIN < Rank::TWO_PAIR_MIN);
assert!(Rank::TWO_PAIR_MIN < Rank::THREE_OF_A_KIND_MIN);
assert!(Rank::THREE_OF_A_KIND_MIN < Rank::STRAIGHT_MIN);
assert!(Rank::STRAIGHT_MIN < Rank::FLUSH_MIN);
assert!(Rank::FLUSH_MIN < Rank::FULL_HOUSE_MIN);
assert!(Rank::FULL_HOUSE_MIN < Rank::FOUR_OF_A_KIND_MIN);
assert!(Rank::FOUR_OF_A_KIND_MIN < Rank::STRAIGHT_FLUSH_MIN);
}
#[test]
fn seven_card_accum_size() {
assert_eq!(std::mem::size_of::<SevenCardAccum>(), 16);
}
#[test]
fn known_hands_have_expected_categories() {
let sf = FlatHand::new_from_str("AdKdQdJdTd").unwrap();
assert_eq!(sf.rank().category(), CoreRank::StraightFlush);
let quads = FlatHand::new_from_str("AsAhAdAcKs").unwrap();
assert_eq!(quads.rank().category(), CoreRank::FourOfAKind);
let wheel = FlatHand::new_from_str("Ad2c3s4h5d").unwrap();
assert_eq!(wheel.rank().category(), CoreRank::Straight);
}
#[test]
fn test_core_rank_from_categories() {
assert_eq!(CoreRank::HighCard, Rank::HIGH_CARD_MIN.category());
assert_eq!(CoreRank::OnePair, Rank::ONE_PAIR_MIN.category());
assert_eq!(CoreRank::TwoPair, Rank::TWO_PAIR_MIN.category());
assert_eq!(CoreRank::ThreeOfAKind, Rank::THREE_OF_A_KIND_MIN.category());
assert_eq!(CoreRank::Straight, Rank::STRAIGHT_MIN.category());
assert_eq!(CoreRank::Flush, Rank::FLUSH_MIN.category());
assert_eq!(CoreRank::FullHouse, Rank::FULL_HOUSE_MIN.category());
assert_eq!(CoreRank::FourOfAKind, Rank::FOUR_OF_A_KIND_MIN.category());
assert_eq!(CoreRank::StraightFlush, Rank::STRAIGHT_FLUSH_MIN.category());
}
#[test]
fn test_core_rank_into() {
let r: CoreRank = Rank::FLUSH_MIN.into();
assert_eq!(r, CoreRank::Flush);
}
#[test]
fn test_core_rank_ordering() {
assert!(CoreRank::HighCard < CoreRank::OnePair);
assert!(CoreRank::OnePair < CoreRank::TwoPair);
assert!(CoreRank::TwoPair < CoreRank::ThreeOfAKind);
assert!(CoreRank::ThreeOfAKind < CoreRank::Straight);
assert!(CoreRank::Straight < CoreRank::Flush);
assert!(CoreRank::Flush < CoreRank::FullHouse);
assert!(CoreRank::FullHouse < CoreRank::FourOfAKind);
assert!(CoreRank::FourOfAKind < CoreRank::StraightFlush);
}
#[test]
fn test_core_rank_display() {
assert_eq!(CoreRank::HighCard.to_string(), "High Card");
assert_eq!(CoreRank::OnePair.to_string(), "One Pair");
assert_eq!(CoreRank::TwoPair.to_string(), "Two Pair");
assert_eq!(CoreRank::ThreeOfAKind.to_string(), "Three of a Kind");
assert_eq!(CoreRank::Straight.to_string(), "Straight");
assert_eq!(CoreRank::Flush.to_string(), "Flush");
assert_eq!(CoreRank::FullHouse.to_string(), "Full House");
assert_eq!(CoreRank::FourOfAKind.to_string(), "Four of a Kind");
assert_eq!(CoreRank::StraightFlush.to_string(), "Straight Flush");
}
#[test]
fn test_rank_ordering_within_same_type() {
let pair_aces = FlatHand::new_from_str("AsAhKdQcJs").unwrap();
let pair_kings = FlatHand::new_from_str("KsKhAdQcJs").unwrap();
assert!(pair_aces.rank() > pair_kings.rank());
let two_pair_ak = FlatHand::new_from_str("AsAhKdKcJs").unwrap();
let two_pair_aq = FlatHand::new_from_str("AsAhQdQcKs").unwrap();
assert!(two_pair_ak.rank() > two_pair_aq.rank());
let trips_aces = FlatHand::new_from_str("AsAhAdKcJs").unwrap();
let trips_kings = FlatHand::new_from_str("KsKhKdAcJs").unwrap();
assert!(trips_aces.rank() > trips_kings.rank());
}
#[test]
fn test_rankable_vec_and_slice() {
let cards: Vec<Card> = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::King, Suit::Spade),
Card::new(Value::Queen, Suit::Spade),
Card::new(Value::Jack, Suit::Spade),
Card::new(Value::Ten, Suit::Spade),
];
assert_eq!(cards.rank().category(), CoreRank::StraightFlush);
let slice: &[Card] = &cards;
assert_eq!(slice.rank().category(), CoreRank::StraightFlush);
assert_eq!(cards[..].rank().category(), CoreRank::StraightFlush);
}
#[test]
fn test_wheel_straight_detection() {
let wheel = FlatHand::new_from_str("Ad2c3s4h5d").unwrap();
assert_eq!(wheel.rank().category(), CoreRank::Straight);
let six_high = FlatHand::new_from_str("2c3s4h5d6c").unwrap();
assert!(wheel.rank() < six_high.rank());
let not_wheel = FlatHand::new_from_str("Ad2c3s4h6d").unwrap();
assert_eq!(not_wheel.rank().category(), CoreRank::HighCard);
let almost_wheel = FlatHand::new_from_str("Ad2c3s4h6c").unwrap();
assert_eq!(almost_wheel.rank().category(), CoreRank::HighCard);
}
#[test]
fn test_seven_card_categories() {
let cards: Vec<Card> = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::King, Suit::Spade),
Card::new(Value::Queen, Suit::Spade),
Card::new(Value::Jack, Suit::Spade),
Card::new(Value::Ten, Suit::Spade),
Card::new(Value::Nine, Suit::Spade),
Card::new(Value::Eight, Suit::Spade),
];
assert_eq!(cards.rank().category(), CoreRank::StraightFlush);
}
#[test]
fn seven_card_accum_matches_rank_best_of() {
for s in [
"Ad8h9cTc5c2s7d", "AdAc9d8cTs2h3s", "AdAc9d8cTs8s3s", "AdAcAs8cTs2h3s", "2c3s4h5s6d8cKh", "Ad8d9dTd5d2h3s", "AdAc9d9c9s2h3s", "AdAcAsAh8cTs2h", "AdKdQdJdTd9d8d", ] {
let hand = FlatHand::new_from_str(s).unwrap();
let mut acc = SevenCardAccum::new();
for c in hand.iter() {
acc.add(*c);
}
assert_eq!(acc.rank(), hand.rank(), "mismatch for {s}");
}
}
#[test]
fn seven_card_accum_order_independent() {
let hand = FlatHand::new_from_str("2s2h2d2c8d8sKd").unwrap();
let cards: Vec<Card> = hand.iter().copied().collect();
let mut forward = SevenCardAccum::new();
for c in &cards {
forward.add(*c);
}
let mut backward = SevenCardAccum::new();
for c in cards.iter().rev() {
backward.add(*c);
}
assert_eq!(forward.rank(), backward.rank());
assert_eq!(forward.rank(), hand.rank());
}
#[test]
fn seven_card_accum_add_operators() {
let base = FlatHand::new_from_str("AsAhKs7c2d").unwrap();
let mut via_assign = SevenCardAccum::new();
for c in base.iter() {
via_assign += *c;
}
let via_add = base.iter().fold(SevenCardAccum::new(), |acc, c| acc + *c);
assert_eq!(via_assign.rank(), via_add.rank());
assert_eq!(via_assign.rank(), base.rank());
}
#[test]
fn seven_card_accum_copy_reuse_equals_from_scratch() {
let base_hand = FlatHand::new_from_str("AsAhKs7c2d").unwrap();
let mut base = SevenCardAccum::new();
for c in base_hand.iter() {
base.add(*c);
}
let mut quads = base; quads.add(Card::new(Value::Ace, Suit::Club));
quads.add(Card::new(Value::Ace, Suit::Diamond));
let quads_scratch = FlatHand::new_from_str("AsAhKs7c2dAcAd").unwrap();
assert_eq!(quads.rank(), quads_scratch.rank());
let mut pair = base; pair.add(Card::new(Value::Queen, Suit::Heart));
pair.add(Card::new(Value::Jack, Suit::Diamond));
let pair_scratch = FlatHand::new_from_str("AsAhKs7c2dQhJd").unwrap();
assert_eq!(pair.rank(), pair_scratch.rank());
assert!(quads.rank() > pair.rank());
}
}