libreda-logic 0.0.3

Logic library for LibrEDA.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! Abstraction of logic networks.

use blanket::blanket;

use std::hash::Hash;
use std::ops::*;

use crate::networks::SimplifyResult;
use crate::traits::BooleanFunction;
use crate::traits::LogicValue;

use crate::{traits::IdType, truth_table::small_lut::SmallTruthTable};

/// Basic properties of a logic network.
#[blanket(derive(Ref, Mut))]
pub trait Network {
    /// Type of the logic node.
    type Node: NetworkNode<NodeId = Self::Signal>;
    /// Identifier of a logic node in this network.
    type NodeId: Clone + PartialEq + Eq + Hash + Send + Sync + std::fmt::Debug + 'static;
    /// Type which represents the outputs of network nodes, e.g. signals. In contrast to a `NodeId` a `Signal` might also encode a logic inversion.
    type Signal: Copy + Clone + PartialEq + Eq + Hash + Send + Sync + std::fmt::Debug + 'static;
    /// Type of signal values. Typically this might be `bool`.
    type LogicValue: LogicValue;
    /// Type which represents the logic function of nodes in the network.
    type NodeFunction: BooleanFunction;

    /// Visit all network gates in an undefined order.
    /// Each gate is visited exactly once.
    fn foreach_gate(&self, f: impl Fn(Self::Signal));

    /// Visit all logic nodes.
    fn foreach_node(&self, f: impl Fn(&Self::Node));

    /// Get a graph signal which represents the constant `value`.
    fn get_constant(&self, value: Self::LogicValue) -> Self::Signal;

    /// Get the value of a signal if it is a constant.
    fn get_constant_value(&self, signal: Self::Signal) -> Option<Self::LogicValue>;

    /// Get an input value of a node.
    fn get_node_input(&self, node: &Self::NodeId, input_index: usize) -> Self::Signal;

    /// Get the output signal of a node.
    fn get_node_output(&self, node: &Self::NodeId) -> Self::Signal;

    /// Get the node which computes a signal.
    fn get_source_node(&self, signal: &Self::Signal) -> Self::NodeId;

    /// Get a primary input by its index.
    fn get_primary_input(&self, index: usize) -> Self::Signal;

    /// Get a primary output by its index.
    fn get_primary_output(&self, index: usize) -> Self::Signal;

    /// Tell if the signal is directly connected to a constant.
    fn is_constant(&self, signal: Self::Signal) -> bool {
        self.get_constant_value(signal).is_some()
    }

    /// Check if the signal is a primary input.
    fn is_input(&self, signal: Self::Signal) -> bool;

    /// Get the logic function implemented by the given node.
    fn node_function(&self, node: Self::NodeId) -> Self::NodeFunction;

    /// Number of gates present in the network.
    fn num_gates(&self) -> usize;

    /// Number of inputs into the given node.
    fn num_node_inputs(&self, node: &Self::NodeId) -> usize;

    /// Number of input into the network.
    fn num_primary_inputs(&self) -> usize;

    /// Number of outputs from the network.
    fn num_primary_outputs(&self) -> usize;
}

/// A logic network where all nodes implement the same boolean function.
pub trait HomogeneousNetwork: Network {
    /// Number of inputs of a node.
    const NUM_NODE_INPUTS: usize;

    /// Get the logic function of a node.
    fn function(&self) -> Self::NodeFunction;
}

/// Get the inverse of a signal without modifying the network.
/// This can typically be done if the inversion is stored in the signal identifier.
pub trait ImmutableNot: Network {
    /// Get the inverted signal.
    fn get_inverted(&self, a: Self::Signal) -> Self::Signal;

    /// Check if this signal is inverted.
    fn is_inverted(&self, a: Self::Signal) -> bool;
}

/// Provide the number of nodes which have a given node in their fan-in.
pub trait ReferenceCounted: Network {
    /// Count the number of nodes which have the signal `a` in their fan-in.
    fn num_references(&self, a: Self::Signal) -> usize;
}

/// Generic functions to manipulate a logic network.
#[blanket(derive(Mut))]
pub trait NetworkEdit: Network {
    /// Create a new input into the network.
    fn create_primary_input(&mut self) -> Self::Signal;

    /// Create an output of the network.
    fn create_primary_output(&mut self, signal: Self::Signal) -> Self::Signal;

    /// Substitute the `old` signal with the `new` signal in all nodes.
    fn substitute(&mut self, old: Self::Signal, new: Self::Signal);

    /// Insert a new node into the network.
    /// Returns the ID of the new node. This function can also return IDs that already existed before.
    /// For example if there is already a node with the same inputs.
    fn create_node(&mut self, node: Self::Node) -> Self::NodeId;

    /// Visit all logic nodes as mutable references.
    fn foreach_node_mut(&mut self, f: impl Fn(&mut Self::Node));
}

/// A network which supports substituting input signals of nodes.
#[blanket(derive(Mut))]
pub trait SubstituteInNode: Network {
    /// Substitute a signal with another signal in the given node.
    fn substitute_in_node(
        &mut self,
        node: Self::NodeId,
        old_signal: Self::Signal,
        new_signal: Self::Signal,
    );
}

/// Convenience functions implemented on top of the `Network` trait.
pub trait NetworkShortcuts: Network {
    /// Iterator over all primary inputs of the network.
    fn primary_inputs(&self) -> Box<dyn Iterator<Item = Self::NodeId> + '_> {
        Box::new(
            (0..self.num_primary_inputs())
                .map(|i| self.get_source_node(&self.get_primary_input(i))),
        )
    }

    /// Iterator over all primary outputs of the network.
    fn primary_outputs(&self) -> Box<dyn Iterator<Item = Self::Signal> + '_> {
        Box::new((0..self.num_primary_outputs()).map(|i| self.get_primary_output(i)))
    }
}
impl<T> NetworkShortcuts for T where T: Network {}

/// Convenience functions implemented on top of the `NetworkEdit` trait.
pub trait NetworkEditShortcuts: NetworkEdit {
    /// Create many inputs at once.
    fn create_primary_inputs<const NUM_INPUTS: usize>(&mut self) -> [Self::Signal; NUM_INPUTS] {
        [(); NUM_INPUTS].map(|_| self.create_primary_input())
    }
}

impl<T> NetworkEditShortcuts for T where T: NetworkEdit {}

/// Unary logic operations.
#[blanket(derive(Mut))]
pub trait UnaryOp: NetworkEdit {
    /// Replicate the signal.
    fn create_buffer(&mut self, signal: Self::Signal) -> Self::Signal {
        signal
    }
    /// Invert the signal.
    fn create_not(&mut self, signal: Self::Signal) -> Self::Signal;
}

/// Binary logic operations.
#[blanket(derive(Mut))]
pub trait BinaryOp: UnaryOp {
    /// Create the logic AND of `a` and `b`.
    fn create_and(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal;
    /// Create the logic OR of `a` and `b`.
    fn create_or(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal;
    /// Create the logic NAND of `a` and `b`.
    fn create_nand(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal;
    /// Create the logic NOR of `a` and `b`.
    fn create_nor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal;
    /// Create the logic XOR of `a` and `b`.
    fn create_xor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal;

    /// Create XNOR gate.
    fn create_xnor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let xor = self.create_xor(a, b);
        self.create_not(xor)
    }
    /// Less-than: Create a signal which is 1 only if `a == 0` and `b == 1`.
    fn create_lt(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let a_not = self.create_not(a);
        self.create_and(a_not, b)
    }
    /// Less-than-or-equal
    fn create_le(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let a_not = self.create_not(a);
        self.create_or(a_not, b)
    }
    /// Greater-than
    fn create_gt(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        self.create_lt(b, a)
    }
    /// Greater-than-or-equal
    fn create_ge(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        self.create_le(b, a)
    }

    /// Create a signal which is 1 only iff `a -> b`.
    fn create_implies(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
        let a_not = self.create_not(a);
        self.create_or(a_not, b)
    }
}

/// Logic operations with three inputs.
#[blanket(derive(Mut))]
pub trait TernaryOp: BinaryOp {
    /// Create three-input majority gate.
    fn create_maj3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
        // Default implementation.
        let ab = self.create_and(a, b);
        let bc = self.create_and(b, c);
        let ac = self.create_and(a, c);

        let bc_or_ac = self.create_or(bc, ac);

        self.create_or(ab, bc_or_ac)
    }

    /// Create if-then-else.
    fn create_ite(
        &mut self,
        condition: Self::Signal,
        then: Self::Signal,
        otherwise: Self::Signal,
    ) -> Self::Signal {
        // Default implementation
        let condition_not = self.create_not(condition);
        let a = self.create_and(condition, then);
        let b = self.create_and(condition_not, otherwise);
        self.create_or(a, b)
    }

    /// Create three-input XOR.
    fn create_xor3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
        let axb = self.create_xor(a, b);
        self.create_xor(axb, c)
    }
}

/// Logic operations with N inputs.
#[blanket(derive(Mut))]
pub trait NAryOp: TernaryOp {
    /// Create the logic AND of all provided inputs.
    fn create_nary_and(&mut self, inputs: impl Iterator<Item = Self::Signal>) -> Self::Signal {
        // TODO: reduce as balanced tree
        inputs
            .reduce(|acc, s| self.create_and(acc, s))
            .unwrap_or(self.get_constant(<Self as Network>::LogicValue::one()))
    }

    /// Create the logic OR of all provided inputs.
    fn create_nary_or(&mut self, inputs: impl Iterator<Item = Self::Signal>) -> Self::Signal {
        // TODO: reduce as balanced tree
        inputs
            .reduce(|acc, s| self.create_or(acc, s))
            .unwrap_or(self.get_constant(<Self as Network>::LogicValue::zero()))
    }
}

/// Basic trait of a node in a logic network.
pub trait NetworkNode: Clone + Eq + PartialEq {
    /// ID type used to identify nodes in the network.
    type NodeId: IdType;

    /// Get the number of inputs into this node.
    fn num_inputs(&self) -> usize;

    /// Get the i-th input.
    /// # Panics
    /// Panics if `i > self.num_inputs()`.
    fn get_input(&self, i: usize) -> Self::NodeId;

    /// Get the logic function of the node.
    fn function(&self) -> SmallTruthTable;

    /// Bring the node into a canonical form by reordering and inverting the inputs (if possible).
    /// This increases the effectiveness of structural hashing.
    ///
    /// The node can also be simplified to a single signal (i.e. an AND node with one input set to constant `0` would reduce to `0`).
    ///
    /// Returns a tuple with the normalized node and a boolean which is `true` iff the output of the node has been inverted by the normalization.
    fn normalized(self) -> SimplifyResult<Self, Self::NodeId>;
}

/// Mutable node in a logic network.
pub trait MutNetworkNode: NetworkNode {
    /// Set the input of a node.
    fn set_input(&mut self, i: usize, signal: Self::NodeId);
}

/// Trait for a node which knows the number of its references.
pub trait NetworkNodeWithReferenceCount: NetworkNode {
    /// Get the number of references to this node.
    fn num_references(&self) -> usize;
}

/// A logic node which implements a function known at compile time.
pub trait StaticFunction {
    // TODO const FUNCTION: TruthTable = ...
}

/// A logic node with a number of inputs known at compile time.
pub trait StaticInputDegree<const NUM_INPUTS: usize>: NetworkNode {
    /// Get node inputs as an array.
    fn to_array(&self) -> [Self::NodeId; NUM_INPUTS];
}

/// Trait for a node which allows changing its reference counter.
pub trait MutNetworkNodeWithReferenceCount: NetworkNodeWithReferenceCount {
    /// Increment the reference counter.
    fn reference(&mut self);

    /// Decrement the reference counter.
    /// # Panics
    /// Panics if the reference counter is already 0.
    fn dereference(&mut self);
}

/// Representation of a signal in a logic network.
pub trait Edge: Sized {}

/// Trait for an edge ID which encodes an optional inversion of the signal.
pub trait EdgeWithInversion: Edge {
    /// Check if the signal is inverted along the edge.
    fn is_inverted(&self) -> bool;

    /// Create the inverted version of this edge.
    fn invert(self) -> Self;

    /// Create the inverted version of this edge iff `condition` is set to `true`, otherwise return a copy of the edge.
    fn invert_if(self, condition: bool) -> Self {
        if condition {
            self.invert()
        } else {
            self
        }
    }

    /// Get the non-inverted version of this edge.
    fn non_inverted(self) -> Self {
        // Default implementation.
        if self.is_inverted() {
            self.invert()
        } else {
            self
        }
    }
}