openpql_prelude/card/
hand_iter.rs

1use super::{Card, CardCount, HandN};
2
3/// Iterator for generating all possible hands of N cards.
4///
5/// Iterates through all possible combinations of N cards from either a standard 52-card deck
6/// or a short deck (36 cards).
7#[derive(Debug, Clone)]
8pub struct HandIter<const SD: bool, const N: usize> {
9    indices: [CardCount; N],
10    done: bool,
11}
12
13#[allow(clippy::cast_possible_truncation)]
14impl<const SD: bool, const N: usize> Default for HandIter<SD, N> {
15    fn default() -> Self {
16        let mut indices = [0; N];
17        for i in 0..N as CardCount {
18            indices[i as usize] = i;
19        }
20        Self {
21            indices,
22            done: N == 0,
23        }
24    }
25}
26
27impl<const SD: bool, const N: usize> Iterator for HandIter<SD, N> {
28    type Item = HandN<N>;
29
30    fn next(&mut self) -> Option<Self::Item> {
31        if self.done {
32            return None;
33        }
34
35        let all = Card::all::<SD>();
36        let max_i = all.len();
37
38        let mut cards = [Card::default(); N];
39        for i in 0..N {
40            cards[i] = all[self.indices[i] as usize];
41        }
42
43        let mut pos = N - 1;
44        self.indices[pos] += 1;
45
46        while self.indices[pos] as usize >= max_i - (N - 1 - pos) {
47            if pos == 0 {
48                self.done = true;
49                return Some(HandN::new(cards));
50            }
51
52            pos -= 1;
53            self.indices[pos] += 1;
54        }
55
56        for i in (pos + 1)..N {
57            self.indices[i] = self.indices[i - 1] + 1;
58        }
59
60        Some(HandN::new(cards))
61    }
62
63    /// # Panics
64    /// May panic on 32-bit systems when the result exceeds `u32::MAX`.
65    /// For example, C(52, 26) ≈ 4.96 x 10¹⁴, which is greater than 2³² - 1 (4,294,967,295).
66    /// However, this function works correctly for typical small values of N, such as N = 7.
67    fn size_hint(&self) -> (usize, Option<usize>) {
68        let n = const { if SD { Card::N_CARDS_SD } else { Card::N_CARDS } };
69        let r = N;
70
71        let len = ncr(n as usize, r);
72
73        (len, Some(len))
74    }
75}
76
77impl<const SD: bool, const N: usize> ExactSizeIterator for HandIter<SD, N> {}
78
79pub fn ncr(n: usize, r: usize) -> usize {
80    if r > n {
81        return 0;
82    }
83
84    if r == 0 || r == n {
85        return 1;
86    }
87
88    let r = r.min(n - r);
89
90    let mut result: usize = 1;
91    for i in 0..r {
92        result = result * (n - i) / (i + 1);
93    }
94
95    result
96}
97
98#[cfg(test)]
99#[cfg_attr(coverage_nightly, coverage(off))]
100mod tests {
101    use super::*;
102    use crate::*;
103
104    fn handiter_vec<const N: usize, const SD: bool>() -> Vec<Vec<Card>> {
105        HandN::<N>::iter_all::<SD>()
106            .map(|hand| hand.to_vec())
107            .collect()
108    }
109
110    fn itertool_vec<const N: usize, const SD: bool>() -> Vec<Vec<Card>> {
111        Card::all::<SD>().iter().copied().combinations(N).collect()
112    }
113
114    #[test]
115    fn test_hand_iter_holdem() {
116        const SD: bool = false;
117        assert_eq!(handiter_vec::<2, SD>(), itertool_vec::<2, SD>());
118        assert_eq!(handiter_vec::<3, SD>(), itertool_vec::<3, SD>());
119    }
120
121    #[test]
122    fn test_hand_iter_shortdeck() {
123        const SD: bool = true;
124        assert_eq!(handiter_vec::<2, SD>(), itertool_vec::<2, SD>());
125        assert_eq!(handiter_vec::<3, SD>(), itertool_vec::<3, SD>());
126    }
127
128    #[quickcheck]
129    fn test_pascals_identity(n: usize, r: usize) -> TestResult {
130        if n == 0 || n > 52 || r == 0 || r > n {
131            return TestResult::discard();
132        }
133        // Pascal's identity C(n, r) = C(n-1, r-1) + C(n-1, r)
134        let lhs = ncr(n, r);
135        let rhs1 = ncr(n - 1, r - 1);
136        let rhs2 = ncr(n - 1, r);
137
138        TestResult::from_bool(lhs == rhs1 + rhs2)
139    }
140
141    #[test]
142    fn test_ncr() {
143        assert_eq!(ncr(52, 4), 270_725); // Omaha
144        assert_eq!(ncr(52, 5), 2_598_960);
145    }
146}