use std::cmp;
use std::iter::zip;
use fxhash::{FxHashMap, FxHashSet};
use itertools::Itertools;
use crate::network::NaryType;
use crate::{Gate, Network, Signal};
fn merge_dependencies<F: Fn(&Gate) -> bool>(
aig: &Network,
g: &Gate,
max_size: usize,
pred: F,
) -> Box<[Signal]> {
let v = g.dependencies();
let mut ret = Vec::new();
let mut remaining = v.len();
for s in v.iter() {
remaining -= 1;
if !s.is_var() || s.is_inverted() {
ret.push(*s);
} else {
let prev_g = aig.gate(s.var() as usize);
let prev_deps = prev_g.dependencies();
if pred(prev_g) && ret.len() + prev_deps.len() + remaining <= max_size {
ret.extend(prev_deps);
} else {
ret.push(*s);
}
}
}
ret.into()
}
pub fn flatten_nary(aig: &Network, max_size: usize) -> Network {
let mut ret = aig.clone();
for i in 0..ret.nb_nodes() {
if ret.gate(i).is_and() {
ret.replace(
i,
Gate::Nary(
merge_dependencies(&ret, ret.gate(i), max_size, |t| t.is_and()),
NaryType::And,
),
);
} else if ret.gate(i).is_xor() {
ret.replace(
i,
Gate::Nary(
merge_dependencies(&ret, ret.gate(i), max_size, |t| t.is_xor()),
NaryType::Xor,
),
);
}
}
ret.cleanup();
ret.make_canonical();
ret
}
struct Factoring {
gate_signals: Vec<Vec<Signal>>,
gate_exclusive_signals: Vec<Vec<Signal>>,
next_var: u32,
built_pairs: Vec<(Signal, Signal)>,
count_to_pair: Vec<FxHashSet<(Signal, Signal)>>,
pair_to_gates: FxHashMap<(Signal, Signal), FxHashSet<usize>>,
}
impl Factoring {
fn from_gates(gates: Vec<Vec<Signal>>, next_var: u32) -> Factoring {
Factoring {
gate_signals: gates,
gate_exclusive_signals: Vec::new(),
next_var,
built_pairs: Vec::new(),
count_to_pair: Vec::new(),
pair_to_gates: FxHashMap::default(),
}
}
fn make_pair(a: &Signal, b: &Signal) -> (Signal, Signal) {
(cmp::min(*a, *b), cmp::max(*a, *b))
}
fn count_signal_usage(&self) -> FxHashMap<Signal, u32> {
let mut count = FxHashMap::<Signal, u32>::default();
for v in &self.gate_signals {
for s in v {
count.entry(*s).and_modify(|e| *e += 1).or_insert(1);
}
}
count
}
fn separate_exclusive_signals(&mut self) {
assert!(self.gate_exclusive_signals.is_empty());
let cnt = self.count_signal_usage();
for g in &mut self.gate_signals {
let mut exclusive = g.clone();
g.retain(|s| cnt[s] != 1);
exclusive.retain(|s| cnt[s] == 1);
self.gate_exclusive_signals.push(exclusive);
}
}
fn consume_binary_gates(&mut self) {
for _ in 0..2 {
for i in 0..self.gate_signals.len() {
if self.gate_signals[i].len() == 2 {
self.replace_pair(Factoring::make_pair(
&self.gate_signals[i][0],
&self.gate_signals[i][1],
));
}
}
}
}
fn compute_pair_to_gates(&self) -> FxHashMap<(Signal, Signal), FxHashSet<usize>> {
let mut ret = FxHashMap::<(Signal, Signal), FxHashSet<usize>>::default();
for (i, v) in self.gate_signals.iter().enumerate() {
for (a, b) in v.iter().tuple_combinations() {
let p = Factoring::make_pair(a, b);
ret.entry(p)
.and_modify(|e| {
e.insert(i);
})
.or_insert({
let mut hsh = FxHashSet::default();
hsh.insert(i);
hsh
});
}
}
ret
}
fn setup_initial(&mut self) {
self.separate_exclusive_signals();
self.pair_to_gates = self.compute_pair_to_gates();
for (p, gates_touched) in &self.pair_to_gates {
let cnt = gates_touched.len();
if self.count_to_pair.len() <= cnt {
self.count_to_pair.resize(cnt + 1, FxHashSet::default());
}
self.count_to_pair[cnt].insert(*p);
}
}
fn finalize(&mut self) {
for (g1, g2) in zip(&mut self.gate_signals, &self.gate_exclusive_signals) {
g1.extend(g2);
}
self.gate_exclusive_signals.clear();
for g in &mut self.gate_signals {
while g.len() > 1 {
let mut next_g = Vec::new();
for i in (0..g.len() - 1).step_by(2) {
let p = Signal::from_var(self.next_var);
self.next_var += 1;
self.built_pairs.push((g[i], g[i + 1]));
next_g.push(p);
}
if g.len() % 2 != 0 {
next_g.push(*g.last().unwrap());
}
*g = next_g;
}
}
}
fn replace_pair(&mut self, p: (Signal, Signal)) {
let p_out = Signal::from_var(self.next_var);
self.next_var += 1;
self.built_pairs.push(p);
let gates_touched = self.pair_to_gates.remove(&p).unwrap();
self.count_to_pair[gates_touched.len()].remove(&p);
for i in gates_touched {
self.gate_signals[i].retain(|s| *s != p.0 && *s != p.1);
for s in self.gate_signals[i].clone() {
self.decrement_pair(Factoring::make_pair(&s, &p.0), i);
self.decrement_pair(Factoring::make_pair(&s, &p.1), i);
self.increment_pair(Factoring::make_pair(&s, &p_out), i);
self.increment_pair(Factoring::make_pair(&s, &p_out), i);
}
self.gate_signals[i].push(p_out);
}
}
fn decrement_pair(&mut self, p: (Signal, Signal), gate: usize) {
let cnt = self.pair_to_gates[&p].len();
self.pair_to_gates.entry(p).and_modify(|e| {
e.remove(&gate);
});
self.count_to_pair[cnt].remove(&p);
if cnt > 1 {
self.count_to_pair[cnt - 1].insert(p);
}
}
fn increment_pair(&mut self, p: (Signal, Signal), gate: usize) {
self.pair_to_gates
.entry(p)
.and_modify(|e| {
e.insert(gate);
})
.or_insert({
let mut hsh = FxHashSet::default();
hsh.insert(gate);
hsh
});
let cnt = self.pair_to_gates[&p].len();
if self.count_to_pair.len() <= cnt {
self.count_to_pair.resize(cnt + 1, FxHashSet::default());
}
self.count_to_pair[cnt - 1].remove(&p);
self.count_to_pair[cnt].insert(p);
}
fn find_best_pair(&mut self) -> Option<(Signal, Signal)> {
while !self.count_to_pair.is_empty() {
let pairs = self.count_to_pair.last().unwrap();
if let Some(p) = pairs.iter().next() {
return Some(*p);
} else {
self.count_to_pair.pop();
}
}
None
}
fn consume_pairs(&mut self) {
self.setup_initial();
self.consume_binary_gates();
while let Some(p) = self.find_best_pair() {
self.replace_pair(p);
}
for g in &self.gate_signals {
assert!(g.len() <= 1);
}
self.finalize();
for g in &self.gate_signals {
assert!(g.len() == 1);
}
}
pub fn run(gates: Vec<Vec<Signal>>, first_var: u32) -> (Vec<(Signal, Signal)>, Vec<Signal>) {
let mut f = Factoring::from_gates(gates, first_var);
f.consume_pairs();
let replacement = f.gate_signals.iter().map(|g| g[0]).collect();
(f.built_pairs, replacement)
}
}
fn factor_gates<F: Fn(&Gate) -> bool, G: Fn(Signal, Signal) -> Gate>(
aig: &Network,
pred: F,
builder: G,
) -> Network {
assert!(aig.is_topo_sorted());
let mut inds = Vec::new();
let mut gates = Vec::new();
for i in 0..aig.nb_nodes() {
let g = aig.gate(i);
if pred(g) && g.dependencies().len() > 1 {
gates.push(g.dependencies().into());
inds.push(i);
}
}
let mut ret = aig.clone();
let (binary_gates, replacements) = Factoring::run(gates, ret.nb_nodes() as u32);
for (a, b) in binary_gates {
ret.add(builder(a, b));
}
for (i, g) in zip(inds, replacements) {
ret.replace(i, Gate::Buf(g));
}
ret.topo_sort();
ret.make_canonical();
ret
}
pub fn factor_nary(aig: &Network) -> Network {
let aig1 = factor_gates(aig, |g| g.is_and(), |a, b| Gate::and(a, b));
let aig2 = factor_gates(&aig1, |g| g.is_xor(), |a, b| Gate::xor(a, b));
aig2
}
pub fn share_logic(aig: &mut Network, flattening_limit: usize) {
*aig = flatten_nary(&aig, flattening_limit);
*aig = factor_nary(&aig);
}
#[cfg(test)]
mod tests {
use super::{factor_nary, flatten_nary};
use crate::network::NaryType;
use crate::{Gate, Network, Signal};
#[test]
fn test_flatten_and() {
let mut aig = Network::new();
let i0 = aig.add_input();
let i1 = aig.add_input();
let i2 = aig.add_input();
aig.add_input();
let i4 = aig.add_input();
let x0 = aig.and(i0, i1);
let x1 = aig.and(i0, !i2);
let x2 = aig.and(x0, x1);
let x3 = aig.and(x2, i4);
aig.add_output(x3);
aig = flatten_nary(&aig, 64);
assert_eq!(aig.nb_nodes(), 1);
assert_eq!(
aig.gate(0),
&Gate::Nary(Box::new([i4, !i2, i1, i0]), NaryType::And)
);
}
#[test]
fn test_flatten_xor() {
let mut aig = Network::new();
let i0 = aig.add_input();
let i1 = aig.add_input();
let i2 = aig.add_input();
aig.add_input();
let i4 = aig.add_input();
let x0 = aig.xor(i0, i1);
let x1 = aig.xor(i0, !i2);
let x2 = aig.xor(x0, x1);
let x3 = aig.xor(x2, i4);
aig.add_output(x3);
aig = flatten_nary(&aig, 64);
assert_eq!(aig.nb_nodes(), 1);
assert_eq!(aig.gate(0), &Gate::xor3(i4, i2, i1));
assert_eq!(aig.output(0), !Signal::from_var(0));
}
#[test]
fn test_share_and() {
let mut aig = Network::new();
let i0 = aig.add_input();
let i1 = aig.add_input();
let i2 = aig.add_input();
let i3 = aig.add_input();
let i4 = aig.add_input();
let x0 = aig.add(Gate::Nary(Box::new([i0, i1, i2]), NaryType::And));
let x1 = aig.add(Gate::Nary(Box::new([i0, i1, i2, i3]), NaryType::And));
let x2 = aig.add(Gate::Nary(Box::new([i1, i2, i4]), NaryType::And));
aig.add_output(x0);
aig.add_output(x1);
aig.add_output(x2);
aig = factor_nary(&aig);
assert_eq!(aig.nb_nodes(), 4);
assert_eq!(aig.gate(0), &Gate::and(i2, i1));
}
}