1use std::num::Wrapping;
4use std::ops::Neg;
5
6use num_traits::{PrimInt, Unsigned, WrappingSub, Zero};
7
8#[derive(Debug)]
9pub struct BitIter<N: PrimInt + Unsigned> {
20 left: N,
21}
22
23impl<N: PrimInt + Unsigned> BitIter<N> {
24 pub fn new(left: N) -> Self {
25 BitIter { left }
26 }
27}
28
29impl<N: PrimInt + Unsigned> Iterator for BitIter<N> {
30 type Item = u8;
31
32 fn next(&mut self) -> Option<<Self as Iterator>::Item> {
33 if self.left == N::zero() {
35 None
36 } else {
37 let index = self.left.trailing_zeros() as u8;
38 self.left = self.left & (self.left - N::one());
39 Some(index)
40 }
41 }
42}
43
44pub fn get_nth_set_bit<N: PrimInt + Unsigned + WrappingSub>(mut x: N, n: u32) -> u8 {
45 for _ in 0..n {
46 x = x & x.wrapping_sub(&N::one());
47 }
48 debug_assert!(x != N::zero());
49 x.trailing_zeros() as u8
50}
51
52#[derive(Debug)]
59pub struct SubSetIterator {
60 start: bool,
61 curr: u64,
62 mask: u64,
63}
64
65impl SubSetIterator {
66 pub fn new(mask: u64) -> Self {
67 Self {
68 start: true,
69 curr: 0,
70 mask,
71 }
72 }
73}
74
75impl Iterator for SubSetIterator {
76 type Item = u64;
77
78 fn next(&mut self) -> Option<Self::Item> {
79 if self.curr == 0 && !self.start {
80 return None;
81 }
82 self.start = false;
83
84 let result = self.curr;
85 self.curr = (self.curr.wrapping_sub(self.mask)) & self.mask;
86 Some(result)
87 }
88}
89
90#[derive(Debug)]
97pub struct SubSetCountIterator {
98 mask: u64,
99 curr: u64,
100 m: u32,
101}
102
103impl SubSetCountIterator {
104 pub fn new(mask: u64, m: u32) -> Self {
105 if m > mask.count_ones() {
106 return SubSetCountIterator {
108 mask,
109 curr: 0,
110 m: u32::MAX,
111 };
112 }
113
114 if m == 0 {
115 return SubSetCountIterator { mask, curr: 0, m: 0 };
117 }
118
119 let start = {
120 let mut left = m;
122 let mut start = 0;
123
124 for i in 0..64 {
125 if left == 0 {
126 break;
127 }
128
129 if mask & (1 << i) != 0 {
130 left -= 1;
131 start |= 1 << i;
132 }
133 }
134
135 start
136 };
137
138 SubSetCountIterator { mask, curr: start, m }
139 }
140}
141
142impl Iterator for SubSetCountIterator {
143 type Item = u64;
144
145 fn next(&mut self) -> Option<Self::Item> {
146 if self.m == 0 {
148 self.m = u32::MAX;
149 return Some(0);
150 }
151 if self.m == u32::MAX {
152 return None;
153 }
154 if self.curr.count_ones() != self.m {
155 self.m = u32::MAX;
156 return None;
157 }
158
159 if self.curr == 0 {
160 None
161 } else {
162 let result = self.curr;
163 self.curr = snoob_masked(self.curr, self.mask);
164 Some(result)
165 }
166 }
167}
168
169fn snoob_masked(sub: u64, set: u64) -> u64 {
171 let mut sub = Wrapping(sub);
172 let mut set = Wrapping(set);
173
174 let mut tmp = sub - Wrapping(1);
175 let mut rip = set & (tmp + (sub & sub.neg()) - set);
176
177 sub = (tmp & sub) ^ rip;
178 sub &= sub - Wrapping(1);
179
180 while !sub.is_zero() {
181 tmp = set & set.neg();
182
183 rip ^= tmp;
184 set ^= tmp;
185
186 sub &= sub - Wrapping(1);
187 }
188
189 rip.0
190}
191
192#[cfg(test)]
193mod tests {
194 use itertools::Itertools;
195
196 use crate::util::bits::{snoob_masked, SubSetCountIterator, SubSetIterator};
197
198 #[test]
199 fn subset_iterator_empty() {
200 assert_eq!(1, SubSetIterator::new(0).count());
201 }
202
203 #[test]
204 fn subset_iterator_standard() {
205 let mask = 0b01001100;
206
207 let iter = SubSetIterator::new(mask);
209 let values = iter.collect_vec();
210 let expected = vec![
211 0b00000000, 0b00000100, 0b00001000, 0b00001100, 0b01000000, 0b01000100, 0b01001000, 0b01001100,
212 ];
213 assert_eq!(values, expected);
214
215 let mut iter = SubSetIterator::new(mask);
217 assert_eq!(8, iter.by_ref().count());
218 assert_eq!(0, iter.count());
219 }
220
221 #[test]
222 fn subset_count_iterator_standard() {
223 for mask in [0b01001100, 0b11111111] {
224 println!("mask={:b}", mask);
225
226 for m in 0..8 {
227 println!("m={}", m);
228
229 let expected = SubSetIterator::new(mask).filter(|v| v.count_ones() == m).collect_vec();
230 let actual = SubSetCountIterator::new(mask, m).collect_vec();
231
232 for a in &actual {
233 println!(" {:b}", a);
234 }
235
236 assert_eq!(expected, actual);
237 }
238 }
239 }
240
241 #[test]
242 fn snoob_random() {
243 fn test(x: u64, y: u64, z: u64) {
244 let a = snoob_masked(x, y);
245 assert_eq!(a, z, "Expected 0x{:x}, got 0x{:x}", z, a);
246 }
247
248 test(0x7283e4c96896188c, 0x706b7f2de031bf37, 0x866246011a627);
250 test(0xfad96ea1180d0e12, 0x76509766802e6373, 0x7250004400204153);
251 test(0xe5a6f8869eb40f35, 0xf16528ff4aace975, 0xf06420780aa8c835);
252 test(0x9690d5c2f7a35fe0, 0x74d4cb118f57c2a5, 0x5440c100871442a5);
253 test(0x65282b0a08ebb2ec, 0xd1a0372c789578cd, 0x90200420688140cd);
254 test(0x3dac1b1d3add1987, 0xf9f42a39293c9190, 0xb8400a19281c1000);
255 test(0xa4ff0d56382d1243, 0xdfadd0cd921aba0, 0x8facd0480208900);
256 test(0xebec27ac83bd6464, 0x39a0a563e5a54560, 0x9a0252361a54060);
257 test(0x3b6a9e51e98ee1dc, 0x4c06832895fdc792, 0x28000846cc592);
258 test(0xc9afc50b67c6c180, 0xb96b34721235db30, 0xa92b246202251930);
259 test(0x950759b98d0b9d44, 0x2fb8f78ae913838f, 0x25b0958048038207);
260 test(0x41014bfee54406c9, 0xdd559e0e54ea7ed9, 0x41018c0644a27809);
261 test(0x63ed18bf633dbce0, 0xe68adf1c4dbf1fd6, 0x6402161c488102d6);
262 test(0xf8a1baa1355ebec3, 0x90765cca6bdae99, 0x52108201cae91);
263 test(0xd157a686cccad263, 0x6f7d04460b4d0672, 0x4e2504060a000402);
264 test(0xd93c30843c3fdb27, 0x3a5c127454f3448, 0x2218023010f1000);
265 }
266}