libreda-logic 0.0.3

Logic library for LibrEDA.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! Transformations of the MIG based on the distributivity of the three-input majority function.

use crate::networks::generic_network::NodeId;
use crate::networks::{mig::*, SimplifyResult};
use itertools::Itertools;

impl Mig {
    /// Try to eliminate a node based on the distributivity of the majority function.
    ///
    /// Distributivity:
    /// * M(x, y, M(u, v, z)) == M(M(x, y, u), M(x, y, v), z)
    ///
    fn simplify_node_by_distributivity(
        &mut self,
        node: Maj3Node,
    ) -> SimplifyResult<Maj3Node, NodeId> {
        if let Some((x, y, u, v, z)) = self.match_distributivity(node) {
            let third_input = self.create_maj3(u, v, z);
            let new_node = self.create_maj3(x, y, third_input);
            SimplifyResult::new_id(new_node)
        } else {
            SimplifyResult::new_node(node)
        }
    }

    /// Try to match the pattern `M(M(x, y, u), M(x, y, v), z)` on the given node.
    /// This is used for eliminating a node based on the distributivity rule `M(M(x, y, u), M(x, y, v), z)` => `M(x, y, M(u, v, z))`.
    ///
    /// If the pattern matches, return `(x, y, u, v, z)`, otherwise return `None`.
    fn match_distributivity(
        &self,
        node: Maj3Node,
    ) -> Option<(MigNodeId, MigNodeId, MigNodeId, MigNodeId, MigNodeId)> {
        // Find signals which are shared in two or more child nodes.
        // Associate the signals with a bitmap which tells in which child nodes they appear.
        let signals_grouped_by_child_idx = node
            .into_iter()
            .enumerate()
            // Skip primary inputs.
            .filter_map(|(child_idx, nid)| {
                self.get_node(nid)
                    .to_logic_node()
                    .map(|idx| (child_idx, idx))
            })
            // 1) Get inputs of child nodes.
            // Deduplicate the inputs per child node.
            .map(|(child_idx, maj3)| maj3.into_iter().dedup().map(move |nid| (nid, child_idx)))
            // 2) Merge the inputs of the child nodes. The result is sorted because the inputs are sorted.
            .kmerge()
            // 3) Deduplicate signals and remember in which child nodes they appear.
            // 3.1) Group inputs together by their signal ID.
            .group_by(|(nid, _i)| *nid);

        let signals_shared_in_two_or_more_child_nodes: Vec<(MigNodeId, u32)> =  // TODO: use SmallVec of length 9
            signals_grouped_by_child_idx
            .into_iter()
            // For each signal, create a bitmap which tells in which children it appears.
            .map(|(nid, group_members)| {
                let signal_appears_in: u32 = group_members
                    // Look at the child node to which the input belongs
                    .map(|(_nid, child_idx)| child_idx)
                    // Create a bitmap. If the signal appears in child `j` then set the j-th bit.
                    .fold(0, |acc, child_idx| acc | (1 << child_idx));
                (nid, signal_appears_in)
            })
            // Take only signals which appear in two or more child nodes.
            .filter(|(_nid, appears_in)| appears_in.count_ones() >= 2)
            .collect();

        // Create all pairs.
        let pairs = signals_shared_in_two_or_more_child_nodes
            .iter()
            .enumerate()
            .flat_map(|(i, element1)| {
                signals_shared_in_two_or_more_child_nodes[i + 1..]
                    .iter()
                    .map(move |element2| (element1, element2))
            });

        // Find a pair of signals which appears in at least two nodes.
        let pair = pairs
            .filter(|((_nid1, appears_in1), (_nid2, appears_in2))| {
                (appears_in1 & appears_in2).count_ones() >= 2
            })
            .map(|((nid1, appears_in1), (nid2, appears_in2))| {
                (*nid1, *nid2, appears_in1 & appears_in2)
            })
            .next();

        // Find the `x`, `y`, `u`, `v` and `z` in `M(M(x, y, u), M(x, y, v), z)`
        if let Some((nid_x, nid_y, appears_in)) = pair {
            // There are at least two child nodes which both use the signals `nid1` and `nid2`.
            // The indices of the child nodes are encoded as a bitmap in `appears_in`.
            // => It is possible to eliminate a node based on the distributivity.
            debug_assert!(appears_in.count_ones() >= 2);
            debug_assert!(appears_in.count_ones() <= 3);
            debug_assert_ne!(nid_x, nid_y);

            // If there are three child nodes using both signals, take the first two only.
            let appears_in = if appears_in.count_ones() > 2 {
                appears_in & 0b011
            } else {
                appears_in
            };

            debug_assert_eq!(appears_in.count_ones(), 2);

            // Convert the bitmap into indices.
            let (child_idx1, child_idx2, third_input_idx) = match appears_in {
                0b011 => (0, 1, 2),
                0b101 => (0, 2, 1),
                0b110 => (2, 1, 0),
                _ => unreachable!("exactly two child nodes must be selected now"),
            };

            let children = node.to_array();

            // Sanity check: The selected child nodes must not be inputs nor constants.
            debug_assert!(self
                .get_node(children[child_idx1])
                .to_logic_node()
                .is_some());
            debug_assert!(self
                .get_node(children[child_idx2])
                .to_logic_node()
                .is_some());

            // Find the `u`, `v` and `z` in `M(M(x, y, u), M(x, y, v), z)`
            let child1 = self.get_node(children[child_idx1]).to_logic_node().unwrap(); // Unwrap is ok, the selected child nodes must be regular nodes.
            let child2 = self.get_node(children[child_idx2]).to_logic_node().unwrap(); // Unwrap is ok, the selected child nodes must be regular nodes.

            let z = children[third_input_idx];

            // Match the pattern `[u, x, x]` (including all permutations) and find u.
            let u = match child1.to_array() {
                [x, y, u] | [x, u, y] | [u, x, y]
                    if (x == nid_x && y == nid_y) || (y == nid_x && x == nid_y) =>
                {
                    u
                }
                _ => unreachable!(),
            };

            // Match the pattern `[v, x, x]` (including all permutations) and find v.
            let v = match child2.to_array() {
                [x, y, v] | [x, v, y] | [v, x, y]
                    if (x == nid_x && y == nid_y) || (y == nid_x && x == nid_y) =>
                {
                    v
                }
                _ => unreachable!(),
            };

            Some((nid_x, nid_y, u, v, z))
        } else {
            None
        }
    }
}

#[test]
fn test_match_simplify_distributivity() {
    let mut mig = Mig::new();

    let [x, y, u, v, z] = mig.create_primary_inputs();

    // Create the following pattern and test if it is matched correctly.
    // `M(M(x, y, u), M(x, y, v), z)`

    let m1 = mig.create_maj3(x, y, u);
    let m2 = mig.create_maj3(x, y, v);
    let m3 = mig.create_maj3(m1, m2, z);

    let (x1, y1, u1, v1, z1) = mig
        .match_distributivity(*mig.get_node(m3).to_logic_node().unwrap())
        .unwrap();

    assert_eq!(z, z1);
    assert!((x1, y1) == (x, y) || (y1, x1) == (x, y));
    assert!((u1, v1) == (u, v) || (u1, v1) == (u, v));
}

#[test]
fn test_match_simplify_distributivity_nomatch() {
    let mut mig = Mig::new();

    let [x, y, u, v, w, z] = mig.create_primary_inputs();

    // Create the following pattern and test if it is matched correctly.
    // `M(M(x, y, u), M(x, y, v), z)`

    let m1 = mig.create_maj3(x, y, u);
    let m2 = mig.create_maj3(x, w, v); // Should be (x, y, v) to create a match.
    let m3 = mig.create_maj3(m1, m2, z);

    let match_result = mig.match_distributivity(*mig.get_node(m3).to_logic_node().unwrap());

    assert_eq!(match_result, None);
}