use crate::{
bitmanip,
traits::{
BooleanFunction, BooleanSystem, NumInputs, NumOutputs, PartialBooleanFunction,
PartialBooleanSystem, StaticNumInputs, StaticNumOutputs,
},
truth_table::{PartialTruthTable, TruthTable},
};
use super::TruthTableEdit;
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct SmallTruthTable {
lut: u64,
num_inputs: u8,
}
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct SmallStaticTruthTable<const NUM_INPUTS: usize> {
lut: u64,
}
impl<const STATIC_NUM_INPUTS: usize> From<SmallStaticTruthTable<STATIC_NUM_INPUTS>>
for SmallTruthTable
{
fn from(tt: SmallStaticTruthTable<STATIC_NUM_INPUTS>) -> Self {
Self {
lut: tt.lut,
num_inputs: STATIC_NUM_INPUTS as u8,
}
}
}
impl<const NUM_INPUTS: usize> TryFrom<SmallTruthTable> for SmallStaticTruthTable<NUM_INPUTS> {
type Error = ();
fn try_from(tt: SmallTruthTable) -> Result<Self, Self::Error> {
if tt.num_inputs() == NUM_INPUTS {
Ok(Self { lut: tt.lut })
} else {
Err(())
}
}
}
impl<const N: usize> NumInputs for SmallStaticTruthTable<N> {
fn num_inputs(&self) -> usize {
N
}
}
impl<const N: usize> NumOutputs for SmallStaticTruthTable<N> {
fn num_outputs(&self) -> usize {
1
}
}
const INDEX_COLUMNS: [u64; 6] = index_columns();
pub trait SmallTT: TruthTable + TruthTableEdit + Sized + Copy {
fn table(&self) -> u64;
fn set_table(self, table: u64) -> Self;
fn invert_if(self, condition: bool) -> Self {
let mask = (1 << ((condition as u64) << self.num_inputs())) - 1;
self.set_table(self.table() ^ mask)
}
fn invert(self) -> Self {
let mask = (1 << (1 << self.num_inputs())) - 1;
self.set_table(self.table() ^ mask)
}
fn bitwise_and(self, other: Self) -> Self {
bitwise_op2(self, other, |a, b| a & b)
}
fn swap_inputs(self, i: usize, j: usize) -> Self {
assert!(i < self.num_inputs());
assert!(j < self.num_inputs());
let (i, j) = match i <= j {
false => (j, i),
true => (i, j),
};
let idx_col_i = INDEX_COLUMNS[i];
let idx_col_j = INDEX_COLUMNS[j];
let shift_amount = (1 << j) - (1 << i);
let select_pattern = (idx_col_i ^ idx_col_j) & idx_col_i;
let permuted_output_bits =
bitmanip::swap_bit_patterns(self.table(), select_pattern, 0, shift_amount);
self.set_table(permuted_output_bits)
}
fn invert_input(self, i: usize) -> Self {
let select_pattern = !INDEX_COLUMNS[i];
let shift_amount = 1 << i;
let permuted_output_bits =
bitmanip::swap_bit_patterns(self.table(), select_pattern, 0, shift_amount);
self.set_table(permuted_output_bits)
}
fn count_ones(&self) -> usize {
self.table().count_ones() as usize
}
}
impl SmallTT for SmallTruthTable {
fn table(&self) -> u64 {
self.lut
}
fn set_table(mut self, table: u64) -> Self {
self.lut = table;
self
}
}
impl<const NUM_INPUTS: usize> SmallTT for SmallStaticTruthTable<NUM_INPUTS> {
fn table(&self) -> u64 {
self.lut
}
fn set_table(mut self, table: u64) -> Self {
self.lut = table;
self
}
}
impl<const NUM_INPUTS: usize> SmallStaticTruthTable<NUM_INPUTS> {
pub fn new(f: impl Fn([bool; NUM_INPUTS]) -> bool) -> SmallStaticTruthTable<NUM_INPUTS> {
assert!(
NUM_INPUTS <= 6,
"Number of inputs ({NUM_INPUTS}) exceeds the maximum (6)."
);
let table = (0..1 << NUM_INPUTS)
.map(|input_bits| {
let mut bits = [false; NUM_INPUTS];
(0..NUM_INPUTS)
.for_each(|bit_idx| bits[bit_idx] = (input_bits >> bit_idx) & 1 == 1);
(f(bits) as u64) << input_bits
})
.fold(0, |a, b| a | b);
Self { lut: table }
}
}
impl<const NUM_INPUTS: usize> PartialBooleanSystem for SmallStaticTruthTable<NUM_INPUTS> {
type LiteralId = u32;
type TermId = ();
fn evaluate_term_partial(&self, term: &(), input_values: &[bool]) -> Option<bool> {
Some(self.evaluate_term(term, input_values))
}
}
impl<const NUM_INPUTS: usize> BooleanSystem for SmallStaticTruthTable<NUM_INPUTS> {
fn evaluate_term(&self, _term: &(), input_values: &[bool]) -> bool {
let bits = input_values
.iter()
.rev()
.fold(0, |acc, bit| (acc << 1) | (*bit as u64));
self.get_bit(bits)
}
}
impl<const NUM_INPUTS: usize> PartialBooleanFunction for SmallStaticTruthTable<NUM_INPUTS> {
fn partial_eval(&self, input_values: &[bool]) -> Option<bool> {
Some(self.eval(input_values))
}
}
impl<const NUM_INPUTS: usize> BooleanFunction for SmallStaticTruthTable<NUM_INPUTS> {
fn eval(&self, input_values: &[bool]) -> bool {
let bits = input_values
.iter()
.rev()
.fold(0, |acc, &bit| (acc << 1) | (bit as u64));
self.get_bit(bits)
}
}
impl<const NUM_INPUTS: usize> PartialTruthTable for SmallStaticTruthTable<NUM_INPUTS> {
fn partial_evaluate(&self, input_bits: u64) -> Option<bool> {
Some(self.get_bit(input_bits))
}
}
impl<const NUM_INPUTS: usize> TruthTable for SmallStaticTruthTable<NUM_INPUTS> {
fn get_bit(&self, bits: u64) -> bool {
let mask = (1 << NUM_INPUTS) - 1;
let index = bits & mask;
(self.lut >> index) & 1 == 1
}
}
impl<const NUM_INPUTS: usize> TruthTableEdit for SmallStaticTruthTable<NUM_INPUTS> {
fn set_bit(&mut self, bit_index: usize, value: bool) {
assert!(
bit_index < (1 << self.num_inputs()),
"bit index out of range"
);
let mask = !(1 << bit_index);
self.lut = (self.lut & mask) | ((value as u64) << bit_index);
}
}
impl SmallTruthTable {
pub fn new<const NUM_INPUTS: usize>(f: impl Fn([bool; NUM_INPUTS]) -> bool) -> Self {
SmallStaticTruthTable::new(f).into()
}
pub fn zero(num_inputs: usize) -> Self {
assert!(num_inputs <= 6);
let num_inputs = num_inputs as u8;
Self { lut: 0, num_inputs }
}
pub const fn from_table(table: u64, num_inputs: usize) -> Self {
assert!(num_inputs <= 6);
Self {
lut: table,
num_inputs: num_inputs as u8,
}
}
pub fn from_boolean_function<F: BooleanFunction>(f: &F) -> Self {
let mut buffer = [false; 6];
let n_inputs = f.num_inputs();
assert!(
n_inputs <= 6,
"number of inputs must be <= 6 but is {n_inputs}"
);
let mut lut = 0u64;
for i in 0..(1 << n_inputs) {
for (j, item) in buffer.iter_mut().enumerate().take(n_inputs) {
*item = ((i >> j) & 1) == 1;
}
let output = f.eval(&buffer);
lut |= (output as u64) << i;
}
Self {
lut,
num_inputs: n_inputs as u8,
}
}
}
fn bitwise_op2<TT: SmallTT>(tt1: TT, tt2: TT, binary_op: impl Fn(u64, u64) -> u64) -> TT {
assert_eq!(tt1.num_inputs(), tt2.num_inputs());
tt1.set_table(binary_op(tt1.table(), tt2.table()))
}
const fn index_columns() -> [u64; 6] {
const N: usize = 6;
let mut index_columns = [0; N];
let mut state = !0u64;
let mut i = N as isize - 1;
while i >= 0 {
let shifted_state = state >> (1 << i);
state ^= shifted_state;
index_columns[i as usize] = state;
i -= 1;
}
index_columns
}
#[test]
fn test_index_columns() {
let cols = index_columns();
assert_eq!(
cols[0],
0b1010101010101010101010101010101010101010101010101010101010101010
);
assert_eq!(
cols[1],
0b1100110011001100110011001100110011001100110011001100110011001100
);
}
#[test]
fn test_swap_inputs() {
let mux_ab = SmallTruthTable::new(|[sel, a, b]| if sel { b } else { a });
assert_eq!(mux_ab.eval(&[false, false, true]), false);
assert_eq!(mux_ab.eval(&[true, false, true]), true);
let mux_ba = mux_ab.swap_inputs(1, 2);
assert_eq!(mux_ba.eval(&[false, false, true]), true);
assert_eq!(mux_ba.eval(&[true, false, true]), false);
}
#[test]
fn test_swap_inputs_random_table() {
let tt = SmallTruthTable {
lut: 0xe3b0c44298fc1c14, num_inputs: 6,
};
for i in 0..6 {
for j in 0..6 {
let tt_swapped = tt.swap_inputs(i, j);
for inputs in 0..(1 << 6) {
let inputs_swapped = bitmanip::swap_bits(inputs, i, j);
assert_eq!(tt.get_bit(inputs), tt_swapped.get_bit(inputs_swapped));
}
}
}
}
#[test]
fn test_invert_inputs_random_table() {
let tt = SmallTruthTable {
lut: 0xe3b0c44298fc1c14, num_inputs: 6,
};
for i in 0..6 {
let tt_swapped = tt.invert_input(i);
for inputs in 0..(1 << 6) {
let inputs_inverted_i = inputs ^ (1 << i);
assert_eq!(tt.get_bit(inputs), tt_swapped.get_bit(inputs_inverted_i));
}
}
}
pub mod truth_table_library {
use super::SmallTruthTable;
pub fn one() -> SmallTruthTable {
SmallTruthTable::new(|[]| true)
}
pub fn zero() -> SmallTruthTable {
SmallTruthTable::new(|[]| false)
}
pub fn identity1() -> SmallTruthTable {
SmallTruthTable::new(|[a]| a)
}
pub fn input_projection(num_inputs: usize, project_input: usize) -> SmallTruthTable {
assert!(project_input < num_inputs, "selected input out of range");
assert!(num_inputs <= 6, "no more than 6 inputs supported");
let num_tt_bits = 1 << num_inputs; let tt_bits = (0..num_tt_bits).map(|idx| ((idx >> project_input) & 1) << idx);
let tt = tt_bits.fold(0, |a, b| a | b);
SmallTruthTable::from_table(tt, num_inputs)
}
pub fn inv1() -> SmallTruthTable {
SmallTruthTable::new(|[a]| !a)
}
pub fn and(num_inputs: usize) -> SmallTruthTable {
assert!(num_inputs <= 6, "no more than 6 inputs supported");
let lut = 1 << ((1 << num_inputs) - 1);
let num_inputs = num_inputs as u8;
SmallTruthTable { lut, num_inputs }
}
pub fn or(num_inputs: usize) -> SmallTruthTable {
assert!(num_inputs <= 6, "no more than 6 inputs supported");
let lut = !1 & ((1 << (num_inputs + 1)) - 1);
let num_inputs = num_inputs as u8;
SmallTruthTable { lut, num_inputs }
}
pub fn and2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a & b)
}
pub fn or2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a | b)
}
pub fn nand2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| !(a & b))
}
pub fn nor2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| !(a | b))
}
pub fn xor2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a ^ b)
}
pub fn eq2() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a == b)
}
pub fn implication() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| !a & b)
}
pub fn converse() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a & !b)
}
pub fn less_than() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a < b)
}
pub fn less_or_equal_than() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a <= b)
}
pub fn greater_than() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a > b)
}
pub fn greater_or_equal_than() -> SmallTruthTable {
SmallTruthTable::new(|[a, b]| a >= b)
}
pub fn maj3() -> SmallTruthTable {
SmallTruthTable::new(|[a, b, c]| (a as u8) + (b as u8) + (c as u8) >= 2)
}
}
#[test]
fn test_create_small_truth_table() {
let t = SmallTruthTable::new(|[a, b]| a ^ b);
assert_eq!(t.num_inputs(), 2);
assert_eq!(t.lut, 0b0110)
}
#[test]
fn test_input_projection() {
use truth_table_library::input_projection;
for num_inputs in 0..=6 {
for selected_input in 0..num_inputs {
let tt = input_projection(num_inputs, selected_input);
for i in 0..tt.size() {
assert_eq!(tt.get_bit(i as u64), ((i >> selected_input) & 1) == 1);
}
}
}
}
#[test]
fn test_eval_small_truth_table() {
let maj3 = SmallTruthTable::new(|[a, b, c]| (a as u8) + (b as u8) + (c as u8) >= 2);
assert_eq!(maj3.get_bit(0b000), false);
assert_eq!(maj3.get_bit(0b001), false);
assert_eq!(maj3.get_bit(0b100), false);
assert_eq!(maj3.get_bit(0b011), true);
}
impl NumInputs for SmallTruthTable {
fn num_inputs(&self) -> usize {
self.num_inputs as usize
}
}
impl NumOutputs for SmallTruthTable {
fn num_outputs(&self) -> usize {
1
}
}
impl PartialBooleanSystem for SmallTruthTable {
type LiteralId = u32;
type TermId = ();
fn evaluate_term_partial(&self, term: &(), input_values: &[bool]) -> Option<bool> {
Some(self.evaluate_term(term, input_values))
}
}
impl BooleanSystem for SmallTruthTable {
fn evaluate_term(&self, _term: &(), input_values: &[bool]) -> bool {
let bits = input_values
.iter()
.rev()
.fold(0, |acc, bit| (acc << 1) | (*bit as u64));
self.get_bit(bits)
}
}
impl PartialBooleanFunction for SmallTruthTable {
fn partial_eval(&self, input_values: &[bool]) -> Option<bool> {
Some(self.eval(input_values))
}
}
impl PartialTruthTable for SmallTruthTable {
fn partial_evaluate(&self, input_bits: u64) -> Option<bool> {
Some(self.get_bit(input_bits))
}
}
impl TruthTableEdit for SmallTruthTable {
fn set_bit(&mut self, bit_index: usize, value: bool) {
assert!(
bit_index < (1 << self.num_inputs()),
"bit index out of range"
);
let mask = !(1 << bit_index);
self.lut = (self.lut & mask) | ((value as u64) << bit_index);
}
}
impl StaticNumOutputs<1> for SmallTruthTable {}
impl<const N: usize> StaticNumOutputs<1> for SmallStaticTruthTable<N> {}
impl<const N: usize> StaticNumInputs<N> for SmallStaticTruthTable<N> {}
impl BooleanFunction for SmallTruthTable {
fn eval(&self, input_values: &[bool]) -> bool {
let bits = input_values
.iter()
.rev()
.fold(0, |acc, &bit| (acc << 1) | (bit as u64));
self.get_bit(bits)
}
}
impl TruthTable for SmallTruthTable {
fn get_bit(&self, bits: u64) -> bool {
let mask = (1 << self.num_inputs) - 1;
let index = bits & mask;
(self.lut >> index) & 1 == 1
}
}