use smallvec::{smallvec, SmallVec};
use crate::network::NetworkEdit;
use crate::truth_table::bitflip_iter::BitFlippable;
use crate::truth_table::small_lut::SmallTT;
use crate::{
network::{
BinaryOp, EdgeWithInversion, MutNetworkNodeWithReferenceCount, NetworkNode,
NetworkNodeWithReferenceCount, TernaryOp, UnaryOp,
},
traits::*,
truth_table::small_lut::{truth_table_library, SmallTruthTable},
};
use truth_table_library as ttlib;
use super::generic_network::{LogicNetwork, NodeId};
use super::SimplifyResult;
pub type KLutNetwork = LogicNetwork<LutNode<NodeId>>;
impl KLutNetwork {
fn insert_node_normalized(&mut self, node: LutNode<NodeId>) -> NodeId {
match node.normalized() {
SimplifyResult::Node(n, inv) => self.create_node(n).invert_if(inv),
SimplifyResult::Simplified(id, inv) => id.invert_if(inv),
}
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct LutNode<NodeId> {
truth_table: SmallTruthTable,
inputs: SmallVec<[NodeId; 6]>, num_references: usize, }
impl<NodeId> LutNode<NodeId> {
pub fn new(truth_table: SmallTruthTable, inputs: SmallVec<[NodeId; 6]>) -> Self {
Self {
truth_table,
inputs,
num_references: 0,
}
}
}
impl<NodeId: EdgeWithInversion + Copy + Ord> LutNode<NodeId> {
fn normalize(self) -> SimplifyResult<Self, NodeId> {
let Self {
mut truth_table,
mut inputs,
num_references,
} = self;
inputs.iter_mut().enumerate().for_each(|(idx, input)| {
if input.is_inverted() {
*input = input.invert();
truth_table.flip_bit(idx);
}
});
for i in 0..inputs.len().saturating_sub(1) {
let min_idx = inputs[i..]
.iter()
.enumerate()
.min_by_key(|(_pos, input)| *input)
.map(|(pos, _)| pos + i)
.unwrap();
if min_idx != i {
inputs.swap(i, min_idx);
truth_table.swap_inputs(i, min_idx);
}
}
let truth_table_inv = truth_table.invert();
let (truth_table, inverted) = if truth_table < truth_table_inv {
(truth_table, false)
} else {
(truth_table_inv, true)
};
let normalized = Self {
truth_table,
inputs,
num_references,
};
SimplifyResult::Node(normalized, inverted)
}
}
#[test]
fn test_normalize_lut_node() {
let a = NodeId::new_node_id(2);
let b = NodeId::new_node_id(1);
let c = NodeId::new_node_id(3);
{
let maj3 = SmallTruthTable::new(|[a, b, c]| (a as u8) + (b as u8) + (c as u8) >= 2);
let maj3_modified =
SmallTruthTable::new(|[b, a, c]| (!a as u8) + (b as u8) + (c as u8) >= 2);
let node = LutNode::new(maj3, smallvec![a, b.invert(), c]);
match node.normalize() {
SimplifyResult::Node(node, inverted) => {
assert_eq!(node.inputs.as_slice(), [b, a, c].as_slice());
assert_eq!(node.truth_table, maj3_modified.invert_if(inverted));
}
SimplifyResult::Simplified(_, _) => assert!(false),
}
}
}
impl<NodeId: IdType + EdgeWithInversion> NetworkNode for LutNode<NodeId> {
type NodeId = NodeId;
fn num_inputs(&self) -> usize {
self.truth_table.num_inputs()
}
fn get_input(&self, i: usize) -> Self::NodeId {
self.inputs[i]
}
fn function(&self) -> SmallTruthTable {
self.truth_table
}
fn normalized(self) -> SimplifyResult<Self, Self::NodeId> {
self.normalize()
}
}
impl<NodeId> IntoIterator for LutNode<NodeId> {
type Item = NodeId;
type IntoIter = smallvec::IntoIter<[NodeId; 6]>;
fn into_iter(self) -> Self::IntoIter {
self.inputs.into_iter()
}
}
impl<NodeId: IdType + EdgeWithInversion> NetworkNodeWithReferenceCount for LutNode<NodeId> {
fn num_references(&self) -> usize {
self.num_references
}
}
impl<NodeId: IdType + EdgeWithInversion> MutNetworkNodeWithReferenceCount for LutNode<NodeId> {
fn reference(&mut self) {
self.num_references += 1
}
fn dereference(&mut self) {
self.num_references -= 1
}
}
impl UnaryOp for KLutNetwork {
fn create_not(&mut self, signal: Self::Signal) -> Self::Signal {
signal.invert()
}
}
impl BinaryOp for KLutNetwork {
fn create_and(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::and2(), smallvec![a, b]);
self.insert_node_normalized(node)
}
fn create_or(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::or2(), smallvec![a, b]);
self.insert_node_normalized(node)
}
fn create_nand(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::nand2(), smallvec![a, b]);
self.insert_node_normalized(node)
}
fn create_nor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::nor2(), smallvec![a, b]);
self.insert_node_normalized(node)
}
fn create_xor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::xor2(), smallvec![a, b]);
self.insert_node_normalized(node)
}
}
impl TernaryOp for KLutNetwork {
fn create_maj3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
let node = LutNode::new(ttlib::maj3(), smallvec![a, b, c]);
self.insert_node_normalized(node)
}
fn create_ite(
&mut self,
condition: Self::Signal,
then: Self::Signal,
otherwise: Self::Signal,
) -> Self::Signal {
let tt = SmallTruthTable::new(|[a, b, c]| if a { b } else { c });
self.insert_node_normalized(LutNode::new(tt, smallvec![condition, then, otherwise]))
}
fn create_xor3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
let tt = SmallTruthTable::new(|[a, b, c]| a ^ b ^ c);
self.insert_node_normalized(LutNode::new(tt, smallvec![a, b, c]))
}
}
#[test]
fn test_simulate_klut_graph() {
use crate::native_boolean_functions::NativeBooleanFunction;
use crate::network::NetworkEdit;
use crate::network::NetworkEditShortcuts;
use crate::traits::BooleanSystem;
let mut g = KLutNetwork::new();
let [in1, in2, carry_in] = g.create_primary_inputs();
let sum = g.create_xor3(in1, in2, carry_in);
let carry = g.create_maj3(in1, in2, carry_in);
let output_sum = g.create_primary_output(sum);
let output_carry = g.create_primary_output(carry);
let simulator = crate::network_simulator::RecursiveSim::new(&g);
fn full_adder([a, b, c]: [bool; 3]) -> [bool; 2] {
let sum = (a as usize) + (b as usize) + (c as usize);
[
sum & 0b1 == 1,
sum & 0b10 == 0b10, ]
}
let reference = NativeBooleanFunction::new(full_adder);
for i in 0..(1 << 3) {
let inputs = [0, 1, 2].map(|idx| (i >> idx) & 1 == 1);
let exptected_output = [0, 1].map(|out| reference.evaluate_term(&out, &inputs));
let actual_output: Vec<_> = simulator
.simulate(&[output_sum, output_carry], &inputs)
.collect();
dbg!(inputs);
assert_eq!(exptected_output.as_slice(), actual_output.as_slice());
}
}