1mod impls;
4pub mod petgraph;
5pub mod render;
6mod rerooted;
7mod root_checked;
8pub mod sibling_subgraph;
9
10#[cfg(test)]
11mod tests;
12
13use std::borrow::Cow;
14use std::collections::HashMap;
15
16pub use self::petgraph::PetgraphWrapper;
17#[allow(deprecated)]
18use self::render::{MermaidFormatter, RenderConfig};
19pub use rerooted::Rerooted;
20pub use root_checked::{InvalidSignature, RootCheckable, RootChecked, check_tag};
21pub use sibling_subgraph::SiblingSubgraph;
22
23use itertools::Itertools;
24use portgraph::render::{DotFormat, MermaidFormat};
25use portgraph::{LinkView, PortView};
26
27use super::internal::{HugrInternals, HugrMutInternals};
28use super::validate::ValidationContext;
29use super::{Hugr, HugrMut, Node, NodeMetadata, ValidationError};
30use crate::core::HugrNode;
31use crate::extension::ExtensionRegistry;
32use crate::ops::handle::NodeHandle;
33use crate::ops::{OpParent, OpTag, OpTrait, OpType};
34
35use crate::types::{EdgeKind, PolyFuncType, Signature, Type};
36use crate::{Direction, IncomingPort, OutgoingPort, Port};
37
38use itertools::Either;
39
40pub trait HugrView: HugrInternals {
43 fn entrypoint(&self) -> Self::Node;
50
51 #[inline]
53 fn entrypoint_optype(&self) -> &OpType {
54 self.get_optype(self.entrypoint())
55 }
56
57 #[inline]
67 fn entrypoint_tag(&self) -> OpTag {
68 self.entrypoint_optype().tag()
69 }
70
71 fn with_entrypoint(&self, entrypoint: Self::Node) -> Rerooted<&Self>
79 where
80 Self: Sized,
81 {
82 Rerooted::new(self, entrypoint)
83 }
84
85 fn module_root(&self) -> Self::Node;
93
94 fn contains_node(&self, node: Self::Node) -> bool;
96
97 fn get_parent(&self, node: Self::Node) -> Option<Self::Node>;
99
100 #[inline]
102 fn get_metadata(&self, node: Self::Node, key: impl AsRef<str>) -> Option<&NodeMetadata> {
103 if self.contains_node(node) {
104 self.node_metadata_map(node).get(key.as_ref())
105 } else {
106 None
107 }
108 }
109
110 fn get_optype(&self, node: Self::Node) -> &OpType;
116
117 fn num_nodes(&self) -> usize;
119
120 fn num_edges(&self) -> usize;
122
123 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize;
125
126 #[inline]
129 fn num_inputs(&self, node: Self::Node) -> usize {
130 self.num_ports(node, Direction::Incoming)
131 }
132
133 #[inline]
136 fn num_outputs(&self, node: Self::Node) -> usize {
137 self.num_ports(node, Direction::Outgoing)
138 }
139
140 fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone;
149
150 fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone;
152
153 #[inline]
157 fn node_outputs(&self, node: Self::Node) -> impl Iterator<Item = OutgoingPort> + Clone {
158 self.node_ports(node, Direction::Outgoing)
159 .map(|p| p.as_outgoing().unwrap())
160 }
161
162 #[inline]
166 fn node_inputs(&self, node: Self::Node) -> impl Iterator<Item = IncomingPort> + Clone {
167 self.node_ports(node, Direction::Incoming)
168 .map(|p| p.as_incoming().unwrap())
169 }
170
171 fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone;
173
174 fn linked_ports(
176 &self,
177 node: Self::Node,
178 port: impl Into<Port>,
179 ) -> impl Iterator<Item = (Self::Node, Port)> + Clone;
180
181 fn all_linked_ports(
183 &self,
184 node: Self::Node,
185 dir: Direction,
186 ) -> Either<
187 impl Iterator<Item = (Self::Node, OutgoingPort)>,
188 impl Iterator<Item = (Self::Node, IncomingPort)>,
189 > {
190 match dir {
191 Direction::Incoming => Either::Left(
192 self.node_inputs(node)
193 .flat_map(move |port| self.linked_outputs(node, port)),
194 ),
195 Direction::Outgoing => Either::Right(
196 self.node_outputs(node)
197 .flat_map(move |port| self.linked_inputs(node, port)),
198 ),
199 }
200 }
201
202 fn all_linked_outputs(
204 &self,
205 node: Self::Node,
206 ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
207 self.all_linked_ports(node, Direction::Incoming)
208 .left()
209 .unwrap()
210 }
211
212 fn all_linked_inputs(
214 &self,
215 node: Self::Node,
216 ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
217 self.all_linked_ports(node, Direction::Outgoing)
218 .right()
219 .unwrap()
220 }
221
222 fn single_linked_port(
225 &self,
226 node: Self::Node,
227 port: impl Into<Port>,
228 ) -> Option<(Self::Node, Port)> {
229 self.linked_ports(node, port).exactly_one().ok()
230 }
231
232 fn single_linked_output(
235 &self,
236 node: Self::Node,
237 port: impl Into<IncomingPort>,
238 ) -> Option<(Self::Node, OutgoingPort)> {
239 self.single_linked_port(node, port.into())
240 .map(|(n, p)| (n, p.as_outgoing().unwrap()))
241 }
242
243 fn single_linked_input(
246 &self,
247 node: Self::Node,
248 port: impl Into<OutgoingPort>,
249 ) -> Option<(Self::Node, IncomingPort)> {
250 self.single_linked_port(node, port.into())
251 .map(|(n, p)| (n, p.as_incoming().unwrap()))
252 }
253 fn linked_outputs(
257 &self,
258 node: Self::Node,
259 port: impl Into<IncomingPort>,
260 ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
261 self.linked_ports(node, port.into())
262 .map(|(n, p)| (n, p.as_outgoing().unwrap()))
263 }
264
265 fn linked_inputs(
269 &self,
270 node: Self::Node,
271 port: impl Into<OutgoingPort>,
272 ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
273 self.linked_ports(node, port.into())
274 .map(|(n, p)| (n, p.as_incoming().unwrap()))
275 }
276
277 fn node_connections(
279 &self,
280 node: Self::Node,
281 other: Self::Node,
282 ) -> impl Iterator<Item = [Port; 2]> + Clone;
283
284 fn is_linked(&self, node: Self::Node, port: impl Into<Port>) -> bool {
286 self.linked_ports(node, port).next().is_some()
287 }
288
289 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone;
291
292 fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
297
298 fn entry_descendants(&self) -> impl Iterator<Item = Self::Node> + Clone {
303 self.descendants(self.entrypoint())
304 }
305
306 fn first_child(&self, node: Self::Node) -> Option<Self::Node> {
309 self.children(node).next()
310 }
311
312 fn neighbours(
315 &self,
316 node: Self::Node,
317 dir: Direction,
318 ) -> impl Iterator<Item = Self::Node> + Clone;
319
320 #[inline]
323 fn input_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
324 self.neighbours(node, Direction::Incoming)
325 }
326
327 #[inline]
330 fn output_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
331 self.neighbours(node, Direction::Outgoing)
332 }
333
334 fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
336
337 #[inline]
340 fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]> {
341 let op = self.get_optype(node);
342 if OpTag::DataflowParent.is_superset(op.tag()) {
344 self.children(node).take(2).collect_vec().try_into().ok()
345 } else {
346 None
347 }
348 }
349
350 fn inner_function_type(&self) -> Option<Cow<'_, Signature>> {
360 self.entrypoint_optype().inner_function_type()
361 }
362
363 fn poly_func_type(&self) -> Option<PolyFuncType> {
366 match self.entrypoint_optype() {
367 OpType::FuncDecl(decl) => Some(decl.signature().clone()),
368 OpType::FuncDefn(defn) => Some(defn.signature().clone()),
369 _ => None,
370 }
371 }
372
373 #[inline]
375 fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>
376 where
377 Self: Sized,
378 {
379 PetgraphWrapper { hugr: self }
380 }
381
382 fn mermaid_string(&self) -> String {
390 self.mermaid_string_with_formatter(self.mermaid_format())
391 }
392
393 #[deprecated(note = "Use `mermaid_format` instead")]
401 #[allow(deprecated)]
402 fn mermaid_string_with_config(&self, config: RenderConfig<Self::Node>) -> String;
403
404 fn mermaid_string_with_formatter(&self, formatter: MermaidFormatter<Self>) -> String {
420 #[allow(deprecated)]
421 let config = match RenderConfig::try_from(formatter) {
422 Ok(config) => config,
423 Err(e) => {
424 panic!("Unsupported format option: {e}");
425 }
426 };
427 #[allow(deprecated)]
428 self.mermaid_string_with_config(config)
429 }
430
431 fn mermaid_format(&self) -> MermaidFormatter<Self> {
442 MermaidFormatter::new(self).with_entrypoint(self.entrypoint())
443 }
444
445 fn dot_string(&self) -> String
449 where
450 Self: Sized;
451
452 fn static_source(&self, node: Self::Node) -> Option<Self::Node> {
454 self.linked_outputs(node, self.get_optype(node).static_input_port()?)
455 .next()
456 .map(|(n, _)| n)
457 }
458
459 fn static_targets(
461 &self,
462 node: Self::Node,
463 ) -> Option<impl Iterator<Item = (Self::Node, IncomingPort)>> {
464 Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?))
465 }
466
467 fn signature(&self, node: Self::Node) -> Option<Cow<'_, Signature>> {
470 self.get_optype(node).dataflow_signature()
471 }
472
473 fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = (Port, Type)> {
476 let sig = self.signature(node).unwrap_or_default();
477 self.node_ports(node, dir)
478 .filter_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone())))
479 }
480
481 fn in_value_types(&self, node: Self::Node) -> impl Iterator<Item = (IncomingPort, Type)> {
484 self.value_types(node, Direction::Incoming)
485 .map(|(p, t)| (p.as_incoming().unwrap(), t))
486 }
487
488 fn out_value_types(&self, node: Self::Node) -> impl Iterator<Item = (OutgoingPort, Type)> {
491 self.value_types(node, Direction::Outgoing)
492 .map(|(p, t)| (p.as_outgoing().unwrap(), t))
493 }
494
495 fn extensions(&self) -> &ExtensionRegistry;
500
501 fn validate(&self) -> Result<(), ValidationError<Self::Node>>
503 where
504 Self: Sized,
505 {
506 let mut validator = ValidationContext::new(self);
507 validator.validate()
508 }
509
510 fn extract_hugr(
524 &self,
525 parent: Self::Node,
526 ) -> (Hugr, impl ExtractionResult<Self::Node> + 'static);
527}
528
529pub trait ExtractionResult<SourceN> {
534 fn extracted_node(&self, node: SourceN) -> Node;
540}
541
542struct DefaultNodeMap(HashMap<Node, Node>);
544
545impl ExtractionResult<Node> for DefaultNodeMap {
546 #[inline]
547 fn extracted_node(&self, node: Node) -> Node {
548 self.0.get(&node).copied().unwrap_or(node)
549 }
550}
551
552impl<S: HugrNode> ExtractionResult<S> for HashMap<S, Node> {
553 #[inline]
554 fn extracted_node(&self, node: S) -> Node {
555 self[&node]
556 }
557}
558
559impl HugrView for Hugr {
560 #[inline]
561 fn entrypoint(&self) -> Self::Node {
562 self.entrypoint.into()
563 }
564
565 #[inline]
566 fn module_root(&self) -> Self::Node {
567 let node: Self::Node = self.module_root.into();
568 let handle = node.try_cast();
569 debug_assert!(
570 handle.is_some(),
571 "The root node in a HUGR must be a module."
572 );
573 handle.unwrap()
574 }
575
576 #[inline]
577 fn contains_node(&self, node: Self::Node) -> bool {
578 self.graph.contains_node(node.into_portgraph())
579 }
580
581 #[inline]
582 fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
583 if !check_valid_non_root(self, node) {
584 return None;
585 }
586 self.hierarchy.parent(node.into_portgraph()).map(Into::into)
587 }
588
589 #[inline]
590 fn get_optype(&self, node: Node) -> &OpType {
591 panic_invalid_node(self, node);
592 self.op_types.get(node.into_portgraph())
593 }
594
595 #[inline]
596 fn num_nodes(&self) -> usize {
597 self.graph.node_count()
598 }
599
600 #[inline]
601 fn num_edges(&self) -> usize {
602 self.graph.link_count()
603 }
604
605 #[inline]
606 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize {
607 self.graph.num_ports(node.into_portgraph(), dir)
608 }
609
610 #[inline]
611 fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
612 self.graph.nodes_iter().map_into()
613 }
614
615 #[inline]
616 fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
617 self.graph
618 .port_offsets(node.into_portgraph(), dir)
619 .map_into()
620 }
621
622 #[inline]
623 fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
624 self.graph
625 .all_port_offsets(node.into_portgraph())
626 .map_into()
627 }
628
629 #[inline]
630 fn linked_ports(
631 &self,
632 node: Node,
633 port: impl Into<Port>,
634 ) -> impl Iterator<Item = (Node, Port)> + Clone {
635 let port = port.into();
636
637 let port = self
638 .graph
639 .port_index(node.into_portgraph(), port.pg_offset())
640 .unwrap();
641 self.graph.port_links(port).map(|(_, link)| {
642 let port = link.port();
643 let node = self.graph.port_node(port).unwrap();
644 let offset = self.graph.port_offset(port).unwrap();
645 (node.into(), offset.into())
646 })
647 }
648
649 #[inline]
650 fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
651 self.graph
652 .get_connections(node.into_portgraph(), other.into_portgraph())
653 .map(|(p1, p2)| {
654 [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into())
655 })
656 }
657
658 #[inline]
659 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
660 self.hierarchy.children(node.into_portgraph()).map_into()
661 }
662
663 #[inline]
664 fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
665 self.hierarchy.descendants(node.into_portgraph()).map_into()
666 }
667
668 #[inline]
669 fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
670 self.graph.neighbours(node.into_portgraph(), dir).map_into()
671 }
672
673 #[inline]
674 fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
675 self.graph.all_neighbours(node.into_portgraph()).map_into()
676 }
677
678 #[allow(deprecated)]
679 fn mermaid_string_with_config(&self, config: RenderConfig) -> String {
680 self.mermaid_string_with_formatter(MermaidFormatter::from_render_config(config, self))
681 }
682
683 fn mermaid_string_with_formatter(&self, formatter: MermaidFormatter<Self>) -> String {
684 self.graph
685 .mermaid_format()
686 .with_hierarchy(&self.hierarchy)
687 .with_node_style(render::node_style(self, formatter.clone()))
688 .with_edge_style(render::edge_style(self, formatter))
689 .finish()
690 }
691
692 fn dot_string(&self) -> String
693 where
694 Self: Sized,
695 {
696 let formatter = MermaidFormatter::new(self).with_entrypoint(self.entrypoint());
697 self.graph
698 .dot_format()
699 .with_hierarchy(&self.hierarchy)
700 .with_node_style(render::node_style(self, formatter.clone()))
701 .with_port_style(render::port_style(self))
702 .with_edge_style(render::edge_style(self, formatter))
703 .finish()
704 }
705
706 #[inline]
707 fn extensions(&self) -> &ExtensionRegistry {
708 &self.extensions
709 }
710
711 #[inline]
712 fn extract_hugr(&self, target: Node) -> (Hugr, impl ExtractionResult<Node> + 'static) {
713 if target == self.module_root().node() {
715 return (self.clone(), DefaultNodeMap(HashMap::new()));
716 }
717
718 let mut parent = target;
724 let mut extracted = loop {
725 let parent_op = self.get_optype(parent).clone();
726 if let Ok(hugr) = Hugr::new_with_entrypoint(parent_op) {
727 break hugr;
728 }
729 parent = self
732 .get_parent(parent)
733 .expect("The module root is always extractable");
734 };
735
736 let old_entrypoint = extracted.entrypoint();
739 let old_parent = extracted.get_parent(old_entrypoint);
740
741 let inserted = extracted.insert_from_view(old_entrypoint, &self.with_entrypoint(parent));
742 let new_entrypoint = inserted.inserted_entrypoint;
743
744 match old_parent {
745 Some(old_parent) => {
746 let old_ins = extracted
749 .node_inputs(old_entrypoint)
750 .flat_map(|inp| {
751 extracted
752 .linked_outputs(old_entrypoint, inp)
753 .map(move |link| (inp, link))
754 })
755 .collect_vec();
756 let old_outs = extracted
757 .node_outputs(old_entrypoint)
758 .flat_map(|out| {
759 extracted
760 .linked_inputs(old_entrypoint, out)
761 .map(move |link| (out, link))
762 })
763 .collect_vec();
764 extracted.set_entrypoint(inserted.node_map[&target]);
766 extracted.remove_node(old_entrypoint);
767 extracted.set_parent(new_entrypoint, old_parent);
768 for (inp, (neigh, neigh_out)) in old_ins {
770 extracted.connect(neigh, neigh_out, new_entrypoint, inp);
771 }
772 for (out, (neigh, neigh_in)) in old_outs {
773 extracted.connect(new_entrypoint, out, neigh, neigh_in);
774 }
775 }
776 None => {
778 extracted.set_entrypoint(inserted.node_map[&target]);
779 extracted.set_module_root(new_entrypoint);
780 extracted.remove_node(old_entrypoint);
781 }
782 }
783 (extracted, DefaultNodeMap(inserted.node_map))
784 }
785}
786
787pub trait PortIterator<P>: Iterator<Item = (Node, P)>
789where
790 P: Into<Port> + Copy,
791 Self: Sized,
792{
793 fn dataflow_ports_only(
796 self,
797 hugr: &impl HugrView<Node = Node>,
798 ) -> impl Iterator<Item = (Node, P)> {
799 self.filter_edge_kind(
800 |kind| matches!(kind, Some(EdgeKind::Value(..) | EdgeKind::StateOrder)),
801 hugr,
802 )
803 }
804
805 fn filter_edge_kind(
807 self,
808 predicate: impl Fn(Option<EdgeKind>) -> bool,
809 hugr: &impl HugrView<Node = Node>,
810 ) -> impl Iterator<Item = (Node, P)> {
811 self.filter(move |(n, p)| {
812 let kind = HugrView::get_optype(hugr, *n).port_kind(*p);
813 predicate(kind)
814 })
815 }
816}
817
818impl<I, P> PortIterator<P> for I
819where
820 I: Iterator<Item = (Node, P)>,
821 P: Into<Port> + Copy,
822{
823}
824
825pub(super) fn check_valid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
827 hugr.contains_node(node) && node != hugr.entrypoint()
828}
829
830pub(super) fn check_valid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
832 hugr.contains_node(node) && node != hugr.module_root().node()
833}
834
835#[track_caller]
837pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
838 assert!(hugr.contains_node(node), "Received an invalid node {node}.",);
839}
840
841#[track_caller]
843pub(super) fn panic_invalid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
844 assert!(
845 check_valid_non_entrypoint(hugr, node),
846 "Received an invalid non-entrypoint node {node}.",
847 );
848}
849
850#[track_caller]
852pub(super) fn panic_invalid_port(hugr: &Hugr, node: Node, port: impl Into<Port>) {
853 let port = port.into();
854 if hugr
855 .graph
856 .port_index(node.into_portgraph(), port.pg_offset())
857 .is_none()
858 {
859 panic!("Received an invalid {port} for {node} while mutating a HUGR");
860 }
861}