1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! Bit manipulation on integers.

use num_traits::{
    ops::overflowing::{OverflowingAdd, OverflowingSub},
    PrimInt,
};

/// Set the bit at index `bit_idx` to `value`.
pub fn set_bit<T: PrimInt>(x: T, bit_idx: usize, value: bool) -> T {
    let mask = T::one() << bit_idx;
    let v = if value { T::one() } else { T::zero() };
    (x & !mask) | (v << bit_idx)
}

/// Get the value of the bit at index `bit_idx`.
pub fn get_bit<T: PrimInt>(x: T, bit_idx: usize) -> bool {
    (x >> bit_idx) & T::one() == T::one()
}

/// Swap bits `i` and `j` in `x`.
/// See (68) in "The Art of Computer Programming", vol. 4A, section 7.1.3.
pub fn swap_bits<T: PrimInt>(x: T, i: usize, j: usize) -> T {
    swap_bit_patterns(x, T::one(), i, j)
}

#[test]
fn test_swap_bits() {
    assert_eq!(swap_bits(0b100, 2, 2), 0b100);
    assert_eq!(swap_bits(0b01, 0, 1), 0b10);
    assert_eq!(swap_bits(0b100000, 5, 11), 0b100000000000);
}

/// Swap bit which are selected by `pattern << i` with the bits selected by `pattern << j`.
/// See (68) in "The Art of Computer Programming", vol. 4A, section 7.1.3.
pub fn swap_bit_patterns<T: PrimInt>(x: T, pattern: T, i: usize, j: usize) -> T {
    // Make sure i <= j.
    let (i, j) = match i <= j {
        false => (j, i),
        true => (i, j),
    };
    let delta = j - i;

    let mask_i = pattern << i;
    let y = (x ^ (x >> delta)) & mask_i;

    x ^ y ^ (y << delta)
}

#[test]
fn test_swap_bit_pattern() {
    assert_eq!(swap_bit_patterns(0b0011, 0b11, 0, 2), 0b1100);
    assert_eq!(swap_bit_patterns(0b0101, 0b101, 0, 1), 0b1010);
}

/// Given an integer with `n` 1-bits, generate the lexicographically next permutation.
///
/// Reference: <https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation>, [archived](https://web.archive.org/web/20221118181551/https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation)
pub fn next_bit_permutation<T>(current_bits: T) -> T
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    let one = T::one();
    let t = current_bits | current_bits.overflowing_sub(&one).0; // Set least significant 0-bits to 1.

    t.overflowing_add(&one).0
        | ((!t & t.overflowing_add(&one).0).overflowing_sub(&one).0
            >> (current_bits.trailing_zeros() as usize + 1))
}

#[test]
fn test_next_bit_permutation() {
    let expected_permutations = [
        0b00111, 0b01011, 0b01101, 0b01110, 0b10011, 0b10101, 0b10110, 0b11001, 0b11010, 0b11100,
    ];

    expected_permutations
        .windows(2)
        .for_each(|w| assert_eq!(next_bit_permutation(w[0]), w[1]));
}

pub struct BitChoiceIter<T> {
    remaining_len: usize,
    state: T,
}

impl<T> Iterator for BitChoiceIter<T>
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        if self.remaining_len == 0 {
            None
        } else {
            let output = self.state;

            if self.remaining_len > 1 {
                self.state = next_bit_permutation(self.state);
            }

            self.remaining_len -= 1;
            Some(output)
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (self.remaining_len, Some(self.remaining_len))
    }
}

/// Iterate over all choices of `k` bits out of `n`. The choice is encoded
/// as a bitmap in the `n` least significant bits of an integer. Each chosen bit is `1` all others are `0`.
pub fn all_bit_choices<T>(n: usize, k: usize) -> BitChoiceIter<T>
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    assert!(k <= n);

    fn factorial(n: usize) -> usize {
        (0..n + 1).skip(1).product()
    }

    // Binomial coefficients
    // (n)
    // (k)
    fn comb(n: usize, k: usize) -> usize {
        assert!(n >= k);
        let k = if k > n - k { k } else { n - k };

        // Compute `factorial(n) / factorial(k)`.
        let f_n_div_f_k: usize = (k..n + 1).skip(1).product();

        f_n_div_f_k / factorial(n - k)
    }
    // Get the start bit-pattern. The `k` least significant bits are set to 1.
    let start = (T::one() << k).overflowing_sub(&T::one()).0;

    let num_combinations = comb(n, k);

    BitChoiceIter {
        remaining_len: num_combinations,
        state: start,
    }
}

#[test]
fn test_bit_choices_iter() {
    let patterns: Vec<u64> = all_bit_choices(10, 0).collect();
    assert_eq!(patterns[0], 0b0);
    assert_eq!(patterns.len(), 1);

    let patterns: Vec<u64> = all_bit_choices(0, 0).collect();
    assert_eq!(patterns[0], 0b0);
    assert_eq!(patterns.len(), 1);

    let patterns: Vec<u64> = all_bit_choices(5, 2).collect();
    assert_eq!(patterns.len(), 10);
    assert_eq!(patterns[0], 0b11);
    assert_eq!(patterns[9], 0b11000);
}