Skip to main content

combinations/
lib.rs

1use std::ops::{Bound, RangeBounds};
2
3/// Extension trait providing combination methods on slices.
4///
5/// ```
6/// use combinations::Combinations;
7///
8/// let items = [1, 2, 3, 4];
9///
10/// // All pairs
11/// let pairs: Vec<Vec<&i32>> = items.combinations(2).collect();
12/// assert_eq!(pairs.len(), 6);
13///
14/// // Works on Vec too
15/// let v = vec!["a", "b", "c"];
16/// for combo in v.combinations(2) {
17///     assert_eq!(combo.len(), 2);
18/// }
19/// ```
20pub trait Combinations {
21    type Item;
22
23    /// Returns an iterator over all `k`-element combinations.
24    ///
25    /// ```
26    /// use combinations::Combinations;
27    ///
28    /// let items = ["a", "b", "c"];
29    /// let got: Vec<Vec<&&str>> = items.combinations(2).collect();
30    /// assert_eq!(got, [vec![&"a", &"b"], vec![&"a", &"c"], vec![&"b", &"c"]]);
31    ///
32    /// // k=0 yields one empty combination
33    /// assert_eq!(items.combinations(0).count(), 1);
34    ///
35    /// // k > len yields nothing
36    /// assert!(items.combinations(10).next().is_none());
37    /// ```
38    fn combinations(&self, k: usize) -> CombinationIter<'_, Self::Item>;
39
40    /// Returns an iterator over combinations of all sizes within `range`.
41    ///
42    /// Accepts any [`RangeBounds<usize>`]: `0..=k`, `1..3`, `..`, etc.
43    /// Combinations are yielded in order of increasing size.
44    ///
45    /// ```
46    /// use combinations::Combinations;
47    ///
48    /// // All subsets of size 1 or 2
49    /// let items = [1, 2, 3];
50    /// let got: Vec<Vec<&i32>> = items.combinations_range(1..=2).collect();
51    /// assert_eq!(got.len(), 6); // C(3,1) + C(3,2) = 3 + 3
52    ///
53    /// // All 2^3 = 8 subsets
54    /// assert_eq!(items.combinations_range(..).count(), 8);
55    /// ```
56    fn combinations_range(&self, range: impl RangeBounds<usize>) -> CombinationRangeIter<'_, Self::Item>;
57}
58
59impl<T> Combinations for [T] {
60    type Item = T;
61    fn combinations(&self, k: usize) -> CombinationIter<'_, T> {
62        CombinationIter::new(self, k)
63    }
64    fn combinations_range(&self, range: impl RangeBounds<usize>) -> CombinationRangeIter<'_, T> {
65        CombinationRangeIter::new(self, range)
66    }
67}
68
69/// Iterator over all `k`-element combinations of a slice.
70/// Each item is a `Vec<&T>`.
71///
72/// For slices with at most 64 elements, uses a bitmask (Gosper's hack)
73/// internally for faster iteration. Falls back to index-based iteration
74/// for larger slices.
75///
76/// Created by [`Combinations::combinations`].
77///
78/// ```
79/// use combinations::Combinations;
80///
81/// let items = [10, 20, 30];
82/// let mut iter = items.combinations(2);
83///
84/// assert_eq!(iter.next(), Some(vec![&10, &20]));
85/// assert_eq!(iter.next(), Some(vec![&10, &30]));
86/// assert_eq!(iter.next(), Some(vec![&20, &30]));
87/// assert_eq!(iter.next(), None);
88/// ```
89pub struct CombinationIter<'a, T> {
90    slice: &'a [T],
91    state: State,
92}
93
94enum State {
95    Bitmask {
96        current: u64,
97        limit: u64,
98        done: bool,
99    },
100    Index {
101        indices: Vec<usize>,
102        done: bool,
103    },
104}
105
106/// Mask with the lowest `bits` bits set.
107fn low_mask(bits: u32) -> u64 {
108    if bits >= 64 {
109        u64::MAX
110    } else if bits == 0 {
111        0
112    } else {
113        (1u64 << bits) - 1
114    }
115}
116
117impl<'a, T> CombinationIter<'a, T> {
118    /// Fills `buf` with the next combination, returning `true` if one was
119    /// produced. The buffer is cleared before each call, so callers can
120    /// reuse the same `Vec` across iterations to avoid repeated allocation.
121    ///
122    /// ```
123    /// use combinations::Combinations;
124    ///
125    /// let items = [1, 2, 3];
126    /// let mut iter = items.combinations(2);
127    /// let mut buf = Vec::new();
128    /// while iter.next_into(&mut buf) {
129    ///     println!("{buf:?}"); // no allocation after the first call
130    /// }
131    /// ```
132    pub fn next_into(&mut self, buf: &mut Vec<&'a T>) -> bool {
133        buf.clear();
134        match &mut self.state {
135            State::Bitmask {
136                current,
137                limit,
138                done,
139            } => {
140                if *done {
141                    return false;
142                }
143
144                let v = *current;
145
146                let mut bits = v;
147                while bits != 0 {
148                    let i = bits.trailing_zeros();
149                    buf.push(&self.slice[i as usize]);
150                    bits &= bits - 1;
151                }
152
153                if v == 0 {
154                    *done = true;
155                    return true;
156                }
157
158                let t = v | (v - 1);
159                if let Some(t1) = t.checked_add(1) {
160                    let next = t1 | (((!t & t1) - 1) >> (v.trailing_zeros() + 1));
161                    if next > *limit {
162                        *done = true;
163                    } else {
164                        *current = next;
165                    }
166                } else {
167                    *done = true;
168                }
169
170                true
171            }
172            State::Index { indices, done } => {
173                if *done {
174                    return false;
175                }
176
177                let k = indices.len();
178                let n = self.slice.len();
179
180                buf.extend(indices.iter().map(|&i| &self.slice[i]));
181
182                let mut i = k;
183                while i > 0 {
184                    i -= 1;
185                    if indices[i] < n - k + i {
186                        indices[i] += 1;
187                        for j in (i + 1)..k {
188                            indices[j] = indices[j - 1] + 1;
189                        }
190                        return true;
191                    }
192                }
193
194                *done = true;
195                true
196            }
197        }
198    }
199
200    fn new(slice: &'a [T], k: usize) -> Self {
201        let n = slice.len();
202        if n <= 64 {
203            if k > n {
204                return Self {
205                    slice,
206                    state: State::Bitmask {
207                        current: 0,
208                        limit: 0,
209                        done: true,
210                    },
211                };
212            }
213            let limit = low_mask(n as u32);
214            let start = low_mask(k as u32);
215            Self {
216                slice,
217                state: State::Bitmask {
218                    current: start,
219                    limit,
220                    done: false,
221                },
222            }
223        } else {
224            if k > n {
225                return Self {
226                    slice,
227                    state: State::Index {
228                        indices: Vec::new(),
229                        done: true,
230                    },
231                };
232            }
233            Self {
234                slice,
235                state: State::Index {
236                    indices: (0..k).collect(),
237                    done: false,
238                },
239            }
240        }
241    }
242}
243
244impl<'a, T> Iterator for CombinationIter<'a, T> {
245    type Item = Vec<&'a T>;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        let mut buf = Vec::new();
249        if self.next_into(&mut buf) {
250            Some(buf)
251        } else {
252            None
253        }
254    }
255}
256
257/// Iterator over combinations of multiple sizes within a range.
258///
259/// Created by [`Combinations::combinations_range`].
260///
261/// ```
262/// use combinations::Combinations;
263///
264/// let items = [1, 2, 3];
265/// let mut iter = items.combinations_range(0..=1);
266///
267/// assert_eq!(iter.next(), Some(vec![]));       // k=0
268/// assert_eq!(iter.next(), Some(vec![&1]));     // k=1
269/// assert_eq!(iter.next(), Some(vec![&2]));
270/// assert_eq!(iter.next(), Some(vec![&3]));
271/// assert_eq!(iter.next(), None);
272/// ```
273pub struct CombinationRangeIter<'a, T> {
274    slice: &'a [T],
275    current_k: usize,
276    end_k: usize,
277    inner: CombinationIter<'a, T>,
278}
279
280impl<'a, T> CombinationRangeIter<'a, T> {
281    fn new(slice: &'a [T], range: impl RangeBounds<usize>) -> Self {
282        let start_k = match range.start_bound() {
283            Bound::Included(&s) => s,
284            Bound::Excluded(&s) => s + 1,
285            Bound::Unbounded => 0,
286        };
287        let end_k = match range.end_bound() {
288            Bound::Included(&e) => e.min(slice.len()),
289            Bound::Excluded(&0) => {
290                return Self {
291                    slice,
292                    current_k: 1,
293                    end_k: 0,
294                    inner: CombinationIter::new(slice, 0),
295                };
296            }
297            Bound::Excluded(&e) => (e - 1).min(slice.len()),
298            Bound::Unbounded => slice.len(),
299        };
300        let current_k = start_k;
301        let inner = CombinationIter::new(slice, current_k);
302        Self {
303            slice,
304            current_k,
305            end_k,
306            inner,
307        }
308    }
309
310    /// Fills `buf` with the next combination, returning `true` if one was
311    /// produced. See [`CombinationIter::next_into`] for details.
312    ///
313    /// ```
314    /// use combinations::Combinations;
315    ///
316    /// let items = [1, 2, 3];
317    /// let mut iter = items.combinations_range(1..=2);
318    /// let mut buf = Vec::new();
319    /// let mut count = 0;
320    /// while iter.next_into(&mut buf) {
321    ///     count += 1;
322    /// }
323    /// assert_eq!(count, 6); // C(3,1) + C(3,2)
324    /// ```
325    pub fn next_into(&mut self, buf: &mut Vec<&'a T>) -> bool {
326        loop {
327            if self.inner.next_into(buf) {
328                return true;
329            }
330            if self.current_k >= self.end_k {
331                return false;
332            }
333            self.current_k += 1;
334            self.inner = CombinationIter::new(self.slice, self.current_k);
335        }
336    }
337}
338
339impl<'a, T> Iterator for CombinationRangeIter<'a, T> {
340    type Item = Vec<&'a T>;
341
342    fn next(&mut self) -> Option<Self::Item> {
343        let mut buf = Vec::new();
344        if self.next_into(&mut buf) {
345            Some(buf)
346        } else {
347            None
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn choose_2() {
358        let v = vec!["hej", "på", "dig"];
359        let got: Vec<Vec<&&str>> = v.combinations(2).collect();
360        assert_eq!(
361            got,
362            vec![
363                vec![&"hej", &"på"],
364                vec![&"hej", &"dig"],
365                vec![&"på", &"dig"],
366            ]
367        );
368    }
369
370    #[test]
371    fn choose_0() {
372        let got: Vec<Vec<i32>> = [1, 2, 3]
373            .combinations(0)
374            .map(|c| c.into_iter().cloned().collect())
375            .collect();
376        assert_eq!(got, [[] as [i32; 0]]);
377    }
378
379    #[test]
380    fn choose_all() {
381        let got: Vec<Vec<i32>> = [1, 2, 3]
382            .combinations(3)
383            .map(|c| c.into_iter().cloned().collect())
384            .collect();
385        assert_eq!(got, [[1, 2, 3]]);
386    }
387
388    #[test]
389    fn k_exceeds_len() {
390        let got: Vec<Vec<i32>> = [1, 2]
391            .combinations(5)
392            .map(|c| c.into_iter().cloned().collect())
393            .collect();
394        assert!(got.is_empty());
395    }
396
397    #[test]
398    fn count() {
399        // C(6, 3) = 20
400        assert_eq!([0; 6].combinations(3).count(), 20);
401    }
402
403    #[test]
404    fn correct_combinations() {
405        let items = [0, 1, 2, 3];
406        let got: Vec<Vec<i32>> = items
407            .combinations(2)
408            .map(|c| c.into_iter().cloned().collect())
409            .collect();
410        // Co-lexicographic order (Gosper's hack / bitmask ascending)
411        assert_eq!(got, [
412            [0, 1], [0, 2], [1, 2],
413            [0, 3], [1, 3],
414            [2, 3],
415        ]);
416    }
417
418    #[test]
419    fn empty_slice() {
420        let empty: &[i32] = &[];
421        let got: Vec<Vec<i32>> = empty
422            .combinations(0)
423            .map(|c| c.into_iter().cloned().collect())
424            .collect();
425        assert_eq!(got, [[] as [i32; 0]]);
426        assert!(empty.combinations(1).collect::<Vec<_>>().is_empty());
427    }
428
429    #[test]
430    fn next_into_reuses_buffer() {
431        let items = [1, 2, 3];
432        let mut iter = items.combinations(2);
433        let mut buf = Vec::new();
434        let mut got = Vec::new();
435        while iter.next_into(&mut buf) {
436            got.push(buf.iter().map(|&&x| x).collect::<Vec<i32>>());
437        }
438        assert_eq!(got, [[1, 2], [1, 3], [2, 3]]);
439    }
440
441    #[test]
442    fn range_inclusive() {
443        let got: Vec<Vec<i32>> = [1, 2, 3]
444            .combinations_range(0..=2)
445            .map(|c| c.into_iter().cloned().collect())
446            .collect();
447        // k=0, then k=1, then k=2
448        assert_eq!(got, [
449            vec![],
450            vec![1], vec![2], vec![3],
451            vec![1, 2], vec![1, 3], vec![2, 3],
452        ]);
453    }
454
455    #[test]
456    fn range_exclusive() {
457        let got: Vec<Vec<i32>> = [1, 2, 3]
458            .combinations_range(1..3)
459            .map(|c| c.into_iter().cloned().collect())
460            .collect();
461        assert_eq!(got, [
462            vec![1], vec![2], vec![3],
463            vec![1, 2], vec![1, 3], vec![2, 3],
464        ]);
465    }
466
467    #[test]
468    fn range_full() {
469        // All 2^3 = 8 subsets of [1,2,3]
470        let got: Vec<Vec<i32>> = [1, 2, 3]
471            .combinations_range(..)
472            .map(|c| c.into_iter().cloned().collect())
473            .collect();
474        assert_eq!(got, [
475            vec![],
476            vec![1], vec![2], vec![3],
477            vec![1, 2], vec![1, 3], vec![2, 3],
478            vec![1, 2, 3],
479        ]);
480    }
481
482    #[test]
483    fn range_empty() {
484        let got: Vec<Vec<&i32>> = [1, 2].combinations_range(5..=6).collect();
485        assert!(got.is_empty());
486    }
487
488    #[test]
489    fn works_on_vec() {
490        let v = vec![10, 20, 30];
491        let got: Vec<Vec<i32>> = v
492            .combinations(1)
493            .map(|c| c.into_iter().cloned().collect())
494            .collect();
495        assert_eq!(got, [[10], [20], [30]]);
496    }
497}