use std::{
cmp::max,
collections::{HashMap, HashSet},
num::TryFromIntError,
};
use thiserror::Error;
use crate::{
Aig, AigError, AigNode, AigNodeRef, NodeId, Result,
cnf::{Cnf, Lit},
dfs::Dfs,
};
#[derive(Debug, Error)]
pub enum MiterError {
#[error("AIGs have different inputs: {0:?} vs {1:?}")]
MiterDifferentInputs(HashSet<NodeId>, HashSet<NodeId>),
#[error("AIGs have different latches: {0:?} vs {1:?}")]
MiterDifferentLatches(HashSet<NodeId>, HashSet<NodeId>),
#[error("trying to construct a miter between two AIGs with different outputs")]
MiterDifferentOutputs,
#[error("node id {0} is not mapped to any literal")]
UnmappedNodeToLit(NodeId),
#[error("conversion from NodeId to Lit failed because of {0}")]
NodeIdToLit(TryFromIntError),
}
impl From<TryFromIntError> for MiterError {
fn from(value: TryFromIntError) -> Self {
MiterError::NodeIdToLit(value)
}
}
pub struct Miter {
pub(super) a: Aig,
pub(super) b: Aig,
pub(super) outputs_map: HashMap<(NodeId, bool), (NodeId, bool)>,
litmap_a: HashMap<NodeId, Lit>,
litmap_b: HashMap<NodeId, Lit>,
merged_a: HashSet<NodeId>,
merged: HashSet<(NodeId, NodeId, bool)>,
next_lit: i64,
}
fn check_outputs(
a: &Aig,
b: &Aig,
outputs_map: &HashMap<(NodeId, bool), (NodeId, bool)>,
) -> Result<()> {
if a.get_outputs()
.iter()
.map(|output| (output.get_node().borrow().get_id(), output.get_complement()))
.collect::<HashSet<(u64, bool)>>()
!= outputs_map.keys().copied().collect()
{
return Err(MiterError::MiterDifferentOutputs.into());
}
if b.get_outputs()
.iter()
.map(|output| (output.get_node().borrow().get_id(), output.get_complement()))
.collect::<HashSet<(u64, bool)>>()
!= outputs_map.values().copied().collect()
{
return Err(MiterError::MiterDifferentOutputs.into());
}
Ok(())
}
impl Miter {
pub fn new(
a: &Aig,
b: &Aig,
outputs_map: HashMap<(NodeId, bool), (NodeId, bool)>,
) -> Result<Self> {
if a.get_inputs_id() != b.get_inputs_id() {
return Err(
MiterError::MiterDifferentInputs(a.get_inputs_id(), b.get_inputs_id()).into(),
);
}
if a.get_latches_id() != b.get_latches_id() {
return Err(
MiterError::MiterDifferentLatches(a.get_latches_id(), b.get_latches_id()).into(),
);
}
check_outputs(a, b, &outputs_map)?;
let max_input_id = i64::try_from(*a.get_inputs_id().iter().max().unwrap_or(&1))
.map_err(MiterError::from)?;
let max_latch_id = i64::try_from(*a.get_latches_id().iter().max().unwrap_or(&1))
.map_err(MiterError::from)?;
let next_lit = max(max_input_id, max_latch_id);
let mut miter = Miter {
a: a.clone(),
b: b.clone(),
outputs_map,
litmap_a: HashMap::new(),
litmap_b: HashMap::new(),
merged: HashSet::new(),
merged_a: HashSet::new(),
next_lit,
};
miter.initialize_litmaps()?;
Ok(miter)
}
fn initialize_litmaps(&mut self) -> Result<()> {
let mut dfs = Dfs::from_outputs(&self.a);
while let Some(n) = dfs.next(&self.a) {
match *n.borrow() {
AigNode::Input(id) => self
.litmap_a
.insert(id, Lit::try_from(id).map_err(MiterError::from)?),
AigNode::Latch { id, .. } => self
.litmap_a
.insert(id, Lit::try_from(id).map_err(MiterError::from)?),
AigNode::And { id, .. } => {
let lit = self.fresh_lit();
self.litmap_a.insert(id, lit)
}
AigNode::False => None,
};
}
let mut dfs = Dfs::from_outputs(&self.b);
while let Some(n) = dfs.next(&self.b) {
match *n.borrow() {
AigNode::Input(id) => self
.litmap_b
.insert(id, Lit::try_from(id).map_err(MiterError::from)?),
AigNode::Latch { id, .. } => self
.litmap_b
.insert(id, Lit::try_from(id).map_err(MiterError::from)?),
AigNode::And { id, .. } => {
let lit = self.fresh_lit();
self.litmap_b.insert(id, lit)
}
AigNode::False => None,
};
}
Ok(())
}
pub fn fresh_lit(&mut self) -> Lit {
let lit = self.next_lit.into();
self.next_lit += 1;
lit
}
fn extract_cnf_from(
&self,
node: AigNodeRef,
cnf: &mut Cnf,
done: &mut HashSet<NodeId>,
litmap: &HashMap<u64, Lit>,
) -> Result<()> {
let mut stack = Vec::new();
let id = node.borrow().get_id();
if done.contains(&id) {
return Ok(());
}
stack.push(node);
while let Some(node) = stack.pop() {
let id = node.borrow().get_id();
done.insert(id);
cnf.add_clauses_node(&*node.borrow(), litmap)?;
match &*node.borrow() {
AigNode::And { fanin0, fanin1, .. } => {
for fanin in [fanin0, fanin1] {
let fanin_id = fanin.get_node().borrow().get_id();
if !done.contains(&fanin_id) {
stack.push(fanin.get_node());
}
}
}
_ => (),
};
}
Ok(())
}
pub fn extract_cnf_node(
&mut self,
node_a: NodeId,
compl_a: bool,
node_b: NodeId,
compl_b: bool,
) -> Result<Cnf> {
let mut cnf = Cnf::new();
let mut done_a = HashSet::new();
self.extract_cnf_from(
self.a
.get_node(node_a)
.ok_or(AigError::NodeDoesNotExist(node_a))?,
&mut cnf,
&mut done_a,
&self.litmap_a,
)?;
let mut done_b = HashSet::new();
self.extract_cnf_from(
self.b
.get_node(node_b)
.ok_or(AigError::NodeDoesNotExist(node_b))?,
&mut cnf,
&mut done_b,
&self.litmap_b,
)?;
let lit_a = *self
.litmap_a
.get(&node_a)
.ok_or(MiterError::UnmappedNodeToLit(node_a))?;
let lit_b = *self
.litmap_b
.get(&node_b)
.ok_or(MiterError::UnmappedNodeToLit(node_b))?;
cnf.add_xor_whose_output_is_true(
if compl_a { !lit_a } else { lit_a },
if compl_b { !lit_b } else { lit_b },
);
Ok(cnf)
}
pub fn extract_cnf(&mut self) -> Result<Cnf> {
let mut cnf = Cnf::new();
let mut done_a = HashSet::new();
for output in self.a.get_outputs() {
self.extract_cnf_from(output.get_node(), &mut cnf, &mut done_a, &self.litmap_a)?;
}
let mut done_b = HashSet::new();
for output in self.b.get_outputs() {
self.extract_cnf_from(output.get_node(), &mut cnf, &mut done_b, &self.litmap_b)?;
}
let mut xor_lits = Vec::new();
let outputs_map = self.outputs_map.clone();
for ((id_a, compl_a), (id_b, compl_b)) in outputs_map {
let z = self.fresh_lit();
let a = *self
.litmap_a
.get(&id_a)
.ok_or(MiterError::UnmappedNodeToLit(id_a))?;
let b = *self
.litmap_b
.get(&id_b)
.ok_or(MiterError::UnmappedNodeToLit(id_b))?;
cnf.add_xor(
if compl_a { !a } else { a },
if compl_b { !b } else { b },
z,
);
xor_lits.push(z);
}
for id in self.a.get_latches_id() {
let latch_a = self.a.get_node(id).ok_or(AigError::NodeDoesNotExist(id))?;
let latch_b = self.b.get_node(id).ok_or(AigError::NodeDoesNotExist(id))?;
let fanin_a = latch_a.borrow().get_fanins()[0].clone();
let fanin_b = latch_b.borrow().get_fanins()[0].clone();
let z = self.fresh_lit();
cnf.add_xor_pseudo_output(fanin_a, &self.litmap_a, fanin_b, &self.litmap_b, z)?;
xor_lits.push(z);
}
cnf.add_or_whose_output_is_true(xor_lits);
Ok(cnf)
}
pub fn merge(&mut self, node_a: NodeId, node_b: NodeId, complement: bool) -> Result<()> {
if node_a == 0 && node_b == 0 {
return Ok(());
} else if node_a == 0 || node_b == 0 {
return Err(AigError::InvalidState(format!(
"trying to merge node false with non-false node: id_a = {}, id_b = {} --- unsupported feature for now",
node_a, node_b
)));
}
let lit_b = *self
.litmap_b
.get(&node_b)
.ok_or(MiterError::UnmappedNodeToLit(node_b))?;
self.litmap_a
.insert(node_a, if complement { !lit_b } else { lit_b });
self.merged.insert((node_a, node_b, complement));
self.merged_a.insert(node_a);
Ok(())
}
pub fn mergeable(&self, node_a: NodeId, node_b: NodeId) -> Result<bool> {
let na = self
.a
.get_node(node_a)
.ok_or(AigError::NodeDoesNotExist(node_a))?;
let nb = self
.b
.get_node(node_b)
.ok_or(AigError::NodeDoesNotExist(node_b))?;
match (&*na.borrow(), &*nb.borrow()) {
(
AigNode::And {
fanin0: fanin0_a,
fanin1: fanin1_a,
..
},
AigNode::And {
fanin0: fanin0_b,
fanin1: fanin1_b,
..
},
) => {
let id0a = fanin0_a.get_node().borrow().get_id();
let c0a = fanin0_a.get_complement();
let id1a = fanin1_a.get_node().borrow().get_id();
let c1a = fanin1_a.get_complement();
let id0b = fanin0_b.get_node().borrow().get_id();
let c0b = fanin0_b.get_complement();
let id1b = fanin1_b.get_node().borrow().get_id();
let c1b = fanin1_b.get_complement();
Ok((self.merged.contains(&(id0a, id0b, c0a ^ c0b))
&& self.merged.contains(&(id1a, id1b, c1a ^ c1b)))
|| (self.merged.contains(&(id0a, id1b, c0a ^ c1b))
&& self.merged.contains(&(id1a, id0b, c1a ^ c0b))))
}
(AigNode::Latch { id: id_a, .. }, AigNode::Latch { id: id_b, .. }) => {
Ok(*id_a == *id_b)
}
(AigNode::Input(id_a), AigNode::Input(id_b)) => Ok(*id_a == *id_b),
(AigNode::False, AigNode::False) => Ok(true),
_ => Ok(false),
}
}
pub fn are_outputs_merged(&self) -> bool {
for out in self.a.get_outputs() {
if !self.merged_a.contains(&out.get_node().borrow().get_id()) {
return false;
}
}
return true;
}
}
#[cfg(test)]
mod test {
use crate::AigEdge;
use super::*;
#[test]
fn new_miter_test() {
let mut a = Aig::new();
let a1 = a.add_node(AigNode::Input(1)).unwrap();
let mut b = Aig::new();
let b2 = b.add_node(AigNode::Input(2)).unwrap();
assert!(Miter::new(&a, &b, HashMap::new()).is_err());
let a2 = a.add_node(AigNode::Input(2)).unwrap();
b.add_node(AigNode::Input(1)).unwrap(); a.add_node(AigNode::and(
3,
AigEdge::new(a1.clone(), false),
AigEdge::new(a2.clone(), false),
))
.unwrap();
let b0 = b.add_node(AigNode::False).unwrap();
b.add_node(AigNode::and(
3,
AigEdge::new(b0.clone(), false),
AigEdge::new(b2.clone(), false),
))
.unwrap();
a.add_output(3, true).unwrap();
let mut outputs = HashMap::new();
outputs.insert((3, true), (3, false));
assert!(Miter::new(&a, &b, outputs.clone()).is_err());
outputs.clear();
b.add_output(3, false).unwrap();
outputs.insert((3, true), (3, true));
assert!(Miter::new(&a, &b, outputs.clone()).is_err());
outputs.clear();
outputs.insert((3, true), (3, false));
assert!(Miter::new(&a, &b, outputs.clone()).is_ok());
b.update();
assert!(Miter::new(&a, &b, outputs.clone()).is_ok());
}
#[test]
fn mergeable_test() {
let mut a = Aig::new();
let a1 = a.add_node(AigNode::Input(1)).unwrap();
let a2 = a.add_node(AigNode::Input(2)).unwrap();
let a3 = a
.add_node(AigNode::and(
3,
AigEdge::new(a1.clone(), false),
AigEdge::new(a2.clone(), false),
))
.unwrap();
let _a4 = a.add_node(AigNode::latch(4, AigEdge::new(a3.clone(), true), None));
a.add_output(4, false).unwrap();
let mut b = Aig::new();
let b1 = b.add_node(AigNode::Input(1)).unwrap();
let b2 = b.add_node(AigNode::Input(2)).unwrap();
let b3 = b
.add_node(AigNode::and(
3,
AigEdge::new(b2.clone(), false),
AigEdge::new(b1.clone(), false),
))
.unwrap();
let _b4 = b.add_node(AigNode::latch(
4,
AigEdge::new(b3.clone(), false), Some(true), ));
b.add_output(4, false).unwrap();
let outputs = HashMap::from([((4, false), (4, false))]);
let mut miter = Miter::new(&a, &b, outputs).unwrap();
assert!(!miter.mergeable(1, 2).unwrap());
assert!(!miter.mergeable(2, 1).unwrap());
assert!(!miter.mergeable(1, 2).unwrap());
assert!(!miter.mergeable(3, 3).unwrap()); assert!(miter.mergeable(4, 4).unwrap());
assert!(miter.mergeable(1, 1).unwrap());
assert!(miter.mergeable(2, 2).unwrap());
miter.merge(1, 1, false).unwrap();
miter.merge(2, 2, false).unwrap();
assert!(miter.mergeable(3, 3).unwrap());
miter.merge(3, 3, false).unwrap();
assert!(miter.mergeable(4, 4).unwrap()); miter.merge(4, 4, false).unwrap();
}
}