1use crate::{
2 circuit::{GlobalNodeId, NodeId, metadata::OperatorLocation, trace::EdgeKind},
3 monitor::visual_graph::{
4 ClusterNode, Edge as VisEdge, Graph as VisGraph, Node as VisNode, SimpleNode,
5 },
6};
7use std::{
8 borrow::Cow,
9 collections::{HashMap, hash_map::Entry},
10 slice,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15#[repr(transparent)]
16pub(super) struct RegionId(Vec<usize>);
17
18impl RegionId {
19 pub(super) fn root() -> Self {
21 Self(Vec::new())
22 }
23
24 pub(super) fn pop(&mut self) {
27 self.0.pop();
28 }
29
30 pub(super) fn child(&self, child_id: usize) -> Self {
31 let mut path = Vec::with_capacity(self.0.len() + 1);
32 path.extend_from_slice(&self.0);
33 path.push(child_id);
34 Self(path)
35 }
36}
37
38pub(super) struct Region {
43 id: RegionId,
44 pub(super) nodes: Vec<NodeId>,
45 name: Cow<'static, str>,
46 location: OperatorLocation,
47 children: Vec<Region>,
48}
49
50impl Region {
51 pub(super) fn new(id: RegionId, name: Cow<'static, str>, location: OperatorLocation) -> Self {
52 Self {
53 id,
54 nodes: Vec::new(),
55 name,
56 location,
57 children: Vec::new(),
58 }
59 }
60
61 fn region_identifier(node_id: &GlobalNodeId, region_id: &RegionId) -> String {
64 let mut region_ident = format!(
65 "{}{}",
66 Node::node_identifier(node_id),
67 if region_id.0.is_empty() { "" } else { "_r" }
68 );
69
70 for i in 0..region_id.0.len() {
71 region_ident.push_str(®ion_id.0[i].to_string());
72 if i < region_id.0.len() - 1 {
73 region_ident.push('_');
74 }
75 }
76
77 region_ident
78 }
79
80 fn visualize(
88 &self,
89 scope: &Node,
90 annotation: &str,
91 annotate: &dyn Fn(&GlobalNodeId) -> (String, f64),
92 ) -> ClusterNode {
93 let mut nodes = Vec::new();
94 for nodeid in self.nodes.iter() {
95 if let Some(vnode) = scope
96 .children()
97 .unwrap()
98 .get(nodeid)
99 .unwrap()
100 .visualize(annotate)
101 {
102 nodes.push(vnode)
103 }
104 }
105
106 for child in self.children.iter() {
107 nodes.push(VisNode::Cluster(child.visualize(scope, "", annotate)));
108 }
109
110 ClusterNode::new(
111 Self::region_identifier(&scope.id, &self.id),
112 format!(
113 "{}{}{}",
114 label(&self.name, self.location),
115 if annotation.is_empty() { "" } else { "\\l" },
116 annotation
117 ),
118 nodes,
119 )
120 }
121
122 fn get_graph(&self, scope: &Node) -> ClusterNode {
125 let mut nodes = Vec::new();
126 for nodeid in self.nodes.iter() {
127 if let Some(vnode) = scope.children().unwrap().get(nodeid).unwrap().get_graph() {
128 nodes.push(vnode)
129 }
130 }
131
132 for child in self.children.iter() {
133 nodes.push(VisNode::Cluster(child.get_graph(scope)));
134 }
135
136 ClusterNode::new(
137 Self::region_identifier(&scope.id, &self.id),
138 label(&self.name, self.location),
139 nodes,
140 )
141 }
142
143 fn do_add_region(
144 &mut self,
145 path: &[usize],
146 name: Cow<'static, str>,
147 location: OperatorLocation,
148 ) -> RegionId {
149 match path.split_first() {
150 None => {
151 let new_region_id = self.id.child(self.children.len());
152 self.children
153 .push(Region::new(new_region_id.clone(), name, location));
154 new_region_id
155 }
156 Some((id, ids)) => self.children[*id].do_add_region(ids, name, location),
157 }
158 }
159
160 pub(super) fn add_region(
166 &mut self,
167 parent: &RegionId,
168 name: Cow<'static, str>,
169 location: OperatorLocation,
170 ) -> RegionId {
171 debug_assert_eq!(self.id, RegionId::root());
172 self.do_add_region(parent.0.as_slice(), name, location)
173 }
174
175 fn do_get_region(&mut self, path: &[usize]) -> &mut Region {
176 match path.split_first() {
177 None => self,
178 Some((id, ids)) => self.children[*id].do_get_region(ids),
179 }
180 }
181
182 pub(super) fn get_region(&mut self, region_id: &RegionId) -> &mut Region {
187 debug_assert_eq!(self.id, RegionId::root());
188
189 self.do_get_region(region_id.0.as_slice())
190 }
191}
192
193pub(super) enum NodeKind {
194 Operator,
196 Circuit {
198 iterative: bool,
199 children: HashMap<NodeId, Node>,
200 region: Region,
201 },
202 StrictInput { output: NodeId },
205 StrictOutput,
207}
208
209pub(super) struct Node {
211 id: GlobalNodeId,
212 pub name: Cow<'static, str>,
213 pub location: OperatorLocation,
214 #[allow(dead_code)]
215 pub region_id: RegionId,
216 pub kind: NodeKind,
217}
218
219impl Node {
220 pub(super) fn new(
221 id: GlobalNodeId,
222 name: Cow<'static, str>,
223 location: OperatorLocation,
224 region_id: RegionId,
225 kind: NodeKind,
226 ) -> Self {
227 Self {
228 id,
229 name,
230 location,
231 region_id,
232 kind,
233 }
234 }
235
236 fn node_ref(&self, mut path: slice::Iter<NodeId>) -> Option<&Node> {
238 match path.next() {
239 None => Some(self),
240 Some(node_id) => match &self.kind {
241 NodeKind::Circuit { children, .. } => children.get(node_id)?.node_ref(path),
242 _ => None,
243 },
244 }
245 }
246
247 fn node_mut(&mut self, mut path: slice::Iter<NodeId>) -> Option<&mut Node> {
249 match path.next() {
250 None => Some(self),
251 Some(node_id) => match &mut self.kind {
252 NodeKind::Circuit { children, .. } => children.get_mut(node_id)?.node_mut(path),
253 _ => None,
254 },
255 }
256 }
257
258 pub(super) fn is_circuit(&self) -> bool {
260 matches!(self.kind, NodeKind::Circuit { .. })
261 }
262
263 pub(super) fn is_iterative(&self) -> bool {
266 matches!(
267 self.kind,
268 NodeKind::Circuit {
269 iterative: true,
270 ..
271 }
272 )
273 }
274
275 pub(super) fn children(&self) -> Option<&HashMap<NodeId, Node>> {
277 if let NodeKind::Circuit { children, .. } = &self.kind {
278 Some(children)
279 } else {
280 None
281 }
282 }
283
284 pub(super) fn region_mut(&mut self) -> Option<&mut Region> {
287 if let NodeKind::Circuit { region, .. } = &mut self.kind {
288 Some(region)
289 } else {
290 None
291 }
292 }
293
294 pub(super) fn is_strict_input(&self) -> bool {
297 matches!(self.kind, NodeKind::StrictInput { .. })
298 }
299
300 pub(super) fn output_id(&self) -> Option<NodeId> {
302 if let NodeKind::StrictInput { output } = &self.kind {
303 Some(*output)
304 } else {
305 None
306 }
307 }
308
309 pub(super) fn node_identifier(node_id: &GlobalNodeId) -> String {
312 node_id.node_identifier()
313 }
314
315 fn visualize(&self, annotate: &dyn Fn(&GlobalNodeId) -> (String, f64)) -> Option<VisNode> {
317 let (annotation, importance) = annotate(&self.id);
318
319 match &self.kind {
320 NodeKind::Operator => Some(VisNode::Simple(SimpleNode::new(
321 Self::node_identifier(&self.id),
322 format!(
323 "{}{}{}",
324 label(&self.name, self.location),
325 if annotation.is_empty() { "" } else { "\\l" },
326 annotation
327 ),
328 importance,
329 ))),
330
331 NodeKind::Circuit { region, .. } => Some(VisNode::Cluster(region.visualize(
332 self,
333 &annotation,
334 annotate,
335 ))),
336
337 NodeKind::StrictInput { output } => Some(VisNode::Simple(SimpleNode::new(
338 Self::node_identifier(&self.id.parent_id().unwrap().child(*output)),
339 format!(
340 "{}{}{}",
341 label(&self.name, self.location),
342 if annotation.is_empty() { "" } else { "\\l" },
343 annotation
344 ),
345 importance,
346 ))),
347 NodeKind::StrictOutput => None,
348 }
349 }
350
351 fn get_graph(&self) -> Option<VisNode> {
353 match &self.kind {
354 NodeKind::Operator => Some(VisNode::Simple(SimpleNode::new(
355 Self::node_identifier(&self.id),
356 label(&self.name, self.location),
357 0f64,
358 ))),
359
360 NodeKind::Circuit { region, .. } => Some(VisNode::Cluster(region.get_graph(self))),
361
362 NodeKind::StrictInput { .. } => Some(VisNode::Simple(SimpleNode::new(
363 Self::node_identifier(&self.id),
364 label(&self.name, self.location),
365 0f64,
366 ))),
367 NodeKind::StrictOutput => Some(VisNode::Simple(SimpleNode::new(
368 Self::node_identifier(&self.id),
369 format!("{}{}", label(&self.name, self.location), " (output)"),
370 0f64,
371 ))),
372 }
373 }
374}
375
376pub(super) struct CircuitGraph {
377 nodes: Node,
379 edges: HashMap<GlobalNodeId, Vec<(GlobalNodeId, EdgeKind)>>,
383}
384
385impl CircuitGraph {
386 pub(super) fn new() -> Self {
387 Self {
388 nodes: Node::new(
389 GlobalNodeId::root(),
390 Cow::Borrowed("root"),
391 None,
392 RegionId::root(),
393 NodeKind::Circuit {
394 iterative: true,
395 children: HashMap::new(),
396 region: Region::new(RegionId::root(), Cow::Borrowed("root"), None),
397 },
398 ),
399 edges: HashMap::new(),
400 }
401 }
402
403 pub(super) fn node_ref(&self, id: &GlobalNodeId) -> Option<&Node> {
405 self.nodes.node_ref(id.path().iter())
406 }
407
408 pub(super) fn node_mut(&mut self, id: &GlobalNodeId) -> Option<&mut Node> {
410 self.nodes.node_mut(id.path().iter())
411 }
412
413 pub(super) fn add_edge(&mut self, from: &GlobalNodeId, to: &GlobalNodeId, kind: &EdgeKind) {
414 match self.edges.entry(from.clone()) {
415 Entry::Occupied(mut oe) => {
416 oe.get_mut().push((to.clone(), kind.clone()));
417 }
418 Entry::Vacant(ve) => {
419 ve.insert(vec![(to.clone(), kind.clone())]);
420 }
421 }
422 }
423
424 pub(super) fn visualize(&self, annotate: &dyn Fn(&GlobalNodeId) -> (String, f64)) -> VisGraph {
426 let cluster = self.nodes.visualize(annotate).unwrap().cluster().unwrap();
427
428 let mut edges = Vec::new();
429
430 for (from_id, to) in self.edges.iter() {
431 let from_node = self.node_ref(from_id).unwrap();
432
433 for (to_id, _kind) in to.iter() {
434 let to_node = self.node_ref(to_id).unwrap();
435 let to_id = match to_node.kind {
436 NodeKind::StrictInput { output } => to_id.parent_id().unwrap().child(output),
437 _ => to_id.clone(),
438 };
439
440 if from_id != &to_id {
442 edges.push(VisEdge::new(
443 Node::node_identifier(from_id),
444 from_node.is_circuit(),
445 Node::node_identifier(&to_id),
446 to_node.is_circuit(),
447 ));
448 }
449 }
450 }
451
452 VisGraph::new(cluster, edges)
453 }
454
455 pub(super) fn get_graph(&self) -> VisGraph {
457 let cluster = self.nodes.get_graph().unwrap().cluster().unwrap();
458
459 let mut edges = Vec::new();
460
461 for (from_id, to) in self.edges.iter() {
462 let from_node = self.node_ref(from_id).unwrap();
463
464 for (to_id, _kind) in to.iter() {
465 let to_node = self.node_ref(to_id).unwrap();
466 edges.push(VisEdge::new(
467 Node::node_identifier(from_id),
468 from_node.is_circuit(),
469 Node::node_identifier(to_id),
470 to_node.is_circuit(),
471 ));
472 }
473 }
474
475 VisGraph::new(cluster, edges)
476 }
477}
478
479fn label(name: &str, location: OperatorLocation) -> String {
480 if let Some(location) = location {
481 let file = location
482 .file()
483 .trim_start_matches(env!("CARGO_MANIFEST_DIR"))
485 .replace('\\', "/");
487
488 let mut components = file.split('/');
491 let base_name = components.next_back().unwrap();
492 let mut file = String::new();
493 for dir_name in components {
494 if let Some(c) = dir_name.chars().next() {
495 file.push(c);
496 file.push('/');
497 }
498 }
499 file.push_str(base_name);
500
501 format!(
502 "{} @ {}:{}:{}",
503 name,
504 file,
505 location.line(),
506 location.column(),
507 )
508 } else {
509 name.to_owned()
510 }
511}