libreda-logic 0.0.3

Logic library for LibrEDA.
Documentation
// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

//! Enumerate cuts for each node in a logic network.

#![allow(unused)] // TODO remove once stabilized

use std::{cmp::Ordering, hash::Hash};

use crate::algorithms::visit::TopologicalIter;
use crate::network::Network;
use crate::network::NetworkShortcuts;
use fnv::FnvHashMap;
use itertools::Itertools;
use smallvec::{smallvec, SmallVec};

/// A `cut` in a logic network.
///
/// A cut is defined by a root node and a set of leaf nodes. Each path from the root node to a primary input of the network must pass
/// exactly once through a leaf node. Each path from a leaf node to a primary output must pass through the root node.
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Cut<NodeId> {
    root_node: NodeId,
    // IDs of leaf nodes of the cut. They are kept lexicographically sorted for faster set operations.
    leaf_nodes: SmallVec<[NodeId; 16]>,
}

impl<NodeId> Cut<NodeId>
where
    NodeId: Clone + Ord,
{
    fn trivial_cut(node: NodeId) -> Self {
        Self {
            root_node: node.clone(),
            leaf_nodes: smallvec![node],
        }
    }

    fn new(node: NodeId, mut leaf_nodes: SmallVec<[NodeId; 16]>) -> Self {
        leaf_nodes.sort();
        Self {
            root_node: node,
            leaf_nodes,
        }
    }
}

/// Compute the union of multiple cuts.
fn cut_union<'a, NodeId: 'a>(
    root: NodeId,
    cuts: impl Iterator<Item = &'a Cut<NodeId>>,
) -> Cut<NodeId>
where
    NodeId: Clone + Ord + Eq,
{
    let leaf_nodes_union = cuts
        .map(|cut| cut.leaf_nodes.iter().cloned())
        // Compute the set union by merging the sorted values and deduplication.
        .kmerge()
        .dedup()
        .collect();

    Cut {
        root_node: root,
        leaf_nodes: leaf_nodes_union,
    }
}

#[test]
fn test_cut_onion() {
    let a = Cut::new("root_a", smallvec!["w", "x", "z"]);
    let b = Cut::new("root_b", smallvec!["x", "y", "z"]);

    let u = cut_union("root_u", [a, b].iter());

    assert_eq!(u, Cut::new("root_u", smallvec!["w", "x", "y", "z"]));
}

/// Configuration for cut enumeration
#[derive(Clone, Copy, Debug)]
struct CutEnumerationConfig<F> {
    /// Maximum number of leaf nodes of a cut.
    max_cut_size: usize,
    /// Maximum selected cuts per node.
    max_cuts_per_node: usize,
    /// Comparision function used for priorizing cuts. Only the `max_cuts_per_node` best (least-cost) cuts will be stored.
    sort_cuts_by_fn: F,
}

/// Result of cut enumeration.
/// Stores a vector of cuts for each node.
#[derive(Clone, Debug)]
struct CutEnumeration<NodeId> {
    cuts: FnvHashMap<NodeId, Vec<Cut<NodeId>>>,
}

impl<NodeId> CutEnumeration<NodeId>
where
    NodeId: Hash + Eq,
{
    /// Get cuts of `node`.
    pub fn get_cuts(&self, node: &NodeId) -> &Vec<Cut<NodeId>> {
        &self.cuts[node]
    }
}

impl<F> CutEnumerationConfig<F> {
    /// Compute all cuts of the network according to the configuration.
    fn compute_cuts<N>(&self, network: &N) -> CutEnumeration<N::NodeId>
    where
        N: Network,
        N::NodeId: Ord,
        F: Fn(&Cut<N::NodeId>, &Cut<N::NodeId>) -> Ordering,
    {
        // Algorithm:
        // * Sort nodes topologically (leaves first, roots last)
        // * For each node create a set of 'best' cuts. Assemble a cut `n` from the cuts of the children of `n`.

        // Sort nodes topologically.
        let topo_sorted: Vec<_> = {
            let top_nodes = network
                .primary_outputs()
                .map(|signal| network.get_source_node(&signal))
                .collect();

            let sorted: Vec<_> = TopologicalIter::new(network, top_nodes)
                .visit_primary_inputs(true)
                .visit_constants(true)
                .collect();

            sorted
        };

        // Storage for computed cuts.
        let mut cuts: FnvHashMap<N::NodeId, Vec<Cut<N::NodeId>>> = Default::default();
        cuts.reserve(topo_sorted.len());

        for node in &topo_sorted {
            let signal = network.get_node_output(node);

            let mut current_cuts = if network.is_input(signal) || network.is_constant(signal) {
                // Primary inputs (and constants) have only a trivial cut.
                vec![Cut::trivial_cut(node.clone())]
            } else {
                // Iterator over child nodes.
                let child_nodes = (0..network.num_node_inputs(node))
                    .map(|i| network.get_node_input(node, i))
                    .map(|input_signal| network.get_source_node(&input_signal));

                // Get all computed cuts of the child nodes.
                let child_cuts: SmallVec<[&Vec<Cut<_>>; 16]> =
                    child_nodes.map(|child_node| &cuts[&child_node]).collect();

                // Construct cuts based on cuts of child nodes.
                let cut_unions: Vec<_> = child_cuts
                    .into_iter()
                    // Select one cut per child node. Iterate over all possible combinations of selections.
                    .multi_cartesian_product()
                    .map(|selected_cuts| cut_union(node.clone(), selected_cuts.into_iter()))
                    .chain([Cut::trivial_cut(node.clone())])
                    // Take cuts which are small enough only.
                    .filter(|cut| cut.leaf_nodes.len() <= self.max_cut_size)
                    .collect();

                cut_unions
            };

            // Take the best cuts only.
            current_cuts.sort_by(|c1, c2| (self.sort_cuts_by_fn)(c1, c2));

            current_cuts.truncate(self.max_cuts_per_node);

            cuts.insert(node.clone(), current_cuts);
        }

        CutEnumeration { cuts }
    }
}

#[test]
fn test_cut_enumeration() {
    use crate::network::*;
    use crate::networks::aig::*;

    let config = CutEnumerationConfig {
        max_cut_size: 100,
        max_cuts_per_node: 100,
        sort_cuts_by_fn: |_a: &Cut<AigNodeId>, _b: &Cut<AigNodeId>| -> Ordering { Ordering::Equal }, // Don't sort.
    };

    let mut net = Aig::new();

    let [a, b, c, d] = net.create_primary_inputs();

    let anb = net.create_and(a, b);
    let cnd = net.create_and(c, d);
    let top = net.create_and(anb, cnd);

    let _out = net.create_primary_output(top);

    let cuts = config.compute_cuts(&net);

    // Input nodes have only trivial cut.
    assert_eq!(cuts.get_cuts(&a).len(), 1);
    assert_eq!(cuts.get_cuts(&b).len(), 1);
    assert_eq!(cuts.get_cuts(&c).len(), 1);
    assert_eq!(cuts.get_cuts(&d).len(), 1);

    // First-level nodes should have the trivial cut + a cut including both inputs.
    assert_eq!(cuts.get_cuts(&anb).len(), 2);
    assert_eq!(cuts.get_cuts(&cnd).len(), 2);

    // Expected cuts:
    // * Trivial cut
    // * (top, [anb, cnd])
    // * (top, [a, b, cnd])
    // * (top, [anb, c, d])
    // * (top, [a, b, c, d])
    assert_eq!(cuts.get_cuts(&top).len(), 5);
}