libreda_logic/algorithms/
cut_enumeration.rs1#![allow(unused)] use std::{cmp::Ordering, hash::Hash};
10
11use crate::algorithms::visit::TopologicalIter;
12use crate::network::Network;
13use crate::network::NetworkShortcuts;
14use fnv::FnvHashMap;
15use itertools::Itertools;
16use smallvec::{smallvec, SmallVec};
17
18#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Cut<NodeId> {
24 root_node: NodeId,
25 leaf_nodes: SmallVec<[NodeId; 16]>,
27}
28
29impl<NodeId> Cut<NodeId>
30where
31 NodeId: Clone + Ord,
32{
33 fn trivial_cut(node: NodeId) -> Self {
34 Self {
35 root_node: node.clone(),
36 leaf_nodes: smallvec![node],
37 }
38 }
39
40 fn new(node: NodeId, mut leaf_nodes: SmallVec<[NodeId; 16]>) -> Self {
41 leaf_nodes.sort();
42 Self {
43 root_node: node,
44 leaf_nodes,
45 }
46 }
47}
48
49fn cut_union<'a, NodeId: 'a>(
51 root: NodeId,
52 cuts: impl Iterator<Item = &'a Cut<NodeId>>,
53) -> Cut<NodeId>
54where
55 NodeId: Clone + Ord + Eq,
56{
57 let leaf_nodes_union = cuts
58 .map(|cut| cut.leaf_nodes.iter().cloned())
59 .kmerge()
61 .dedup()
62 .collect();
63
64 Cut {
65 root_node: root,
66 leaf_nodes: leaf_nodes_union,
67 }
68}
69
70#[test]
71fn test_cut_onion() {
72 let a = Cut::new("root_a", smallvec!["w", "x", "z"]);
73 let b = Cut::new("root_b", smallvec!["x", "y", "z"]);
74
75 let u = cut_union("root_u", [a, b].iter());
76
77 assert_eq!(u, Cut::new("root_u", smallvec!["w", "x", "y", "z"]));
78}
79
80#[derive(Clone, Copy, Debug)]
82struct CutEnumerationConfig<F> {
83 max_cut_size: usize,
85 max_cuts_per_node: usize,
87 sort_cuts_by_fn: F,
89}
90
91#[derive(Clone, Debug)]
94struct CutEnumeration<NodeId> {
95 cuts: FnvHashMap<NodeId, Vec<Cut<NodeId>>>,
96}
97
98impl<NodeId> CutEnumeration<NodeId>
99where
100 NodeId: Hash + Eq,
101{
102 pub fn get_cuts(&self, node: &NodeId) -> &Vec<Cut<NodeId>> {
104 &self.cuts[node]
105 }
106}
107
108impl<F> CutEnumerationConfig<F> {
109 fn compute_cuts<N>(&self, network: &N) -> CutEnumeration<N::NodeId>
111 where
112 N: Network,
113 N::NodeId: Ord,
114 F: Fn(&Cut<N::NodeId>, &Cut<N::NodeId>) -> Ordering,
115 {
116 let topo_sorted: Vec<_> = {
122 let top_nodes = network
123 .primary_outputs()
124 .map(|signal| network.get_source_node(&signal))
125 .collect();
126
127 let sorted: Vec<_> = TopologicalIter::new(network, top_nodes)
128 .visit_primary_inputs(true)
129 .visit_constants(true)
130 .collect();
131
132 sorted
133 };
134
135 let mut cuts: FnvHashMap<N::NodeId, Vec<Cut<N::NodeId>>> = Default::default();
137 cuts.reserve(topo_sorted.len());
138
139 for node in &topo_sorted {
140 let signal = network.get_node_output(node);
141
142 let mut current_cuts = if network.is_input(signal) || network.is_constant(signal) {
143 vec![Cut::trivial_cut(node.clone())]
145 } else {
146 let child_nodes = (0..network.num_node_inputs(node))
148 .map(|i| network.get_node_input(node, i))
149 .map(|input_signal| network.get_source_node(&input_signal));
150
151 let child_cuts: SmallVec<[&Vec<Cut<_>>; 16]> =
153 child_nodes.map(|child_node| &cuts[&child_node]).collect();
154
155 let cut_unions: Vec<_> = child_cuts
157 .into_iter()
158 .multi_cartesian_product()
160 .map(|selected_cuts| cut_union(node.clone(), selected_cuts.into_iter()))
161 .chain([Cut::trivial_cut(node.clone())])
162 .filter(|cut| cut.leaf_nodes.len() <= self.max_cut_size)
164 .collect();
165
166 cut_unions
167 };
168
169 current_cuts.sort_by(|c1, c2| (self.sort_cuts_by_fn)(c1, c2));
171
172 current_cuts.truncate(self.max_cuts_per_node);
173
174 cuts.insert(node.clone(), current_cuts);
175 }
176
177 CutEnumeration { cuts }
178 }
179}
180
181#[test]
182fn test_cut_enumeration() {
183 use crate::network::*;
184 use crate::networks::aig::*;
185
186 let config = CutEnumerationConfig {
187 max_cut_size: 100,
188 max_cuts_per_node: 100,
189 sort_cuts_by_fn: |_a: &Cut<AigNodeId>, _b: &Cut<AigNodeId>| -> Ordering { Ordering::Equal }, };
191
192 let mut net = Aig::new();
193
194 let [a, b, c, d] = net.create_primary_inputs();
195
196 let anb = net.create_and(a, b);
197 let cnd = net.create_and(c, d);
198 let top = net.create_and(anb, cnd);
199
200 let _out = net.create_primary_output(top);
201
202 let cuts = config.compute_cuts(&net);
203
204 assert_eq!(cuts.get_cuts(&a).len(), 1);
206 assert_eq!(cuts.get_cuts(&b).len(), 1);
207 assert_eq!(cuts.get_cuts(&c).len(), 1);
208 assert_eq!(cuts.get_cuts(&d).len(), 1);
209
210 assert_eq!(cuts.get_cuts(&anb).len(), 2);
212 assert_eq!(cuts.get_cuts(&cnd).len(), 2);
213
214 assert_eq!(cuts.get_cuts(&top).len(), 5);
221}