libreda-logic 0.0.3

Logic library for LibrEDA.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! Bit manipulation on integers.

#![allow(unused)]

use num_traits::{
    ops::overflowing::{OverflowingAdd, OverflowingSub},
    PrimInt,
};

/// Set the bit at index `bit_idx` to `value`.
pub fn set_bit<T: PrimInt>(x: T, bit_idx: usize, value: bool) -> T {
    let mask = T::one() << bit_idx;
    let v = if value { T::one() } else { T::zero() };
    (x & !mask) | (v << bit_idx)
}

/// Get the value of the bit at index `bit_idx`.
pub fn get_bit<T: PrimInt>(x: T, bit_idx: usize) -> bool {
    (x >> bit_idx) & T::one() == T::one()
}

/// Swap bits `i` and `j` in `x`.
/// See (68) in "The Art of Computer Programming", vol. 4A, section 7.1.3.
pub fn swap_bits<T: PrimInt>(x: T, i: usize, j: usize) -> T {
    swap_bit_patterns(x, T::one(), i, j)
}

#[test]
fn test_swap_bits() {
    assert_eq!(swap_bits(0b100, 2, 2), 0b100);
    assert_eq!(swap_bits(0b01, 0, 1), 0b10);
    assert_eq!(swap_bits(0b100000, 5, 11), 0b100000000000);
}

/// Swap bit which are selected by `pattern << i` with the bits selected by `pattern << j`.
/// See (68) in "The Art of Computer Programming", vol. 4A, section 7.1.3.
pub fn swap_bit_patterns<T: PrimInt>(x: T, pattern: T, i: usize, j: usize) -> T {
    // Make sure i <= j.
    let (i, j) = match i <= j {
        false => (j, i),
        true => (i, j),
    };
    let delta = j - i;

    let mask_i = pattern << i;
    let y = (x ^ (x >> delta)) & mask_i;

    x ^ y ^ (y << delta)
}

#[test]
fn test_swap_bit_pattern() {
    assert_eq!(swap_bit_patterns(0b0011, 0b11, 0, 2), 0b1100);
    assert_eq!(swap_bit_patterns(0b0101, 0b101, 0, 1), 0b1010);
}

/// Given an integer with `n` 1-bits, generate the lexicographically next permutation.
///
/// Reference: <https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation>, [archived](https://web.archive.org/web/20221118181551/https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation)
pub fn next_bit_permutation<T>(current_bits: T) -> T
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    let one = T::one();
    let t = current_bits | current_bits.overflowing_sub(&one).0; // Set least significant 0-bits to 1.

    t.overflowing_add(&one).0
        | ((!t & t.overflowing_add(&one).0).overflowing_sub(&one).0
            >> (current_bits.trailing_zeros() as usize + 1))
}

#[test]
fn test_next_bit_permutation() {
    let expected_permutations = [
        0b00111, 0b01011, 0b01101, 0b01110, 0b10011, 0b10101, 0b10110, 0b11001, 0b11010, 0b11100,
    ];

    expected_permutations
        .windows(2)
        .for_each(|w| assert_eq!(next_bit_permutation(w[0]), w[1]));
}

pub struct BitChoiceIter<T> {
    remaining_len: usize,
    state: T,
}

impl<T> Iterator for BitChoiceIter<T>
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        if self.remaining_len == 0 {
            None
        } else {
            let output = self.state;

            if self.remaining_len > 1 {
                self.state = next_bit_permutation(self.state);
            }

            self.remaining_len -= 1;
            Some(output)
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (self.remaining_len, Some(self.remaining_len))
    }
}

/// Iterate over all choices of `k` bits out of `n`. The choice is encoded
/// as a bitmap in the `n` least significant bits of an integer. Each chosen bit is `1` all others are `0`.
pub fn all_bit_choices<T>(n: usize, k: usize) -> BitChoiceIter<T>
where
    T: PrimInt + OverflowingSub + OverflowingAdd,
{
    assert!(k <= n);

    fn factorial(n: usize) -> usize {
        (0..n + 1).skip(1).product()
    }

    // Binomial coefficients
    // (n)
    // (k)
    fn comb(n: usize, k: usize) -> usize {
        assert!(n >= k);
        let k = if k > n - k { k } else { n - k };

        // Compute `factorial(n) / factorial(k)`.
        let f_n_div_f_k: usize = (k..n + 1).skip(1).product();

        f_n_div_f_k / factorial(n - k)
    }
    // Get the start bit-pattern. The `k` least significant bits are set to 1.
    let start = (T::one() << k).overflowing_sub(&T::one()).0;

    let num_combinations = comb(n, k);

    BitChoiceIter {
        remaining_len: num_combinations,
        state: start,
    }
}

#[test]
fn test_bit_choices_iter() {
    let patterns: Vec<u64> = all_bit_choices(10, 0).collect();
    assert_eq!(patterns[0], 0b0);
    assert_eq!(patterns.len(), 1);

    let patterns: Vec<u64> = all_bit_choices(0, 0).collect();
    assert_eq!(patterns[0], 0b0);
    assert_eq!(patterns.len(), 1);

    let patterns: Vec<u64> = all_bit_choices(5, 2).collect();
    assert_eq!(patterns.len(), 10);
    assert_eq!(patterns[0], 0b11);
    assert_eq!(patterns[9], 0b11000);
}