use smallvec::SmallVec;
use crate::{bitmanip, traits::*};
use super::{bitflip_iter::BitFlippable, permutation_iter::Permutable};
pub trait InvertOutput {
fn inverted_output(self) -> Self;
fn invert_output(&mut self);
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct NPNTransform {
input_reordering: u64,
invert_inputs: u16,
invert_output: bool,
num_inputs: u8,
}
impl NPNTransform {
pub fn identity(num_inputs: usize) -> Self {
assert!(
num_inputs <= 16,
"maximum number of inputs is 16, found {}",
num_inputs
);
NPNTransform {
input_reordering: 0xfedcba9876543210,
invert_inputs: 0,
invert_output: false,
num_inputs: num_inputs as u8,
}
}
pub fn num_inputs(&self) -> usize {
self.num_inputs as usize
}
fn get_invert_bit(&self, idx: usize) -> bool {
assert!(
idx < self.num_inputs(),
"input index out of bounds: {} (num_inputs = {})",
idx,
self.num_inputs()
);
(self.invert_inputs >> idx) & 1 == 1
}
fn toggle_invert_bit(&mut self, idx: usize) {
assert!(
idx < self.num_inputs(),
"input index out of bounds: {} (num_inputs = {})",
idx,
self.num_inputs()
);
self.invert_inputs ^= 1 << idx
}
fn get_permutation_target(&self, input_idx: usize) -> usize {
assert!(input_idx < self.num_inputs());
((self.input_reordering >> (4 * input_idx)) & 0b1111) as usize
}
fn set_permutation_target(&mut self, input_idx: usize, value: usize) {
debug_assert!(value < self.num_inputs());
debug_assert!(input_idx < self.num_inputs());
let mask = 0b1111 << (input_idx * 4);
self.input_reordering =
(self.input_reordering & !mask) | ((value as u64) << (input_idx * 4))
}
pub fn apply<T>(&self, tt: T) -> T
where
T: NumInputs + BitFlippable + Permutable + InvertOutput,
{
assert_eq!(tt.num_inputs(), self.num_inputs());
let mut tt = if self.invert_output {
tt.inverted_output()
} else {
tt
};
(0..tt.num_inputs())
.filter(|i| self.get_invert_bit(*i))
.for_each(|i| tt.flip_bit(i));
{
let mut source_indices: SmallVec<[_; 8]> = (0..tt.num_inputs())
.map(|i| self.get_permutation_target(i))
.collect();
for i in 0..tt.num_inputs() {
if source_indices[i] != i {
let j = source_indices[i..]
.iter()
.copied()
.enumerate()
.find(|(_, t)| *t == i)
.map(|(j, _)| j)
.unwrap()
+ i;
tt.swap(i, j);
source_indices.swap(i, j);
}
}
}
tt
}
pub fn inverse(&self) -> Self {
let mut inverse = Self {
input_reordering: 0,
invert_inputs: 0,
invert_output: self.invert_output,
num_inputs: self.num_inputs,
};
for i in 0..self.num_inputs() {
let t = self.get_permutation_target(i);
inverse.set_permutation_target(t, i);
inverse.invert_inputs |= (bitmanip::get_bit(self.invert_inputs, t) as u16) << i;
}
inverse
}
}
impl InvertOutput for NPNTransform {
fn inverted_output(mut self) -> Self {
self.invert_output();
self
}
fn invert_output(&mut self) {
self.invert_output ^= true;
}
}
impl BitFlippable for NPNTransform {
fn num_bits(&self) -> usize {
self.num_inputs()
}
fn flip_bit(&mut self, bit_idx: usize) {
self.toggle_invert_bit(bit_idx);
}
}
impl Permutable for NPNTransform {
fn len(&self) -> usize {
self.num_inputs()
}
fn swap(&mut self, i: usize, j: usize) {
self.input_reordering =
bitmanip::swap_bit_patterns(self.input_reordering, 0b1111, i * 4, j * 4);
self.invert_inputs = bitmanip::swap_bits(self.invert_inputs, i, j);
}
}
#[test]
fn test_npn_transform_identity() {
let mut tf = NPNTransform::identity(8);
tf.flip_bit(3);
tf.swap(1, 3);
tf.flip_bit(1);
tf.swap(3, 5);
tf.swap(1, 5);
tf.swap(5, 3);
assert_eq!(tf, NPNTransform::identity(tf.num_inputs()));
}
#[test]
fn test_npn_transform_apply() {
use crate::truth_table::small_lut::SmallTruthTable;
let mux2 = SmallTruthTable::new(|[s, d0, d1]| if s { d1 } else { d0 });
let mut tf = NPNTransform::identity(3);
tf.flip_bit(0);
tf.swap(1, 2);
assert_ne!(tf, NPNTransform::identity(3));
assert_eq!(mux2, tf.apply(mux2));
}
#[test]
fn test_npn_transform_inverse() {
use crate::truth_table::small_lut::SmallTruthTable;
let tt = SmallTruthTable::from_table(0b0110011010111001, 4);
let mut tf = NPNTransform::identity(4);
tf.flip_bit(0);
tf.swap(1, 2);
tf.inverted_output();
tf.flip_bit(2);
let inverse = tf.inverse();
assert_eq!(inverse.apply(tf.apply(tt)), tt);
}