board_game/util/
bits.rs

1//! Utilities with compact bit data structures.
2
3use std::num::Wrapping;
4use std::ops::Neg;
5
6use num_traits::{PrimInt, Unsigned, WrappingSub, Zero};
7
8#[derive(Debug)]
9/// Iterator over the indices of the set bits of an integer,
10/// from least to most significant.
11///
12/// # Example
13///
14/// ```
15/// use board_game::util::bits::BitIter;
16/// let b = BitIter::new(0b10011u32);
17/// assert_eq!(b.collect::<Vec<_>>(), vec![0, 1, 4]);
18/// ```
19pub struct BitIter<N: PrimInt + Unsigned> {
20    left: N,
21}
22
23impl<N: PrimInt + Unsigned> BitIter<N> {
24    pub fn new(left: N) -> Self {
25        BitIter { left }
26    }
27}
28
29impl<N: PrimInt + Unsigned> Iterator for BitIter<N> {
30    type Item = u8;
31
32    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
33        //TODO report bug to intel-rust that self.left.is_zero() complains about a missing trait
34        if self.left == N::zero() {
35            None
36        } else {
37            let index = self.left.trailing_zeros() as u8;
38            self.left = self.left & (self.left - N::one());
39            Some(index)
40        }
41    }
42}
43
44pub fn get_nth_set_bit<N: PrimInt + Unsigned + WrappingSub>(mut x: N, n: u32) -> u8 {
45    for _ in 0..n {
46        x = x & x.wrapping_sub(&N::one());
47    }
48    debug_assert!(x != N::zero());
49    x.trailing_zeros() as u8
50}
51
52/// Iterator over all subsets of the given mask.
53///
54/// If the mask has `N` set bits this yields `2 ** N` values.
55///
56/// Implementation based on https://analog-hors.github.io/writing/magic-bitboards/
57/// and https://www.chessprogramming.org/Traversing_Subsets_of_a_Set#All_Subsets_of_any_Set
58#[derive(Debug)]
59pub struct SubSetIterator {
60    start: bool,
61    curr: u64,
62    mask: u64,
63}
64
65impl SubSetIterator {
66    pub fn new(mask: u64) -> Self {
67        Self {
68            start: true,
69            curr: 0,
70            mask,
71        }
72    }
73}
74
75impl Iterator for SubSetIterator {
76    type Item = u64;
77
78    fn next(&mut self) -> Option<Self::Item> {
79        if self.curr == 0 && !self.start {
80            return None;
81        }
82        self.start = false;
83
84        let result = self.curr;
85        self.curr = (self.curr.wrapping_sub(self.mask)) & self.mask;
86        Some(result)
87    }
88}
89
90/// Iterator over all subsets of the given mask that have `M` bits set.
91///
92/// If the mask has `N` set bits this yields `nCr(N, M)` values.
93/// Only yields any values if `N >= M`.
94///
95/// Implementation based on https://www.chessprogramming.org/Traversing_Subsets_of_a_Set#Snoobing_any_Sets
96#[derive(Debug)]
97pub struct SubSetCountIterator {
98    mask: u64,
99    curr: u64,
100    m: u32,
101}
102
103impl SubSetCountIterator {
104    pub fn new(mask: u64, m: u32) -> Self {
105        if m > mask.count_ones() {
106            // don't yield any values
107            return SubSetCountIterator {
108                mask,
109                curr: 0,
110                m: u32::MAX,
111            };
112        }
113
114        if m == 0 {
115            // yield zero once
116            return SubSetCountIterator { mask, curr: 0, m: 0 };
117        }
118
119        let start = {
120            // TODO is there a cleaner/faster way to write this?
121            let mut left = m;
122            let mut start = 0;
123
124            for i in 0..64 {
125                if left == 0 {
126                    break;
127                }
128
129                if mask & (1 << i) != 0 {
130                    left -= 1;
131                    start |= 1 << i;
132                }
133            }
134
135            start
136        };
137
138        SubSetCountIterator { mask, curr: start, m }
139    }
140}
141
142impl Iterator for SubSetCountIterator {
143    type Item = u64;
144
145    fn next(&mut self) -> Option<Self::Item> {
146        // TODO find a better way to do this
147        if self.m == 0 {
148            self.m = u32::MAX;
149            return Some(0);
150        }
151        if self.m == u32::MAX {
152            return None;
153        }
154        if self.curr.count_ones() != self.m {
155            self.m = u32::MAX;
156            return None;
157        }
158
159        if self.curr == 0 {
160            None
161        } else {
162            let result = self.curr;
163            self.curr = snoob_masked(self.curr, self.mask);
164            Some(result)
165        }
166    }
167}
168
169/// Based on https://www.chessprogramming.org/Traversing_Subsets_of_a_Set#Snoobing_any_Sets
170fn snoob_masked(sub: u64, set: u64) -> u64 {
171    let mut sub = Wrapping(sub);
172    let mut set = Wrapping(set);
173
174    let mut tmp = sub - Wrapping(1);
175    let mut rip = set & (tmp + (sub & sub.neg()) - set);
176
177    sub = (tmp & sub) ^ rip;
178    sub &= sub - Wrapping(1);
179
180    while !sub.is_zero() {
181        tmp = set & set.neg();
182
183        rip ^= tmp;
184        set ^= tmp;
185
186        sub &= sub - Wrapping(1);
187    }
188
189    rip.0
190}
191
192#[cfg(test)]
193mod tests {
194    use itertools::Itertools;
195
196    use crate::util::bits::{snoob_masked, SubSetCountIterator, SubSetIterator};
197
198    #[test]
199    fn subset_iterator_empty() {
200        assert_eq!(1, SubSetIterator::new(0).count());
201    }
202
203    #[test]
204    fn subset_iterator_standard() {
205        let mask = 0b01001100;
206
207        // check that the output values are right
208        let iter = SubSetIterator::new(mask);
209        let values = iter.collect_vec();
210        let expected = vec![
211            0b00000000, 0b00000100, 0b00001000, 0b00001100, 0b01000000, 0b01000100, 0b01001000, 0b01001100,
212        ];
213        assert_eq!(values, expected);
214
215        // check that it only iterates once
216        let mut iter = SubSetIterator::new(mask);
217        assert_eq!(8, iter.by_ref().count());
218        assert_eq!(0, iter.count());
219    }
220
221    #[test]
222    fn subset_count_iterator_standard() {
223        for mask in [0b01001100, 0b11111111] {
224            println!("mask={:b}", mask);
225
226            for m in 0..8 {
227                println!("m={}", m);
228
229                let expected = SubSetIterator::new(mask).filter(|v| v.count_ones() == m).collect_vec();
230                let actual = SubSetCountIterator::new(mask, m).collect_vec();
231
232                for a in &actual {
233                    println!("  {:b}", a);
234                }
235
236                assert_eq!(expected, actual);
237            }
238        }
239    }
240
241    #[test]
242    fn snoob_random() {
243        fn test(x: u64, y: u64, z: u64) {
244            let a = snoob_masked(x, y);
245            assert_eq!(a, z, "Expected 0x{:x}, got 0x{:x}", z, a);
246        }
247
248        // test values generated by running C version of the function on random inputs
249        test(0x7283e4c96896188c, 0x706b7f2de031bf37, 0x866246011a627);
250        test(0xfad96ea1180d0e12, 0x76509766802e6373, 0x7250004400204153);
251        test(0xe5a6f8869eb40f35, 0xf16528ff4aace975, 0xf06420780aa8c835);
252        test(0x9690d5c2f7a35fe0, 0x74d4cb118f57c2a5, 0x5440c100871442a5);
253        test(0x65282b0a08ebb2ec, 0xd1a0372c789578cd, 0x90200420688140cd);
254        test(0x3dac1b1d3add1987, 0xf9f42a39293c9190, 0xb8400a19281c1000);
255        test(0xa4ff0d56382d1243, 0xdfadd0cd921aba0, 0x8facd0480208900);
256        test(0xebec27ac83bd6464, 0x39a0a563e5a54560, 0x9a0252361a54060);
257        test(0x3b6a9e51e98ee1dc, 0x4c06832895fdc792, 0x28000846cc592);
258        test(0xc9afc50b67c6c180, 0xb96b34721235db30, 0xa92b246202251930);
259        test(0x950759b98d0b9d44, 0x2fb8f78ae913838f, 0x25b0958048038207);
260        test(0x41014bfee54406c9, 0xdd559e0e54ea7ed9, 0x41018c0644a27809);
261        test(0x63ed18bf633dbce0, 0xe68adf1c4dbf1fd6, 0x6402161c488102d6);
262        test(0xf8a1baa1355ebec3, 0x90765cca6bdae99, 0x52108201cae91);
263        test(0xd157a686cccad263, 0x6f7d04460b4d0672, 0x4e2504060a000402);
264        test(0xd93c30843c3fdb27, 0x3a5c127454f3448, 0x2218023010f1000);
265    }
266}