use crate::circuit::{Instantiable, Net};
use crate::error::Error;
#[cfg(feature = "graph")]
use crate::netlist::Connection;
use crate::netlist::{InputPort, NetRef, Netlist};
#[cfg(feature = "graph")]
use petgraph::graph::DiGraph;
use std::cmp::Reverse;
use std::collections::hash_map::Entry;
use std::collections::{BinaryHeap, HashMap, HashSet};
pub trait Analysis<'a, I: Instantiable>
where
Self: Sized + 'a,
{
fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
}
pub struct FanOutTable<'a, I: Instantiable> {
_netlist: &'a Netlist<I>,
net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
is_an_output: HashSet<Net>,
}
impl<I> FanOutTable<'_, I>
where
I: Instantiable,
{
pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
self.net_fan_out
.get(net)
.into_iter()
.flat_map(|users| users.iter().cloned())
}
pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
self.node_fan_out
.get(node)
.into_iter()
.flat_map(|users| users.iter().cloned())
}
pub fn net_has_uses(&self, net: &Net) -> bool {
(self.net_fan_out.contains_key(net) && !self.net_fan_out.get(net).unwrap().is_empty())
|| self.is_an_output.contains(net)
}
}
impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
where
I: Instantiable,
{
fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
let mut is_an_output: HashSet<Net> = HashSet::new();
if let Err(e) = netlist.verify() {
match e {
Error::NoOutputs => (),
_ => return Err(e),
}
}
for c in netlist.connections() {
if let Entry::Vacant(e) = net_fan_out.entry(c.net()) {
e.insert(vec![c.target().unwrap()]);
} else {
net_fan_out
.get_mut(&c.net())
.unwrap()
.push(c.target().unwrap());
}
if let Entry::Vacant(e) = node_fan_out.entry(c.src().unwrap()) {
e.insert(vec![c.target().unwrap()]);
} else {
node_fan_out
.get_mut(&c.src().unwrap())
.unwrap()
.push(c.target().unwrap());
}
}
for (o, n) in netlist.outputs() {
is_an_output.insert(o.as_net().clone());
is_an_output.insert(n);
}
Ok(FanOutTable {
_netlist: netlist,
net_fan_out,
node_fan_out,
is_an_output,
})
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum CombDepthResult {
Undefined,
CombCycle,
Depth(usize),
}
pub struct CombDepthInfo<'a, I: Instantiable> {
_netlist: &'a Netlist<I>,
results: HashMap<NetRef<I>, CombDepthResult>,
critical_par: HashMap<NetRef<I>, InputPort<I>>,
critical_ends: BinaryHeap<(Reverse<usize>, NetRef<I>)>,
max_depth: Option<usize>,
}
impl<I> CombDepthInfo<'_, I>
where
I: Instantiable,
{
const SIZE_HEAP: usize = 10;
pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<CombDepthResult> {
self.results.get(node).copied()
}
pub fn get_crit_input(&self, node: &NetRef<I>) -> Option<&InputPort<I>> {
self.critical_par.get(node)
}
pub fn get_critical_points(&self) -> impl IntoIterator<Item = &NetRef<I>> {
let mut v = self.critical_ends.iter().collect::<Vec<_>>();
v.sort_by_key(|(d, _)| *d);
v.into_iter().map(|(_, n)| n)
}
pub fn build_critical_path(&self) -> Option<Vec<NetRef<I>>> {
let mut path = Vec::new();
let mut current = self.get_critical_points().into_iter().next()?.clone();
while let Some(crit) = self.critical_par.get(¤t) {
path.push(current.clone());
current = self
._netlist
.get_driver(current, crit.get_input_num())
.unwrap();
}
path.push(current);
Some(path)
}
pub fn get_max_depth(&self) -> Option<usize> {
self.max_depth
}
}
impl<'a, I> Analysis<'a, I> for CombDepthInfo<'a, I>
where
I: Instantiable,
{
fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
let mut results: HashMap<NetRef<I>, CombDepthResult> = HashMap::new();
let mut critical_par: HashMap<NetRef<I>, InputPort<I>> = HashMap::new();
let mut critical_ends: BinaryHeap<(_, NetRef<I>)> = BinaryHeap::new();
let mut visiting: HashSet<NetRef<I>> = HashSet::new();
let mut max_depth: Option<usize> = None;
fn compute<I: Instantiable>(
node: NetRef<I>,
netlist: &Netlist<I>,
results: &mut HashMap<NetRef<I>, CombDepthResult>,
critical_par: &mut HashMap<NetRef<I>, InputPort<I>>,
critical_ends: &mut BinaryHeap<(Reverse<usize>, NetRef<I>)>,
visiting: &mut HashSet<NetRef<I>>,
) -> CombDepthResult {
if let Some(&r) = results.get(&node) {
return r;
}
if visiting.contains(&node) {
for n in visiting.iter() {
results.insert(n.clone(), CombDepthResult::CombCycle);
}
return CombDepthResult::CombCycle;
}
if node.is_an_input() || node.get_instance_type().is_some_and(|inst| inst.is_seq()) {
let r = CombDepthResult::Depth(0);
results.insert(node.clone(), r);
return r;
}
visiting.insert(node.clone());
let mut max_depth = 0;
let mut crit: Option<InputPort<I>> = None;
let mut is_undefined = false;
for i in 0..node.get_num_input_ports() {
let driver = match netlist.get_driver(node.clone(), i) {
Some(d) => d,
None => {
is_undefined = true;
continue;
}
};
if let Some(inst) = driver.get_instance_type()
&& inst.is_seq()
{
continue;
}
match compute(
driver,
netlist,
results,
critical_par,
critical_ends,
visiting,
) {
CombDepthResult::Depth(d) => {
if d > max_depth {
max_depth = d;
crit = Some(node.get_input(i));
}
}
CombDepthResult::Undefined => {
is_undefined = true;
}
CombDepthResult::CombCycle => {
let r = CombDepthResult::CombCycle;
results.insert(node.clone(), r);
visiting.remove(&node);
return r;
}
}
}
visiting.remove(&node);
let r = if is_undefined {
CombDepthResult::Undefined
} else {
if let Some(crit) = crit {
critical_par.insert(node.clone(), crit);
}
let d = max_depth + 1;
critical_ends.push((Reverse(d), node.clone()));
if critical_ends.len() > CombDepthInfo::<I>::SIZE_HEAP {
critical_ends.pop();
}
CombDepthResult::Depth(d)
};
results.insert(node.clone(), r);
r
}
for (driven, _) in netlist.outputs() {
let node = driven.unwrap();
let r = compute(
node,
netlist,
&mut results,
&mut critical_par,
&mut critical_ends,
&mut visiting,
);
if let CombDepthResult::Depth(d) = r {
max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
}
}
for node in netlist.matches(|inst| inst.is_seq()) {
compute(
node.clone(),
netlist,
&mut results,
&mut critical_par,
&mut critical_ends,
&mut visiting,
);
for i in 0..node.get_num_input_ports() {
if let Some(driver) = netlist.get_driver(node.clone(), i) {
if driver.get_instance_type().is_some_and(|inst| inst.is_seq()) {
continue;
}
let r = compute(
driver,
netlist,
&mut results,
&mut critical_par,
&mut critical_ends,
&mut visiting,
);
if let CombDepthResult::Depth(d) = r {
max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
}
}
}
}
Ok(CombDepthInfo {
_netlist: netlist,
results,
critical_par,
critical_ends,
max_depth,
})
}
}
#[cfg(feature = "graph")]
#[derive(Debug, Clone)]
pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
NetRef(NetRef<I>),
Pseudo(T),
}
#[cfg(feature = "graph")]
impl<I, T> std::fmt::Display for Node<I, T>
where
I: Instantiable,
T: Clone + std::fmt::Debug + std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Node::NetRef(nr) => nr.fmt(f),
Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
}
}
}
#[cfg(feature = "graph")]
#[derive(Debug, Clone)]
pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
Connection(Connection<I>),
Pseudo(T),
}
#[cfg(feature = "graph")]
impl<I, T> std::fmt::Display for Edge<I, T>
where
I: Instantiable,
T: Clone + std::fmt::Debug + std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Edge::Connection(c) => c.fmt(f),
Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
}
}
}
#[cfg(feature = "graph")]
pub struct MultiDiGraph<'a, I: Instantiable> {
_netlist: &'a Netlist<I>,
graph: DiGraph<Node<I, String>, Edge<I, Net>>,
}
#[cfg(feature = "graph")]
impl<I> MultiDiGraph<'_, I>
where
I: Instantiable,
{
pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
&self.graph
}
pub fn greedy_feedback_arcs(&self) -> impl Iterator<Item = Connection<I>> {
petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&self.graph)
.map(|e| match e.weight() {
Edge::Connection(c) => c,
_ => unreachable!("Outputs should be sinks"),
})
.cloned()
}
pub fn sccs(&self) -> Vec<Vec<NetRef<I>>> {
let mut res = Vec::new();
for scc in petgraph::algo::tarjan_scc(&self.graph) {
let c: Vec<NetRef<I>> = scc
.into_iter()
.filter_map(|i| match &self.graph[i] {
Node::NetRef(nr) => Some(nr.clone()),
_ => None,
})
.collect();
if !c.is_empty() {
res.push(c);
}
}
res
}
}
#[cfg(feature = "graph")]
impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
where
I: Instantiable,
{
fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
netlist.verify()?;
let mut mapping = HashMap::new();
let mut graph = DiGraph::new();
for obj in netlist.objects() {
let id = graph.add_node(Node::NetRef(obj.clone()));
mapping.insert(obj.to_string(), id);
}
for connection in netlist.connections() {
let source = connection.src().unwrap().get_obj().to_string();
let target = connection.target().unwrap().get_obj().to_string();
let s_id = mapping[&source];
let t_id = mapping[&target];
graph.add_edge(s_id, t_id, Edge::Connection(connection));
}
for (o, n) in netlist.outputs() {
let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
}
Ok(Self {
_netlist: netlist,
graph,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{format_id, netlist::*};
fn full_adder() -> Gate {
Gate::new_logical_multi(
"FA".into(),
vec!["CIN".into(), "A".into(), "B".into()],
vec!["S".into(), "COUT".into()],
)
}
fn ripple_adder() -> GateNetlist {
let netlist = Netlist::new("ripple_adder".to_string());
let bitwidth = 4;
let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
for (i, (a, b)) in a.into_iter().zip(b).enumerate() {
let fa = netlist
.insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
.unwrap();
fa.expose_net(&fa.get_net(0)).unwrap();
carry = fa.find_output(&"COUT".into()).unwrap();
if i == bitwidth - 1 {
fa.get_output(1).expose_with_name("cout".into()).unwrap();
}
}
netlist.reclaim().unwrap()
}
#[test]
fn fanout_table() {
let netlist = ripple_adder();
let analysis = FanOutTable::build(&netlist);
assert!(analysis.is_ok());
let analysis = analysis.unwrap();
assert!(netlist.verify().is_ok());
for item in netlist.objects().filter(|o| !o.is_an_input()) {
assert!(
analysis
.get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
.next()
.is_none(),
"Sum bit should not have users"
);
assert!(
item.get_instance_name().is_some(),
"Item should have a name. Filtered inputs"
);
let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
let mut cout_users = analysis.get_net_users(&net);
if item.get_instance_name().unwrap().to_string() != "fa_3" {
assert!(cout_users.next().is_some(), "Carry bit should have users");
}
assert!(
cout_users.next().is_none(),
"Carry bit should have 1 or 0 user"
);
}
}
}