libreda-logic 0.0.3

Logic library for LibrEDA.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! A logic network built from k-input lookup-tables ('k-LUT').

use smallvec::{smallvec, SmallVec};

use crate::network::NetworkEdit;
use crate::truth_table::bitflip_iter::BitFlippable;
use crate::truth_table::small_lut::SmallTT;
use crate::{
    network::{
        BinaryOp, EdgeWithInversion, MutNetworkNodeWithReferenceCount, NetworkNode,
        NetworkNodeWithReferenceCount, TernaryOp, UnaryOp,
    },
    traits::*,
    truth_table::small_lut::{truth_table_library, SmallTruthTable},
};
use truth_table_library as ttlib;

use super::generic_network::{LogicNetwork, NodeId};
use super::SimplifyResult;

/// A logic network which consists of `k`-input lookup-tables.
pub type KLutNetwork = LogicNetwork<LutNode<NodeId>>;

impl KLutNetwork {
    /// Normalize the node and insert it.
    fn insert_node_normalized(&mut self, node: LutNode<NodeId>) -> NodeId {
        match node.normalized() {
            SimplifyResult::Node(n, inv) => self.create_node(n).invert_if(inv),
            SimplifyResult::Simplified(id, inv) => id.invert_if(inv),
        }
    }
}

/// Step of a boolean chain. The operator implemented by the step is represented by a truth-table.
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct LutNode<NodeId> {
    /// Truth-table of the operator.
    truth_table: SmallTruthTable,
    inputs: SmallVec<[NodeId; 6]>, // Can have up to 6 inputs.
    num_references: usize,         // TODO: Move this to some `GeneralNode<Function>`
}

impl<NodeId> LutNode<NodeId> {
    /// Create a new `k-LUT` node. The node function is specified by the `truth_table`.
    pub fn new(truth_table: SmallTruthTable, inputs: SmallVec<[NodeId; 6]>) -> Self {
        Self {
            truth_table,
            inputs,
            num_references: 0,
        }
    }
}

impl<NodeId: EdgeWithInversion + Copy + Ord> LutNode<NodeId> {
    /// Normalize the node by flipping the inputs to their non-inverted version and ordering the inputs lexicographically.
    /// This also modifies the truth-table of the node.
    ///
    fn normalize(self) -> SimplifyResult<Self, NodeId> {
        let Self {
            mut truth_table,
            mut inputs,
            num_references,
        } = self;

        // Normalize polarities of inputs.
        // Flip inputs to their non-inverted version.
        inputs.iter_mut().enumerate().for_each(|(idx, input)| {
            if input.is_inverted() {
                *input = input.invert();
                truth_table.flip_bit(idx);
            }
        });

        // Normalize ordering of inputs.
        // Sort the inputs by swapping pairs of inputs.
        for i in 0..inputs.len().saturating_sub(1) {
            // Find the smallest signal ID in the remaining list.
            let min_idx = inputs[i..]
                .iter()
                .enumerate()
                .min_by_key(|(_pos, input)| *input)
                .map(|(pos, _)| pos + i)
                .unwrap();

            if min_idx != i {
                // Swap inputs if necessary.
                inputs.swap(i, min_idx);
                truth_table.swap_inputs(i, min_idx);
            }
        }

        // Invert the output value.
        let truth_table_inv = truth_table.invert();

        let (truth_table, inverted) = if truth_table < truth_table_inv {
            (truth_table, false)
        } else {
            (truth_table_inv, true)
        };

        let normalized = Self {
            truth_table,
            inputs,
            num_references,
        };

        SimplifyResult::Node(normalized, inverted)
    }
}

#[test]
fn test_normalize_lut_node() {
    let a = NodeId::new_node_id(2);
    let b = NodeId::new_node_id(1);
    let c = NodeId::new_node_id(3);

    {
        let maj3 = SmallTruthTable::new(|[a, b, c]| (a as u8) + (b as u8) + (c as u8) >= 2);
        let maj3_modified =
            SmallTruthTable::new(|[b, a, c]| (!a as u8) + (b as u8) + (c as u8) >= 2);

        let node = LutNode::new(maj3, smallvec![a, b.invert(), c]);

        match node.normalize() {
            SimplifyResult::Node(node, inverted) => {
                assert_eq!(node.inputs.as_slice(), [b, a, c].as_slice());
                assert_eq!(node.truth_table, maj3_modified.invert_if(inverted));
            }
            SimplifyResult::Simplified(_, _) => assert!(false),
        }
    }
}

impl<NodeId: IdType + EdgeWithInversion> NetworkNode for LutNode<NodeId> {
    type NodeId = NodeId;

    fn num_inputs(&self) -> usize {
        self.truth_table.num_inputs()
    }

    fn get_input(&self, i: usize) -> Self::NodeId {
        self.inputs[i]
    }

    fn function(&self) -> SmallTruthTable {
        self.truth_table
    }

    fn normalized(self) -> SimplifyResult<Self, Self::NodeId> {
        self.normalize()
    }
}

impl<NodeId> IntoIterator for LutNode<NodeId> {
    type Item = NodeId;

    type IntoIter = smallvec::IntoIter<[NodeId; 6]>;

    fn into_iter(self) -> Self::IntoIter {
        self.inputs.into_iter()
    }
}

impl<NodeId: IdType + EdgeWithInversion> NetworkNodeWithReferenceCount for LutNode<NodeId> {
    fn num_references(&self) -> usize {
        self.num_references
    }
}

impl<NodeId: IdType + EdgeWithInversion> MutNetworkNodeWithReferenceCount for LutNode<NodeId> {
    fn reference(&mut self) {
        self.num_references += 1
    }

    fn dereference(&mut self) {
        self.num_references -= 1
    }
}

impl UnaryOp for KLutNetwork {
    fn create_not(&mut self, signal: Self::Signal) -> Self::Signal {
        signal.invert()
    }
}

impl BinaryOp for KLutNetwork {
    fn create_and(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::and2(), smallvec![a, b]);
        self.insert_node_normalized(node)
    }

    fn create_or(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::or2(), smallvec![a, b]);
        self.insert_node_normalized(node)
    }

    fn create_nand(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::nand2(), smallvec![a, b]);
        self.insert_node_normalized(node)
    }

    fn create_nor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::nor2(), smallvec![a, b]);
        self.insert_node_normalized(node)
    }

    fn create_xor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::xor2(), smallvec![a, b]);
        self.insert_node_normalized(node)
    }
}

impl TernaryOp for KLutNetwork {
    fn create_maj3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
        let node = LutNode::new(ttlib::maj3(), smallvec![a, b, c]);
        self.insert_node_normalized(node)
    }

    fn create_ite(
        &mut self,
        condition: Self::Signal,
        then: Self::Signal,
        otherwise: Self::Signal,
    ) -> Self::Signal {
        let tt = SmallTruthTable::new(|[a, b, c]| if a { b } else { c });
        self.insert_node_normalized(LutNode::new(tt, smallvec![condition, then, otherwise]))
    }

    fn create_xor3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
        let tt = SmallTruthTable::new(|[a, b, c]| a ^ b ^ c);
        self.insert_node_normalized(LutNode::new(tt, smallvec![a, b, c]))
    }
}

#[test]
fn test_simulate_klut_graph() {
    use crate::native_boolean_functions::NativeBooleanFunction;
    use crate::network::NetworkEdit;
    use crate::network::NetworkEditShortcuts;
    use crate::traits::BooleanSystem;

    // Construct a one-bit full adder.
    let mut g = KLutNetwork::new();
    let [in1, in2, carry_in] = g.create_primary_inputs();

    let sum = g.create_xor3(in1, in2, carry_in);
    let carry = g.create_maj3(in1, in2, carry_in);

    let output_sum = g.create_primary_output(sum);
    let output_carry = g.create_primary_output(carry);

    let simulator = crate::network_simulator::RecursiveSim::new(&g);

    // Reference model of the full adder.
    fn full_adder([a, b, c]: [bool; 3]) -> [bool; 2] {
        let sum = (a as usize) + (b as usize) + (c as usize);
        [
            sum & 0b1 == 1,
            sum & 0b10 == 0b10, // carry
        ]
    }

    let reference = NativeBooleanFunction::new(full_adder);

    for i in 0..(1 << 3) {
        let inputs = [0, 1, 2].map(|idx| (i >> idx) & 1 == 1);

        let exptected_output = [0, 1].map(|out| reference.evaluate_term(&out, &inputs));
        let actual_output: Vec<_> = simulator
            .simulate(&[output_sum, output_carry], &inputs)
            .collect();

        dbg!(inputs);

        assert_eq!(exptected_output.as_slice(), actual_output.as_slice());
    }
}