Skip to main content

dbsp/monitor/
circuit_graph.rs

1use crate::{
2    circuit::{GlobalNodeId, NodeId, metadata::OperatorLocation, trace::EdgeKind},
3    monitor::visual_graph::{
4        ClusterNode, Edge as VisEdge, Graph as VisGraph, Node as VisNode, SimpleNode,
5    },
6};
7use std::{
8    borrow::Cow,
9    collections::{HashMap, hash_map::Entry},
10    slice,
11};
12
13/// Region id is a path from the root of the region tree.
14#[derive(Debug, Clone, PartialEq, Eq)]
15#[repr(transparent)]
16pub(super) struct RegionId(Vec<usize>);
17
18impl RegionId {
19    /// Root region.
20    pub(super) fn root() -> Self {
21        Self(Vec::new())
22    }
23
24    /// Pop the innermost child, transforming region id into its parent region
25    /// id.
26    pub(super) fn pop(&mut self) {
27        self.0.pop();
28    }
29
30    pub(super) fn child(&self, child_id: usize) -> Self {
31        let mut path = Vec::with_capacity(self.0.len() + 1);
32        path.extend_from_slice(&self.0);
33        path.push(child_id);
34        Self(path)
35    }
36}
37
38/// A region is a named grouping of operators in a circuit.
39///
40/// Regions can be nested inside other regions, forming a tree.
41/// A circuit is created with a single root region.
42pub(super) struct Region {
43    id: RegionId,
44    pub(super) nodes: Vec<NodeId>,
45    name: Cow<'static, str>,
46    location: OperatorLocation,
47    children: Vec<Region>,
48}
49
50impl Region {
51    pub(super) fn new(id: RegionId, name: Cow<'static, str>, location: OperatorLocation) -> Self {
52        Self {
53            id,
54            nodes: Vec::new(),
55            name,
56            location,
57            children: Vec::new(),
58        }
59    }
60
61    /// Generate unique name for a region to use as a node label in a visual
62    /// graph.
63    fn region_identifier(node_id: &GlobalNodeId, region_id: &RegionId) -> String {
64        let mut region_ident = format!(
65            "{}{}",
66            Node::node_identifier(node_id),
67            if region_id.0.is_empty() { "" } else { "_r" }
68        );
69
70        for i in 0..region_id.0.len() {
71            region_ident.push_str(&region_id.0[i].to_string());
72            if i < region_id.0.len() - 1 {
73                region_ident.push('_');
74            }
75        }
76
77        region_ident
78    }
79
80    /// Output region as a cluster in a visual graph.
81    ///
82    /// # Arguments
83    ///
84    /// * `annotation` - annotation to attach to the region.
85    /// * `annotate` - function used to annotate nodes inside the region.
86    ///   Returns a label and an "importance" (between 0 and 1) for the node.
87    fn visualize(
88        &self,
89        scope: &Node,
90        annotation: &str,
91        annotate: &dyn Fn(&GlobalNodeId) -> (String, f64),
92    ) -> ClusterNode {
93        let mut nodes = Vec::new();
94        for nodeid in self.nodes.iter() {
95            if let Some(vnode) = scope
96                .children()
97                .unwrap()
98                .get(nodeid)
99                .unwrap()
100                .visualize(annotate)
101            {
102                nodes.push(vnode)
103            }
104        }
105
106        for child in self.children.iter() {
107            nodes.push(VisNode::Cluster(child.visualize(scope, "", annotate)));
108        }
109
110        ClusterNode::new(
111            Self::region_identifier(&scope.id, &self.id),
112            format!(
113                "{}{}{}",
114                label(&self.name, self.location),
115                if annotation.is_empty() { "" } else { "\\l" },
116                annotation
117            ),
118            nodes,
119        )
120    }
121
122    /// Output region as a cluster in a visual graph.
123    /// Similar to 'visualize', but does not merge halves of strict operators.
124    fn get_graph(&self, scope: &Node) -> ClusterNode {
125        let mut nodes = Vec::new();
126        for nodeid in self.nodes.iter() {
127            if let Some(vnode) = scope.children().unwrap().get(nodeid).unwrap().get_graph() {
128                nodes.push(vnode)
129            }
130        }
131
132        for child in self.children.iter() {
133            nodes.push(VisNode::Cluster(child.get_graph(scope)));
134        }
135
136        ClusterNode::new(
137            Self::region_identifier(&scope.id, &self.id),
138            label(&self.name, self.location),
139            nodes,
140        )
141    }
142
143    fn do_add_region(
144        &mut self,
145        path: &[usize],
146        name: Cow<'static, str>,
147        location: OperatorLocation,
148    ) -> RegionId {
149        match path.split_first() {
150            None => {
151                let new_region_id = self.id.child(self.children.len());
152                self.children
153                    .push(Region::new(new_region_id.clone(), name, location));
154                new_region_id
155            }
156            Some((id, ids)) => self.children[*id].do_add_region(ids, name, location),
157        }
158    }
159
160    /// Add a subregion to `self`.
161    ///
162    /// * `self` - must be a root region.
163    /// * `parent` - existing sub-region id.
164    /// * `description` - name of a new region to add as child to `parent`.
165    pub(super) fn add_region(
166        &mut self,
167        parent: &RegionId,
168        name: Cow<'static, str>,
169        location: OperatorLocation,
170    ) -> RegionId {
171        debug_assert_eq!(self.id, RegionId::root());
172        self.do_add_region(parent.0.as_slice(), name, location)
173    }
174
175    fn do_get_region(&mut self, path: &[usize]) -> &mut Region {
176        match path.split_first() {
177            None => self,
178            Some((id, ids)) => self.children[*id].do_get_region(ids),
179        }
180    }
181
182    /// Get a mutable reference to a subregion of `self`.
183    ///
184    /// * `self` - must be a root region.
185    /// * `region_id` - existing subregion id.
186    pub(super) fn get_region(&mut self, region_id: &RegionId) -> &mut Region {
187        debug_assert_eq!(self.id, RegionId::root());
188
189        self.do_get_region(region_id.0.as_slice())
190    }
191}
192
193pub(super) enum NodeKind {
194    /// Regular operator.
195    Operator,
196    /// Root circuit or subcircuit.
197    Circuit {
198        iterative: bool,
199        children: HashMap<NodeId, Node>,
200        region: Region,
201    },
202    /// The input half of a [strict
203    /// operator](`crate::circuit::operator_traits::StrictOperator`).
204    StrictInput { output: NodeId },
205    /// The output half of a strict operator.
206    StrictOutput,
207}
208
209/// A node in a circuit graph represents an operator or a circuit.
210pub(super) struct Node {
211    id: GlobalNodeId,
212    pub name: Cow<'static, str>,
213    pub location: OperatorLocation,
214    #[allow(dead_code)]
215    pub region_id: RegionId,
216    pub kind: NodeKind,
217}
218
219impl Node {
220    pub(super) fn new(
221        id: GlobalNodeId,
222        name: Cow<'static, str>,
223        location: OperatorLocation,
224        region_id: RegionId,
225        kind: NodeKind,
226    ) -> Self {
227        Self {
228            id,
229            name,
230            location,
231            region_id,
232            kind,
233        }
234    }
235
236    /// Lookup node in the subtree with the root in `self` by path.
237    fn node_ref(&self, mut path: slice::Iter<NodeId>) -> Option<&Node> {
238        match path.next() {
239            None => Some(self),
240            Some(node_id) => match &self.kind {
241                NodeKind::Circuit { children, .. } => children.get(node_id)?.node_ref(path),
242                _ => None,
243            },
244        }
245    }
246
247    /// Lookup node in the subtree with the root in `self` by path.
248    fn node_mut(&mut self, mut path: slice::Iter<NodeId>) -> Option<&mut Node> {
249        match path.next() {
250            None => Some(self),
251            Some(node_id) => match &mut self.kind {
252                NodeKind::Circuit { children, .. } => children.get_mut(node_id)?.node_mut(path),
253                _ => None,
254            },
255        }
256    }
257
258    /// `true` if `self` is a circuit node.
259    pub(super) fn is_circuit(&self) -> bool {
260        matches!(self.kind, NodeKind::Circuit { .. })
261    }
262
263    /// `true` if `self` is an iterative circuit (including the root
264    /// circuit).
265    pub(super) fn is_iterative(&self) -> bool {
266        matches!(
267            self.kind,
268            NodeKind::Circuit {
269                iterative: true,
270                ..
271            }
272        )
273    }
274
275    /// Returns children of `self` if `self` is a circuit.
276    pub(super) fn children(&self) -> Option<&HashMap<NodeId, Node>> {
277        if let NodeKind::Circuit { children, .. } = &self.kind {
278            Some(children)
279        } else {
280            None
281        }
282    }
283
284    /// Returns a mutable reference to the root region of `self` if
285    /// `self` is a circuit node.
286    pub(super) fn region_mut(&mut self) -> Option<&mut Region> {
287        if let NodeKind::Circuit { region, .. } = &mut self.kind {
288            Some(region)
289        } else {
290            None
291        }
292    }
293
294    /// `true` if `self` is the input half of a [strict
295    /// operator](`crate::circuit::operator_traits::StrictOperator`).
296    pub(super) fn is_strict_input(&self) -> bool {
297        matches!(self.kind, NodeKind::StrictInput { .. })
298    }
299
300    /// Returns `self.output` if `self` is a strict input operator.
301    pub(super) fn output_id(&self) -> Option<NodeId> {
302        if let NodeKind::StrictInput { output } = &self.kind {
303            Some(*output)
304        } else {
305            None
306        }
307    }
308
309    /// Generate unique name for the node to use as a node label in a visual
310    /// graph.
311    pub(super) fn node_identifier(node_id: &GlobalNodeId) -> String {
312        node_id.node_identifier()
313    }
314
315    /// Output circuit node as a node in a visual graph.
316    fn visualize(&self, annotate: &dyn Fn(&GlobalNodeId) -> (String, f64)) -> Option<VisNode> {
317        let (annotation, importance) = annotate(&self.id);
318
319        match &self.kind {
320            NodeKind::Operator => Some(VisNode::Simple(SimpleNode::new(
321                Self::node_identifier(&self.id),
322                format!(
323                    "{}{}{}",
324                    label(&self.name, self.location),
325                    if annotation.is_empty() { "" } else { "\\l" },
326                    annotation
327                ),
328                importance,
329            ))),
330
331            NodeKind::Circuit { region, .. } => Some(VisNode::Cluster(region.visualize(
332                self,
333                &annotation,
334                annotate,
335            ))),
336
337            NodeKind::StrictInput { output } => Some(VisNode::Simple(SimpleNode::new(
338                Self::node_identifier(&self.id.parent_id().unwrap().child(*output)),
339                format!(
340                    "{}{}{}",
341                    label(&self.name, self.location),
342                    if annotation.is_empty() { "" } else { "\\l" },
343                    annotation
344                ),
345                importance,
346            ))),
347            NodeKind::StrictOutput => None,
348        }
349    }
350
351    /// Output circuit node as a node in a visual graph, without merging the two halves of strict operators.
352    fn get_graph(&self) -> Option<VisNode> {
353        match &self.kind {
354            NodeKind::Operator => Some(VisNode::Simple(SimpleNode::new(
355                Self::node_identifier(&self.id),
356                label(&self.name, self.location),
357                0f64,
358            ))),
359
360            NodeKind::Circuit { region, .. } => Some(VisNode::Cluster(region.get_graph(self))),
361
362            NodeKind::StrictInput { .. } => Some(VisNode::Simple(SimpleNode::new(
363                Self::node_identifier(&self.id),
364                label(&self.name, self.location),
365                0f64,
366            ))),
367            NodeKind::StrictOutput => Some(VisNode::Simple(SimpleNode::new(
368                Self::node_identifier(&self.id),
369                format!("{}{}", label(&self.name, self.location), " (output)"),
370                0f64,
371            ))),
372        }
373    }
374}
375
376pub(super) struct CircuitGraph {
377    /// Tree of nodes.
378    nodes: Node,
379    /// Matches a node to the vector of nodes that read from its output
380    /// stream or have a dependency on it.
381    /// A node can occur in this vector multiple times.
382    edges: HashMap<GlobalNodeId, Vec<(GlobalNodeId, EdgeKind)>>,
383}
384
385impl CircuitGraph {
386    pub(super) fn new() -> Self {
387        Self {
388            nodes: Node::new(
389                GlobalNodeId::root(),
390                Cow::Borrowed("root"),
391                None,
392                RegionId::root(),
393                NodeKind::Circuit {
394                    iterative: true,
395                    children: HashMap::new(),
396                    region: Region::new(RegionId::root(), Cow::Borrowed("root"), None),
397                },
398            ),
399            edges: HashMap::new(),
400        }
401    }
402
403    /// Locate node by its global id.
404    pub(super) fn node_ref(&self, id: &GlobalNodeId) -> Option<&Node> {
405        self.nodes.node_ref(id.path().iter())
406    }
407
408    /// Locate node by its global id.
409    pub(super) fn node_mut(&mut self, id: &GlobalNodeId) -> Option<&mut Node> {
410        self.nodes.node_mut(id.path().iter())
411    }
412
413    pub(super) fn add_edge(&mut self, from: &GlobalNodeId, to: &GlobalNodeId, kind: &EdgeKind) {
414        match self.edges.entry(from.clone()) {
415            Entry::Occupied(mut oe) => {
416                oe.get_mut().push((to.clone(), kind.clone()));
417            }
418            Entry::Vacant(ve) => {
419                ve.insert(vec![(to.clone(), kind.clone())]);
420            }
421        }
422    }
423
424    /// Output circuit graph as visual graph.
425    pub(super) fn visualize(&self, annotate: &dyn Fn(&GlobalNodeId) -> (String, f64)) -> VisGraph {
426        let cluster = self.nodes.visualize(annotate).unwrap().cluster().unwrap();
427
428        let mut edges = Vec::new();
429
430        for (from_id, to) in self.edges.iter() {
431            let from_node = self.node_ref(from_id).unwrap();
432
433            for (to_id, _kind) in to.iter() {
434                let to_node = self.node_ref(to_id).unwrap();
435                let to_id = match to_node.kind {
436                    NodeKind::StrictInput { output } => to_id.parent_id().unwrap().child(output),
437                    _ => to_id.clone(),
438                };
439
440                // Don't draw self-loops on strict operators.
441                if from_id != &to_id {
442                    edges.push(VisEdge::new(
443                        Node::node_identifier(from_id),
444                        from_node.is_circuit(),
445                        Node::node_identifier(&to_id),
446                        to_node.is_circuit(),
447                    ));
448                }
449            }
450        }
451
452        VisGraph::new(cluster, edges)
453    }
454
455    /// Similar to 'visualize' without any annotations, but it does not merge the two halves of strict operators.
456    pub(super) fn get_graph(&self) -> VisGraph {
457        let cluster = self.nodes.get_graph().unwrap().cluster().unwrap();
458
459        let mut edges = Vec::new();
460
461        for (from_id, to) in self.edges.iter() {
462            let from_node = self.node_ref(from_id).unwrap();
463
464            for (to_id, _kind) in to.iter() {
465                let to_node = self.node_ref(to_id).unwrap();
466                edges.push(VisEdge::new(
467                    Node::node_identifier(from_id),
468                    from_node.is_circuit(),
469                    Node::node_identifier(to_id),
470                    to_node.is_circuit(),
471                ));
472            }
473        }
474
475        VisGraph::new(cluster, edges)
476    }
477}
478
479fn label(name: &str, location: OperatorLocation) -> String {
480    if let Some(location) = location {
481        let file = location
482            .file()
483            // Strip the crate's path from any of its operators
484            .trim_start_matches(env!("CARGO_MANIFEST_DIR"))
485            // Windows uses "\" for paths which dot interprets as an escape char
486            .replace('\\', "/");
487
488        // Abbreviate the file name to the first letter of each directory
489        // followed by the full name of the file.
490        let mut components = file.split('/');
491        let base_name = components.next_back().unwrap();
492        let mut file = String::new();
493        for dir_name in components {
494            if let Some(c) = dir_name.chars().next() {
495                file.push(c);
496                file.push('/');
497            }
498        }
499        file.push_str(base_name);
500
501        format!(
502            "{} @ {}:{}:{}",
503            name,
504            file,
505            location.line(),
506            location.column(),
507        )
508    } else {
509        name.to_owned()
510    }
511}