use num_traits::{
ops::overflowing::{OverflowingAdd, OverflowingSub},
PrimInt,
};
pub fn set_bit<T: PrimInt>(x: T, bit_idx: usize, value: bool) -> T {
let mask = T::one() << bit_idx;
let v = if value { T::one() } else { T::zero() };
(x & !mask) | (v << bit_idx)
}
pub fn get_bit<T: PrimInt>(x: T, bit_idx: usize) -> bool {
(x >> bit_idx) & T::one() == T::one()
}
pub fn swap_bits<T: PrimInt>(x: T, i: usize, j: usize) -> T {
swap_bit_patterns(x, T::one(), i, j)
}
#[test]
fn test_swap_bits() {
assert_eq!(swap_bits(0b100, 2, 2), 0b100);
assert_eq!(swap_bits(0b01, 0, 1), 0b10);
assert_eq!(swap_bits(0b100000, 5, 11), 0b100000000000);
}
pub fn swap_bit_patterns<T: PrimInt>(x: T, pattern: T, i: usize, j: usize) -> T {
let (i, j) = match i <= j {
false => (j, i),
true => (i, j),
};
let delta = j - i;
let mask_i = pattern << i;
let y = (x ^ (x >> delta)) & mask_i;
x ^ y ^ (y << delta)
}
#[test]
fn test_swap_bit_pattern() {
assert_eq!(swap_bit_patterns(0b0011, 0b11, 0, 2), 0b1100);
assert_eq!(swap_bit_patterns(0b0101, 0b101, 0, 1), 0b1010);
}
pub fn next_bit_permutation<T>(current_bits: T) -> T
where
T: PrimInt + OverflowingSub + OverflowingAdd,
{
let one = T::one();
let t = current_bits | current_bits.overflowing_sub(&one).0; t.overflowing_add(&one).0
| ((!t & t.overflowing_add(&one).0).overflowing_sub(&one).0
>> (current_bits.trailing_zeros() as usize + 1))
}
#[test]
fn test_next_bit_permutation() {
let expected_permutations = [
0b00111, 0b01011, 0b01101, 0b01110, 0b10011, 0b10101, 0b10110, 0b11001, 0b11010, 0b11100,
];
expected_permutations
.windows(2)
.for_each(|w| assert_eq!(next_bit_permutation(w[0]), w[1]));
}
pub struct BitChoiceIter<T> {
remaining_len: usize,
state: T,
}
impl<T> Iterator for BitChoiceIter<T>
where
T: PrimInt + OverflowingSub + OverflowingAdd,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining_len == 0 {
None
} else {
let output = self.state;
if self.remaining_len > 1 {
self.state = next_bit_permutation(self.state);
}
self.remaining_len -= 1;
Some(output)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining_len, Some(self.remaining_len))
}
}
pub fn all_bit_choices<T>(n: usize, k: usize) -> BitChoiceIter<T>
where
T: PrimInt + OverflowingSub + OverflowingAdd,
{
assert!(k <= n);
fn factorial(n: usize) -> usize {
(0..n + 1).skip(1).product()
}
fn comb(n: usize, k: usize) -> usize {
assert!(n >= k);
let k = if k > n - k { k } else { n - k };
let f_n_div_f_k: usize = (k..n + 1).skip(1).product();
f_n_div_f_k / factorial(n - k)
}
let start = (T::one() << k).overflowing_sub(&T::one()).0;
let num_combinations = comb(n, k);
BitChoiceIter {
remaining_len: num_combinations,
state: start,
}
}
#[test]
fn test_bit_choices_iter() {
let patterns: Vec<u64> = all_bit_choices(10, 0).collect();
assert_eq!(patterns[0], 0b0);
assert_eq!(patterns.len(), 1);
let patterns: Vec<u64> = all_bit_choices(0, 0).collect();
assert_eq!(patterns[0], 0b0);
assert_eq!(patterns.len(), 1);
let patterns: Vec<u64> = all_bit_choices(5, 2).collect();
assert_eq!(patterns.len(), 10);
assert_eq!(patterns[0], 0b11);
assert_eq!(patterns[9], 0b11000);
}