libreda_logic/algorithms/
cut_enumeration.rs

1// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
2//
3// SPDX-License-Identifier: AGPL-3.0-or-later
4
5//! Enumerate cuts for each node in a logic network.
6
7#![allow(unused)] // TODO remove once stabilized
8
9use std::{cmp::Ordering, hash::Hash};
10
11use crate::algorithms::visit::TopologicalIter;
12use crate::network::Network;
13use crate::network::NetworkShortcuts;
14use fnv::FnvHashMap;
15use itertools::Itertools;
16use smallvec::{smallvec, SmallVec};
17
18/// A `cut` in a logic network.
19///
20/// A cut is defined by a root node and a set of leaf nodes. Each path from the root node to a primary input of the network must pass
21/// exactly once through a leaf node. Each path from a leaf node to a primary output must pass through the root node.
22#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Cut<NodeId> {
24    root_node: NodeId,
25    // IDs of leaf nodes of the cut. They are kept lexicographically sorted for faster set operations.
26    leaf_nodes: SmallVec<[NodeId; 16]>,
27}
28
29impl<NodeId> Cut<NodeId>
30where
31    NodeId: Clone + Ord,
32{
33    fn trivial_cut(node: NodeId) -> Self {
34        Self {
35            root_node: node.clone(),
36            leaf_nodes: smallvec![node],
37        }
38    }
39
40    fn new(node: NodeId, mut leaf_nodes: SmallVec<[NodeId; 16]>) -> Self {
41        leaf_nodes.sort();
42        Self {
43            root_node: node,
44            leaf_nodes,
45        }
46    }
47}
48
49/// Compute the union of multiple cuts.
50fn cut_union<'a, NodeId: 'a>(
51    root: NodeId,
52    cuts: impl Iterator<Item = &'a Cut<NodeId>>,
53) -> Cut<NodeId>
54where
55    NodeId: Clone + Ord + Eq,
56{
57    let leaf_nodes_union = cuts
58        .map(|cut| cut.leaf_nodes.iter().cloned())
59        // Compute the set union by merging the sorted values and deduplication.
60        .kmerge()
61        .dedup()
62        .collect();
63
64    Cut {
65        root_node: root,
66        leaf_nodes: leaf_nodes_union,
67    }
68}
69
70#[test]
71fn test_cut_onion() {
72    let a = Cut::new("root_a", smallvec!["w", "x", "z"]);
73    let b = Cut::new("root_b", smallvec!["x", "y", "z"]);
74
75    let u = cut_union("root_u", [a, b].iter());
76
77    assert_eq!(u, Cut::new("root_u", smallvec!["w", "x", "y", "z"]));
78}
79
80/// Configuration for cut enumeration
81#[derive(Clone, Copy, Debug)]
82struct CutEnumerationConfig<F> {
83    /// Maximum number of leaf nodes of a cut.
84    max_cut_size: usize,
85    /// Maximum selected cuts per node.
86    max_cuts_per_node: usize,
87    /// Comparision function used for priorizing cuts. Only the `max_cuts_per_node` best (least-cost) cuts will be stored.
88    sort_cuts_by_fn: F,
89}
90
91/// Result of cut enumeration.
92/// Stores a vector of cuts for each node.
93#[derive(Clone, Debug)]
94struct CutEnumeration<NodeId> {
95    cuts: FnvHashMap<NodeId, Vec<Cut<NodeId>>>,
96}
97
98impl<NodeId> CutEnumeration<NodeId>
99where
100    NodeId: Hash + Eq,
101{
102    /// Get cuts of `node`.
103    pub fn get_cuts(&self, node: &NodeId) -> &Vec<Cut<NodeId>> {
104        &self.cuts[node]
105    }
106}
107
108impl<F> CutEnumerationConfig<F> {
109    /// Compute all cuts of the network according to the configuration.
110    fn compute_cuts<N>(&self, network: &N) -> CutEnumeration<N::NodeId>
111    where
112        N: Network,
113        N::NodeId: Ord,
114        F: Fn(&Cut<N::NodeId>, &Cut<N::NodeId>) -> Ordering,
115    {
116        // Algorithm:
117        // * Sort nodes topologically (leaves first, roots last)
118        // * For each node create a set of 'best' cuts. Assemble a cut `n` from the cuts of the children of `n`.
119
120        // Sort nodes topologically.
121        let topo_sorted: Vec<_> = {
122            let top_nodes = network
123                .primary_outputs()
124                .map(|signal| network.get_source_node(&signal))
125                .collect();
126
127            let sorted: Vec<_> = TopologicalIter::new(network, top_nodes)
128                .visit_primary_inputs(true)
129                .visit_constants(true)
130                .collect();
131
132            sorted
133        };
134
135        // Storage for computed cuts.
136        let mut cuts: FnvHashMap<N::NodeId, Vec<Cut<N::NodeId>>> = Default::default();
137        cuts.reserve(topo_sorted.len());
138
139        for node in &topo_sorted {
140            let signal = network.get_node_output(node);
141
142            let mut current_cuts = if network.is_input(signal) || network.is_constant(signal) {
143                // Primary inputs (and constants) have only a trivial cut.
144                vec![Cut::trivial_cut(node.clone())]
145            } else {
146                // Iterator over child nodes.
147                let child_nodes = (0..network.num_node_inputs(node))
148                    .map(|i| network.get_node_input(node, i))
149                    .map(|input_signal| network.get_source_node(&input_signal));
150
151                // Get all computed cuts of the child nodes.
152                let child_cuts: SmallVec<[&Vec<Cut<_>>; 16]> =
153                    child_nodes.map(|child_node| &cuts[&child_node]).collect();
154
155                // Construct cuts based on cuts of child nodes.
156                let cut_unions: Vec<_> = child_cuts
157                    .into_iter()
158                    // Select one cut per child node. Iterate over all possible combinations of selections.
159                    .multi_cartesian_product()
160                    .map(|selected_cuts| cut_union(node.clone(), selected_cuts.into_iter()))
161                    .chain([Cut::trivial_cut(node.clone())])
162                    // Take cuts which are small enough only.
163                    .filter(|cut| cut.leaf_nodes.len() <= self.max_cut_size)
164                    .collect();
165
166                cut_unions
167            };
168
169            // Take the best cuts only.
170            current_cuts.sort_by(|c1, c2| (self.sort_cuts_by_fn)(c1, c2));
171
172            current_cuts.truncate(self.max_cuts_per_node);
173
174            cuts.insert(node.clone(), current_cuts);
175        }
176
177        CutEnumeration { cuts }
178    }
179}
180
181#[test]
182fn test_cut_enumeration() {
183    use crate::network::*;
184    use crate::networks::aig::*;
185
186    let config = CutEnumerationConfig {
187        max_cut_size: 100,
188        max_cuts_per_node: 100,
189        sort_cuts_by_fn: |_a: &Cut<AigNodeId>, _b: &Cut<AigNodeId>| -> Ordering { Ordering::Equal }, // Don't sort.
190    };
191
192    let mut net = Aig::new();
193
194    let [a, b, c, d] = net.create_primary_inputs();
195
196    let anb = net.create_and(a, b);
197    let cnd = net.create_and(c, d);
198    let top = net.create_and(anb, cnd);
199
200    let _out = net.create_primary_output(top);
201
202    let cuts = config.compute_cuts(&net);
203
204    // Input nodes have only trivial cut.
205    assert_eq!(cuts.get_cuts(&a).len(), 1);
206    assert_eq!(cuts.get_cuts(&b).len(), 1);
207    assert_eq!(cuts.get_cuts(&c).len(), 1);
208    assert_eq!(cuts.get_cuts(&d).len(), 1);
209
210    // First-level nodes should have the trivial cut + a cut including both inputs.
211    assert_eq!(cuts.get_cuts(&anb).len(), 2);
212    assert_eq!(cuts.get_cuts(&cnd).len(), 2);
213
214    // Expected cuts:
215    // * Trivial cut
216    // * (top, [anb, cnd])
217    // * (top, [a, b, cnd])
218    // * (top, [anb, c, d])
219    // * (top, [a, b, c, d])
220    assert_eq!(cuts.get_cuts(&top).len(), 5);
221}