1use std::ops::{Bound, RangeBounds};
2
3pub trait Combinations {
21 type Item;
22
23 fn combinations(&self, k: usize) -> CombinationIter<'_, Self::Item>;
39
40 fn combinations_range(&self, range: impl RangeBounds<usize>) -> CombinationRangeIter<'_, Self::Item>;
57}
58
59impl<T> Combinations for [T] {
60 type Item = T;
61 fn combinations(&self, k: usize) -> CombinationIter<'_, T> {
62 CombinationIter::new(self, k)
63 }
64 fn combinations_range(&self, range: impl RangeBounds<usize>) -> CombinationRangeIter<'_, T> {
65 CombinationRangeIter::new(self, range)
66 }
67}
68
69pub struct CombinationIter<'a, T> {
90 slice: &'a [T],
91 state: State,
92}
93
94enum State {
95 Bitmask {
96 current: u64,
97 limit: u64,
98 done: bool,
99 },
100 Index {
101 indices: Vec<usize>,
102 done: bool,
103 },
104}
105
106fn low_mask(bits: u32) -> u64 {
108 if bits >= 64 {
109 u64::MAX
110 } else if bits == 0 {
111 0
112 } else {
113 (1u64 << bits) - 1
114 }
115}
116
117impl<'a, T> CombinationIter<'a, T> {
118 pub fn next_into(&mut self, buf: &mut Vec<&'a T>) -> bool {
133 buf.clear();
134 match &mut self.state {
135 State::Bitmask {
136 current,
137 limit,
138 done,
139 } => {
140 if *done {
141 return false;
142 }
143
144 let v = *current;
145
146 let mut bits = v;
147 while bits != 0 {
148 let i = bits.trailing_zeros();
149 buf.push(&self.slice[i as usize]);
150 bits &= bits - 1;
151 }
152
153 if v == 0 {
154 *done = true;
155 return true;
156 }
157
158 let t = v | (v - 1);
159 if let Some(t1) = t.checked_add(1) {
160 let next = t1 | (((!t & t1) - 1) >> (v.trailing_zeros() + 1));
161 if next > *limit {
162 *done = true;
163 } else {
164 *current = next;
165 }
166 } else {
167 *done = true;
168 }
169
170 true
171 }
172 State::Index { indices, done } => {
173 if *done {
174 return false;
175 }
176
177 let k = indices.len();
178 let n = self.slice.len();
179
180 buf.extend(indices.iter().map(|&i| &self.slice[i]));
181
182 let mut i = k;
183 while i > 0 {
184 i -= 1;
185 if indices[i] < n - k + i {
186 indices[i] += 1;
187 for j in (i + 1)..k {
188 indices[j] = indices[j - 1] + 1;
189 }
190 return true;
191 }
192 }
193
194 *done = true;
195 true
196 }
197 }
198 }
199
200 fn new(slice: &'a [T], k: usize) -> Self {
201 let n = slice.len();
202 if n <= 64 {
203 if k > n {
204 return Self {
205 slice,
206 state: State::Bitmask {
207 current: 0,
208 limit: 0,
209 done: true,
210 },
211 };
212 }
213 let limit = low_mask(n as u32);
214 let start = low_mask(k as u32);
215 Self {
216 slice,
217 state: State::Bitmask {
218 current: start,
219 limit,
220 done: false,
221 },
222 }
223 } else {
224 if k > n {
225 return Self {
226 slice,
227 state: State::Index {
228 indices: Vec::new(),
229 done: true,
230 },
231 };
232 }
233 Self {
234 slice,
235 state: State::Index {
236 indices: (0..k).collect(),
237 done: false,
238 },
239 }
240 }
241 }
242}
243
244impl<'a, T> Iterator for CombinationIter<'a, T> {
245 type Item = Vec<&'a T>;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 let mut buf = Vec::new();
249 if self.next_into(&mut buf) {
250 Some(buf)
251 } else {
252 None
253 }
254 }
255}
256
257pub struct CombinationRangeIter<'a, T> {
274 slice: &'a [T],
275 current_k: usize,
276 end_k: usize,
277 inner: CombinationIter<'a, T>,
278}
279
280impl<'a, T> CombinationRangeIter<'a, T> {
281 fn new(slice: &'a [T], range: impl RangeBounds<usize>) -> Self {
282 let start_k = match range.start_bound() {
283 Bound::Included(&s) => s,
284 Bound::Excluded(&s) => s + 1,
285 Bound::Unbounded => 0,
286 };
287 let end_k = match range.end_bound() {
288 Bound::Included(&e) => e.min(slice.len()),
289 Bound::Excluded(&0) => {
290 return Self {
291 slice,
292 current_k: 1,
293 end_k: 0,
294 inner: CombinationIter::new(slice, 0),
295 };
296 }
297 Bound::Excluded(&e) => (e - 1).min(slice.len()),
298 Bound::Unbounded => slice.len(),
299 };
300 let current_k = start_k;
301 let inner = CombinationIter::new(slice, current_k);
302 Self {
303 slice,
304 current_k,
305 end_k,
306 inner,
307 }
308 }
309
310 pub fn next_into(&mut self, buf: &mut Vec<&'a T>) -> bool {
326 loop {
327 if self.inner.next_into(buf) {
328 return true;
329 }
330 if self.current_k >= self.end_k {
331 return false;
332 }
333 self.current_k += 1;
334 self.inner = CombinationIter::new(self.slice, self.current_k);
335 }
336 }
337}
338
339impl<'a, T> Iterator for CombinationRangeIter<'a, T> {
340 type Item = Vec<&'a T>;
341
342 fn next(&mut self) -> Option<Self::Item> {
343 let mut buf = Vec::new();
344 if self.next_into(&mut buf) {
345 Some(buf)
346 } else {
347 None
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn choose_2() {
358 let v = vec!["hej", "på", "dig"];
359 let got: Vec<Vec<&&str>> = v.combinations(2).collect();
360 assert_eq!(
361 got,
362 vec![
363 vec![&"hej", &"på"],
364 vec![&"hej", &"dig"],
365 vec![&"på", &"dig"],
366 ]
367 );
368 }
369
370 #[test]
371 fn choose_0() {
372 let got: Vec<Vec<i32>> = [1, 2, 3]
373 .combinations(0)
374 .map(|c| c.into_iter().cloned().collect())
375 .collect();
376 assert_eq!(got, [[] as [i32; 0]]);
377 }
378
379 #[test]
380 fn choose_all() {
381 let got: Vec<Vec<i32>> = [1, 2, 3]
382 .combinations(3)
383 .map(|c| c.into_iter().cloned().collect())
384 .collect();
385 assert_eq!(got, [[1, 2, 3]]);
386 }
387
388 #[test]
389 fn k_exceeds_len() {
390 let got: Vec<Vec<i32>> = [1, 2]
391 .combinations(5)
392 .map(|c| c.into_iter().cloned().collect())
393 .collect();
394 assert!(got.is_empty());
395 }
396
397 #[test]
398 fn count() {
399 assert_eq!([0; 6].combinations(3).count(), 20);
401 }
402
403 #[test]
404 fn correct_combinations() {
405 let items = [0, 1, 2, 3];
406 let got: Vec<Vec<i32>> = items
407 .combinations(2)
408 .map(|c| c.into_iter().cloned().collect())
409 .collect();
410 assert_eq!(got, [
412 [0, 1], [0, 2], [1, 2],
413 [0, 3], [1, 3],
414 [2, 3],
415 ]);
416 }
417
418 #[test]
419 fn empty_slice() {
420 let empty: &[i32] = &[];
421 let got: Vec<Vec<i32>> = empty
422 .combinations(0)
423 .map(|c| c.into_iter().cloned().collect())
424 .collect();
425 assert_eq!(got, [[] as [i32; 0]]);
426 assert!(empty.combinations(1).collect::<Vec<_>>().is_empty());
427 }
428
429 #[test]
430 fn next_into_reuses_buffer() {
431 let items = [1, 2, 3];
432 let mut iter = items.combinations(2);
433 let mut buf = Vec::new();
434 let mut got = Vec::new();
435 while iter.next_into(&mut buf) {
436 got.push(buf.iter().map(|&&x| x).collect::<Vec<i32>>());
437 }
438 assert_eq!(got, [[1, 2], [1, 3], [2, 3]]);
439 }
440
441 #[test]
442 fn range_inclusive() {
443 let got: Vec<Vec<i32>> = [1, 2, 3]
444 .combinations_range(0..=2)
445 .map(|c| c.into_iter().cloned().collect())
446 .collect();
447 assert_eq!(got, [
449 vec![],
450 vec![1], vec![2], vec![3],
451 vec![1, 2], vec![1, 3], vec![2, 3],
452 ]);
453 }
454
455 #[test]
456 fn range_exclusive() {
457 let got: Vec<Vec<i32>> = [1, 2, 3]
458 .combinations_range(1..3)
459 .map(|c| c.into_iter().cloned().collect())
460 .collect();
461 assert_eq!(got, [
462 vec![1], vec![2], vec![3],
463 vec![1, 2], vec![1, 3], vec![2, 3],
464 ]);
465 }
466
467 #[test]
468 fn range_full() {
469 let got: Vec<Vec<i32>> = [1, 2, 3]
471 .combinations_range(..)
472 .map(|c| c.into_iter().cloned().collect())
473 .collect();
474 assert_eq!(got, [
475 vec![],
476 vec![1], vec![2], vec![3],
477 vec![1, 2], vec![1, 3], vec![2, 3],
478 vec![1, 2, 3],
479 ]);
480 }
481
482 #[test]
483 fn range_empty() {
484 let got: Vec<Vec<&i32>> = [1, 2].combinations_range(5..=6).collect();
485 assert!(got.is_empty());
486 }
487
488 #[test]
489 fn works_on_vec() {
490 let v = vec![10, 20, 30];
491 let got: Vec<Vec<i32>> = v
492 .combinations(1)
493 .map(|c| c.into_iter().cloned().collect())
494 .collect();
495 assert_eq!(got, [[10], [20], [30]]);
496 }
497}