Skip to main content

safety_net/
graph.rs

1/*!
2
3  Graph utils for the `graph` module.
4
5*/
6
7use crate::circuit::{Instantiable, Net};
8use crate::error::Error;
9#[cfg(feature = "graph")]
10use crate::netlist::Connection;
11use crate::netlist::{NetRef, Netlist};
12#[cfg(feature = "graph")]
13use petgraph::graph::DiGraph;
14use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16
17/// A common trait of analyses than can be performed on a netlist.
18/// An analysis becomes stale when the netlist is modified.
19pub trait Analysis<'a, I: Instantiable>
20where
21    Self: Sized + 'a,
22{
23    /// Construct the analysis to the current state of the netlist.
24    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
25}
26
27/// A table that maps nets to the circuit nodes they drive
28pub struct FanOutTable<'a, I: Instantiable> {
29    /// A reference to the underlying netlist
30    _netlist: &'a Netlist<I>,
31    /// Maps a net to the list of nodes it drives
32    net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
33    /// Maps a node to the list of nodes it drives
34    node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
35    /// Contains nets which are outputs
36    is_an_output: HashSet<Net>,
37}
38
39impl<I> FanOutTable<'_, I>
40where
41    I: Instantiable,
42{
43    /// Returns an iterator to the circuit nodes that use `net`.
44    pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
45        self.net_fan_out
46            .get(net)
47            .into_iter()
48            .flat_map(|users| users.iter().cloned())
49    }
50
51    /// Returns an iterator to the circuit nodes that use `node`.
52    pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
53        self.node_fan_out
54            .get(node)
55            .into_iter()
56            .flat_map(|users| users.iter().cloned())
57    }
58
59    /// Returns `true` if the net has any used by any cells in the circuit
60    /// This does incude nets that are only used as outputs.
61    pub fn net_has_uses(&self, net: &Net) -> bool {
62        (self.net_fan_out.contains_key(net) && !self.net_fan_out.get(net).unwrap().is_empty())
63            || self.is_an_output.contains(net)
64    }
65}
66
67impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
68where
69    I: Instantiable,
70{
71    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
72        let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
73        let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
74        let mut is_an_output: HashSet<Net> = HashSet::new();
75
76        // This can only be fully-correct on a verified netlist.
77        netlist.verify()?;
78
79        for c in netlist.connections() {
80            if let Entry::Vacant(e) = net_fan_out.entry(c.net()) {
81                e.insert(vec![c.target().unwrap()]);
82            } else {
83                net_fan_out
84                    .get_mut(&c.net())
85                    .unwrap()
86                    .push(c.target().unwrap());
87            }
88
89            if let Entry::Vacant(e) = node_fan_out.entry(c.src().unwrap()) {
90                e.insert(vec![c.target().unwrap()]);
91            } else {
92                node_fan_out
93                    .get_mut(&c.src().unwrap())
94                    .unwrap()
95                    .push(c.target().unwrap());
96            }
97        }
98
99        for (o, n) in netlist.outputs() {
100            is_an_output.insert(o.as_net().clone());
101            is_an_output.insert(n);
102        }
103
104        Ok(FanOutTable {
105            _netlist: netlist,
106            net_fan_out,
107            node_fan_out,
108            is_an_output,
109        })
110    }
111}
112
113/// A simple example to analyze the logic levels of a netlist.
114/// This analysis checks for cycles, but it doesn't check for registers.
115/// Result of combinational depth analysis for a single net.
116#[derive(Debug, Copy, Clone, PartialEq, Eq)]
117pub enum CombDepthResult {
118    /// Signal has no driver
119    Undefined,
120    /// Signal is along a cycle
121    PartOfCycle,
122    /// Integer logic level
123    Depth(usize),
124}
125
126/// Computes the combinational depth of each net in a netlist.
127///
128/// Each net is classified as having a defined depth, being undefined,
129/// or participating in a combinational cycle.
130pub struct SimpleCombDepth<'a, I: Instantiable> {
131    _netlist: &'a Netlist<I>,
132    results: HashMap<NetRef<I>, CombDepthResult>,
133    /// Max will be None whenever no outputs in the whole netlist have a well defined combinational depth
134    /// for example if they are all undefined or they all partake in a cycle
135    max_depth: Option<usize>,
136}
137
138impl<I> SimpleCombDepth<'_, I>
139where
140    I: Instantiable,
141{
142    /// Returns the logic level of a node in the circuit.
143    pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<CombDepthResult> {
144        self.results.get(node).copied()
145    }
146
147    /// Returns the maximum logic level of the circuit.
148    pub fn get_max_depth(&self) -> Option<usize> {
149        self.max_depth
150    }
151}
152impl<'a, I> Analysis<'a, I> for SimpleCombDepth<'a, I>
153where
154    I: Instantiable,
155{
156    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
157        let mut results: HashMap<NetRef<I>, CombDepthResult> = HashMap::new();
158        let mut visiting: HashSet<NetRef<I>> = HashSet::new();
159        let mut max_depth: Option<usize> = None;
160
161        fn compute<I: Instantiable>(
162            node: NetRef<I>,
163            netlist: &Netlist<I>,
164            results: &mut HashMap<NetRef<I>, CombDepthResult>,
165            visiting: &mut HashSet<NetRef<I>>,
166        ) -> CombDepthResult {
167            // Memoized result
168            if let Some(&r) = results.get(&node) {
169                return r;
170            }
171
172            // Cycle detection
173            if visiting.contains(&node) {
174                for n in visiting.iter() {
175                    results.insert(n.clone(), CombDepthResult::PartOfCycle);
176                }
177                return CombDepthResult::PartOfCycle;
178            }
179
180            // Input nodes have depth 0
181            if node.is_an_input() {
182                let r = CombDepthResult::Depth(0);
183                results.insert(node.clone(), r);
184                return r;
185            }
186
187            visiting.insert(node.clone());
188
189            let mut max_depth = 0;
190
191            for i in 0..node.get_num_input_ports() {
192                let driver = match netlist.get_driver(node.clone(), i) {
193                    Some(d) => d,
194                    None => {
195                        let r = CombDepthResult::Undefined;
196                        results.insert(node.clone(), r);
197                        visiting.remove(&node);
198                        return r;
199                    }
200                };
201
202                match compute(driver, netlist, results, visiting) {
203                    CombDepthResult::Depth(d) => {
204                        max_depth = max_depth.max(d);
205                    }
206                    CombDepthResult::Undefined => {
207                        let r = CombDepthResult::Undefined;
208                        results.insert(node.clone(), r);
209                        visiting.remove(&node);
210                        return r;
211                    }
212                    CombDepthResult::PartOfCycle => {
213                        let r = CombDepthResult::PartOfCycle;
214                        results.insert(node.clone(), r);
215                        visiting.remove(&node);
216                        return r;
217                    }
218                }
219            }
220
221            visiting.remove(&node);
222
223            let r = CombDepthResult::Depth(max_depth + 1);
224            results.insert(node.clone(), r);
225            r
226        }
227
228        for (driven, _) in netlist.outputs() {
229            let node = driven.unwrap();
230            let r = compute(node, netlist, &mut results, &mut visiting);
231
232            if let CombDepthResult::Depth(d) = r {
233                max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
234            }
235        }
236
237        Ok(SimpleCombDepth {
238            _netlist: netlist,
239            results,
240            max_depth,
241        })
242    }
243}
244
245/// An enum to provide pseudo-nodes for any misc user-programmable behavior.
246#[cfg(feature = "graph")]
247#[derive(Debug, Clone)]
248pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
249    /// A 'real' circuit node
250    NetRef(NetRef<I>),
251    /// Any other user-programmable node
252    Pseudo(T),
253}
254
255#[cfg(feature = "graph")]
256impl<I, T> std::fmt::Display for Node<I, T>
257where
258    I: Instantiable,
259    T: Clone + std::fmt::Debug + std::fmt::Display,
260{
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        match self {
263            Node::NetRef(nr) => nr.fmt(f),
264            Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
265        }
266    }
267}
268
269/// An enum to provide pseudo-edges for any misc user-programmable behavior.
270#[cfg(feature = "graph")]
271#[derive(Debug, Clone)]
272pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
273    /// A 'real' circuit connection
274    Connection(Connection<I>),
275    /// Any other user-programmable node
276    Pseudo(T),
277}
278
279#[cfg(feature = "graph")]
280impl<I, T> std::fmt::Display for Edge<I, T>
281where
282    I: Instantiable,
283    T: Clone + std::fmt::Debug + std::fmt::Display,
284{
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        match self {
287            Edge::Connection(c) => c.fmt(f),
288            Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
289        }
290    }
291}
292
293/// Returns a petgraph representation of the netlist as a directed multi-graph with type [DiGraph<Object, NetLabel>].
294#[cfg(feature = "graph")]
295pub struct MultiDiGraph<'a, I: Instantiable> {
296    _netlist: &'a Netlist<I>,
297    graph: DiGraph<Node<I, String>, Edge<I, Net>>,
298}
299
300#[cfg(feature = "graph")]
301impl<I> MultiDiGraph<'_, I>
302where
303    I: Instantiable,
304{
305    /// Return a reference to the graph constructed by this analysis
306    pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
307        &self.graph
308    }
309}
310
311#[cfg(feature = "graph")]
312impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
313where
314    I: Instantiable,
315{
316    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
317        // If we verify, we can hash by name
318        netlist.verify()?;
319        let mut mapping = HashMap::new();
320        let mut graph = DiGraph::new();
321
322        for obj in netlist.objects() {
323            let id = graph.add_node(Node::NetRef(obj.clone()));
324            mapping.insert(obj.to_string(), id);
325        }
326
327        for connection in netlist.connections() {
328            let source = connection.src().unwrap().get_obj().to_string();
329            let target = connection.target().unwrap().get_obj().to_string();
330            let s_id = mapping[&source];
331            let t_id = mapping[&target];
332            graph.add_edge(s_id, t_id, Edge::Connection(connection));
333        }
334
335        // Finally, add the output connections
336        for (o, n) in netlist.outputs() {
337            let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
338            let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
339            graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
340        }
341
342        Ok(Self {
343            _netlist: netlist,
344            graph,
345        })
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::{format_id, netlist::*};
353
354    fn full_adder() -> Gate {
355        Gate::new_logical_multi(
356            "FA".into(),
357            vec!["CIN".into(), "A".into(), "B".into()],
358            vec!["S".into(), "COUT".into()],
359        )
360    }
361
362    fn ripple_adder() -> GateNetlist {
363        let netlist = Netlist::new("ripple_adder".to_string());
364        let bitwidth = 4;
365
366        // Add the the inputs
367        let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
368        let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
369        let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
370
371        for (i, (a, b)) in a.into_iter().zip(b.into_iter()).enumerate() {
372            // Instantiate a full adder for each bit
373            let fa = netlist
374                .insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
375                .unwrap();
376
377            // Expose the sum
378            fa.expose_net(&fa.get_net(0)).unwrap();
379
380            carry = fa.find_output(&"COUT".into()).unwrap();
381
382            if i == bitwidth - 1 {
383                // Last full adder, expose the carry out
384                fa.get_output(1).expose_with_name("cout".into()).unwrap();
385            }
386        }
387
388        netlist.reclaim().unwrap()
389    }
390
391    #[test]
392    fn fanout_table() {
393        let netlist = ripple_adder();
394        let analysis = FanOutTable::build(&netlist);
395        assert!(analysis.is_ok());
396        let analysis = analysis.unwrap();
397        assert!(netlist.verify().is_ok());
398
399        for item in netlist.objects().filter(|o| !o.is_an_input()) {
400            // Sum bit has no users (it is a direct output)
401            assert!(
402                analysis
403                    .get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
404                    .next()
405                    .is_none(),
406                "Sum bit should not have users"
407            );
408
409            assert!(
410                item.get_instance_name().is_some(),
411                "Item should have a name. Filtered inputs"
412            );
413
414            let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
415            let mut cout_users = analysis.get_net_users(&net);
416            if item.get_instance_name().unwrap().to_string() != "fa_3" {
417                assert!(cout_users.next().is_some(), "Carry bit should have users");
418            }
419
420            assert!(
421                cout_users.next().is_none(),
422                "Carry bit should have 1 or 0 user"
423            );
424        }
425    }
426}