rust_poker/range_filter/
mod.rs

1use crate::constants::HAND_CATEGORY_SHIFT;
2use crate::hand_evaluator::{evaluate, Hand};
3use crate::hand_range::{Combo, HandRange};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
7pub enum MadeHandCategories {
8    QuadsOrBetter,
9    FullHouse,
10    Flush,
11    Straight,
12    ThreeOfAKind,
13    TwoPair,
14    Pair,
15    // OverPair,
16    // TopPair,
17    // MiddlePair,
18    // WeakPair,
19    // AceHigh,
20    NoMadeHand,
21}
22
23impl MadeHandCategories {
24    pub fn get_table_index(&self) -> usize {
25        match self {
26            MadeHandCategories::QuadsOrBetter => 0,
27            MadeHandCategories::FullHouse => 1,
28            MadeHandCategories::Flush => 2,
29            MadeHandCategories::Straight => 3,
30            MadeHandCategories::ThreeOfAKind => 4,
31            MadeHandCategories::TwoPair => 5,
32            MadeHandCategories::Pair => 6,
33            MadeHandCategories::NoMadeHand => 7,
34        }
35    }
36    pub fn category_count() -> usize {
37        8
38    }
39}
40
41#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
42pub enum DrawHandCategories {
43    TwoCardFlushDraw,
44    NutFlushDraw,
45    OESD,
46    // Gutshot,
47    // OverCards,
48    NoDraw,
49}
50
51impl DrawHandCategories {
52    pub fn get_table_index(&self) -> usize {
53        match self {
54            DrawHandCategories::TwoCardFlushDraw => 0,
55            DrawHandCategories::NutFlushDraw => 1,
56            DrawHandCategories::OESD => 2,
57            DrawHandCategories::NoDraw => 3,
58        }
59    }
60    pub fn category_count() -> usize {
61        4
62    }
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66pub struct RangeFilter {
67    pub made_hands: Vec<MadeHandCategories>,
68    pub draw_hands: Vec<DrawHandCategories>,
69}
70
71impl HandRange {
72    pub fn apply_filter(&mut self, board: u64, filter: &RangeFilter) {
73        self.remove_conflicting_combos(board);
74        self.hands.retain(|combo| {
75            filter
76                .made_hands
77                .contains(&get_made_hand_category(&combo, board))
78                || filter
79                    .draw_hands
80                    .contains(&get_draw_hand_category(&combo, board))
81        });
82    }
83}
84
85/// Contains tables representing how a hand range interacts with a board
86/// Breaks hand range combo array into two tables of combos with each index representing a hand class
87#[derive(Serialize, Deserialize, Debug)]
88pub struct HandCategoryRange {
89    board: u64,
90    made_hand_table: Vec<Vec<String>>,
91    draw_hand_table: Vec<Vec<String>>,
92}
93
94impl HandCategoryRange {
95    pub fn from_range_and_board(hand_range: &mut HandRange, board: u64) -> Self {
96        let mut made_hand_table = vec![Vec::new(); MadeHandCategories::category_count()];
97        let mut draw_hand_table = vec![Vec::new(); DrawHandCategories::category_count()];
98        hand_range.remove_conflicting_combos(board);
99        hand_range.hands.iter().for_each(|combo| {
100            made_hand_table[get_made_hand_category(&combo, board).get_table_index()]
101                .push(combo.to_string());
102            draw_hand_table[get_draw_hand_category(&combo, board).get_table_index()]
103                .push(combo.to_string());
104        });
105        HandCategoryRange {
106            made_hand_table,
107            draw_hand_table,
108            board,
109        }
110    }
111}
112
113pub fn get_made_hand_category(hole_cards: &Combo, board: u64) -> MadeHandCategories {
114    let hand = Hand::from_bit_mask(board) + Hand::from_hole_cards(hole_cards.0, hole_cards.1);
115    let score = evaluate(&hand);
116    match score >> HAND_CATEGORY_SHIFT {
117        9 => MadeHandCategories::QuadsOrBetter,
118        8 => MadeHandCategories::QuadsOrBetter,
119        7 => MadeHandCategories::FullHouse,
120        6 => MadeHandCategories::Flush,
121        5 => MadeHandCategories::Straight,
122        4 => MadeHandCategories::ThreeOfAKind,
123        3 => MadeHandCategories::TwoPair,
124        2 => MadeHandCategories::Pair,
125        _ => MadeHandCategories::NoMadeHand,
126    }
127}
128
129pub fn get_draw_hand_category(hole_cards: &Combo, board: u64) -> DrawHandCategories {
130    let eval_hand = Hand::from_bit_mask(board);
131    let eval_board = Hand::default() + Hand::from_hole_cards(hole_cards.0, hole_cards.1);
132    let hand = Hand::from_bit_mask(board) + Hand::from_hole_cards(hole_cards.0, hole_cards.1);
133    // detect two card flush draw
134    for i in 0..4 {
135        if eval_hand.suit_count(i) == 2 && eval_board.suit_count(i) == 2 {
136            return DrawHandCategories::TwoCardFlushDraw;
137        }
138    }
139    // detect ace high flush draw
140    for i in 0..4 {
141        // get suit mask
142        if hand.suit_count(i) == 4 && hand.has_ace_of_suit(i) {
143            return DrawHandCategories::NutFlushDraw;
144        }
145    }
146    // detect OESD
147    let rank_mask = hand.get_rank_mask();
148    for i in 0..8 {
149        let oesd_mask = 0b11110u64 << i;
150        if (rank_mask & oesd_mask) == oesd_mask {
151            return DrawHandCategories::OESD;
152        }
153    }
154
155    DrawHandCategories::NoDraw
156}
157
158impl Hand {
159    /// does hand have an ace of specific suit
160    fn has_ace_of_suit(&self, suit: u8) -> bool {
161        ((self.get_mask() >> 16 * (3 - suit)) & (1u64 << 12)) != 0
162    }
163    /// returns 16 bit rank mask, ignoring suits
164    fn get_rank_mask(&self) -> u64 {
165        let hand_mask = self.get_mask();
166        let mut rank_mask = 0u64;
167        for i in 0..4 {
168            rank_mask |= (hand_mask >> 16 * (3 - i)) & 0xFFFF;
169        }
170        rank_mask
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::hand_range::get_card_mask;
178
179    #[test]
180    fn test_get_made_hand_category() {
181        let hole_cards = Combo(0u8, 1u8, 100);
182        let board = 0b11100;
183        assert!(get_made_hand_category(&hole_cards, board) == MadeHandCategories::QuadsOrBetter);
184    }
185
186    #[test]
187    fn test_get_draw_hand_category() {
188        {
189            let hole_cards = Combo(0u8, 4, 100);
190            let board = 0b0001000100000010;
191            assert!(
192                get_draw_hand_category(&hole_cards, board) == DrawHandCategories::TwoCardFlushDraw
193            );
194        }
195        {
196            let hole_cards = Combo(0u8, 1u8, 100);
197            let board = get_card_mask("4s5sAs");
198            assert_eq!(
199                get_draw_hand_category(&hole_cards, board),
200                DrawHandCategories::NutFlushDraw
201            );
202        }
203        {
204            let hole_cards = Combo(4u8, 5u8, 100); // 3, 3
205            let board = get_card_mask("4s5h6c");
206            assert_eq!(
207                get_draw_hand_category(&hole_cards, board),
208                DrawHandCategories::OESD
209            );
210        }
211        {
212            let hole_cards = Combo(8u8 * 4, 0, 100); // T, 2
213            let board = get_card_mask("JcQsKd");
214            assert_eq!(
215                get_draw_hand_category(&hole_cards, board),
216                DrawHandCategories::OESD
217            );
218        }
219    }
220
221    #[test]
222    fn test_from_range_and_board() {
223        let mut hand_range = HandRange::from_string("22+".to_string());
224        let board = get_card_mask("AsTh4c");
225        let tables = HandCategoryRange::from_range_and_board(&mut hand_range, board);
226        assert_eq!(9, tables.made_hand_table[4].len()); // 9 trips
227        assert_eq!(60, tables.made_hand_table[6].len()); // 60 pairs
228    }
229
230    #[test]
231    fn test_apply_filter() {
232        let mut hand_range = HandRange::from_string("22+".to_string());
233        let board = get_card_mask("AsTh4c");
234        let filter = RangeFilter {
235            made_hands: vec![MadeHandCategories::ThreeOfAKind],
236            draw_hands: vec![],
237        };
238        hand_range.apply_filter(board, &filter);
239        assert_eq!(9, hand_range.hands.len());
240    }
241}