1mod impls;
4mod nodes_iter;
5pub mod petgraph;
6pub mod render;
7mod rerooted;
8mod root_checked;
9pub mod sibling_subgraph;
10
11#[cfg(test)]
12mod tests;
13
14use std::borrow::Cow;
15use std::collections::HashMap;
16
17pub use self::petgraph::PetgraphWrapper;
18#[expect(deprecated)]
19use self::render::{MermaidFormatter, RenderConfig};
20pub use nodes_iter::NodesIter;
21pub use rerooted::Rerooted;
22pub use root_checked::{InvalidSignature, RootCheckable, RootChecked, check_tag};
23pub use sibling_subgraph::SiblingSubgraph;
24
25use itertools::Itertools;
26use portgraph::render::{DotFormat, MermaidFormat};
27use portgraph::{LinkView, PortView};
28
29use super::internal::{HugrInternals, HugrMutInternals};
30use super::validate::ValidationContext;
31use super::{Hugr, HugrMut, Node, NodeMetadata, ValidationError};
32use crate::core::HugrNode;
33use crate::extension::ExtensionRegistry;
34use crate::ops::handle::NodeHandle;
35use crate::ops::{OpParent, OpTag, OpTrait, OpType};
36
37use crate::types::{EdgeKind, PolyFuncType, Signature, Type};
38use crate::{Direction, IncomingPort, OutgoingPort, Port};
39
40use itertools::Either;
41
42pub trait HugrView: HugrInternals {
45 fn entrypoint(&self) -> Self::Node;
52
53 #[inline]
55 fn entrypoint_optype(&self) -> &OpType {
56 self.get_optype(self.entrypoint())
57 }
58
59 #[inline]
69 fn entrypoint_tag(&self) -> OpTag {
70 self.entrypoint_optype().tag()
71 }
72
73 fn with_entrypoint(&self, entrypoint: Self::Node) -> Rerooted<&Self>
81 where
82 Self: Sized,
83 {
84 Rerooted::new(self, entrypoint)
85 }
86
87 fn module_root(&self) -> Self::Node;
95
96 fn contains_node(&self, node: Self::Node) -> bool;
98
99 fn get_parent(&self, node: Self::Node) -> Option<Self::Node>;
101
102 #[inline]
104 fn get_metadata(&self, node: Self::Node, key: impl AsRef<str>) -> Option<&NodeMetadata> {
105 if self.contains_node(node) {
106 self.node_metadata_map(node).get(key.as_ref())
107 } else {
108 None
109 }
110 }
111
112 fn get_optype(&self, node: Self::Node) -> &OpType;
118
119 fn num_nodes(&self) -> usize;
121
122 fn num_edges(&self) -> usize;
124
125 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize;
127
128 #[inline]
131 fn num_inputs(&self, node: Self::Node) -> usize {
132 self.num_ports(node, Direction::Incoming)
133 }
134
135 #[inline]
138 fn num_outputs(&self, node: Self::Node) -> usize {
139 self.num_ports(node, Direction::Outgoing)
140 }
141
142 fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone;
151
152 fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone;
154
155 #[inline]
159 fn node_outputs(&self, node: Self::Node) -> impl Iterator<Item = OutgoingPort> + Clone {
160 self.node_ports(node, Direction::Outgoing)
161 .map(|p| p.as_outgoing().unwrap())
162 }
163
164 #[inline]
168 fn node_inputs(&self, node: Self::Node) -> impl Iterator<Item = IncomingPort> + Clone {
169 self.node_ports(node, Direction::Incoming)
170 .map(|p| p.as_incoming().unwrap())
171 }
172
173 fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone;
175
176 fn linked_ports(
178 &self,
179 node: Self::Node,
180 port: impl Into<Port>,
181 ) -> impl Iterator<Item = (Self::Node, Port)> + Clone;
182
183 fn all_linked_ports(
185 &self,
186 node: Self::Node,
187 dir: Direction,
188 ) -> Either<
189 impl Iterator<Item = (Self::Node, OutgoingPort)>,
190 impl Iterator<Item = (Self::Node, IncomingPort)>,
191 > {
192 match dir {
193 Direction::Incoming => Either::Left(
194 self.node_inputs(node)
195 .flat_map(move |port| self.linked_outputs(node, port)),
196 ),
197 Direction::Outgoing => Either::Right(
198 self.node_outputs(node)
199 .flat_map(move |port| self.linked_inputs(node, port)),
200 ),
201 }
202 }
203
204 fn all_linked_outputs(
206 &self,
207 node: Self::Node,
208 ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
209 self.all_linked_ports(node, Direction::Incoming)
210 .left()
211 .unwrap()
212 }
213
214 fn all_linked_inputs(
216 &self,
217 node: Self::Node,
218 ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
219 self.all_linked_ports(node, Direction::Outgoing)
220 .right()
221 .unwrap()
222 }
223
224 fn single_linked_port(
227 &self,
228 node: Self::Node,
229 port: impl Into<Port>,
230 ) -> Option<(Self::Node, Port)> {
231 self.linked_ports(node, port).exactly_one().ok()
232 }
233
234 fn single_linked_output(
237 &self,
238 node: Self::Node,
239 port: impl Into<IncomingPort>,
240 ) -> Option<(Self::Node, OutgoingPort)> {
241 self.single_linked_port(node, port.into())
242 .map(|(n, p)| (n, p.as_outgoing().unwrap()))
243 }
244
245 fn single_linked_input(
248 &self,
249 node: Self::Node,
250 port: impl Into<OutgoingPort>,
251 ) -> Option<(Self::Node, IncomingPort)> {
252 self.single_linked_port(node, port.into())
253 .map(|(n, p)| (n, p.as_incoming().unwrap()))
254 }
255 fn linked_outputs(
259 &self,
260 node: Self::Node,
261 port: impl Into<IncomingPort>,
262 ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
263 self.linked_ports(node, port.into())
264 .map(|(n, p)| (n, p.as_outgoing().unwrap()))
265 }
266
267 fn linked_inputs(
271 &self,
272 node: Self::Node,
273 port: impl Into<OutgoingPort>,
274 ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
275 self.linked_ports(node, port.into())
276 .map(|(n, p)| (n, p.as_incoming().unwrap()))
277 }
278
279 fn node_connections(
281 &self,
282 node: Self::Node,
283 other: Self::Node,
284 ) -> impl Iterator<Item = [Port; 2]> + Clone;
285
286 fn is_linked(&self, node: Self::Node, port: impl Into<Port>) -> bool {
288 self.linked_ports(node, port).next().is_some()
289 }
290
291 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone;
293
294 fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
299
300 fn entry_descendants(&self) -> impl Iterator<Item = Self::Node> + Clone {
305 self.descendants(self.entrypoint())
306 }
307
308 fn first_child(&self, node: Self::Node) -> Option<Self::Node> {
311 self.children(node).next()
312 }
313
314 fn neighbours(
317 &self,
318 node: Self::Node,
319 dir: Direction,
320 ) -> impl Iterator<Item = Self::Node> + Clone;
321
322 #[inline]
325 fn input_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
326 self.neighbours(node, Direction::Incoming)
327 }
328
329 #[inline]
332 fn output_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
333 self.neighbours(node, Direction::Outgoing)
334 }
335
336 fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
338
339 #[inline]
342 fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]> {
343 let op = self.get_optype(node);
344 if OpTag::DataflowParent.is_superset(op.tag()) {
346 self.children(node).take(2).collect_vec().try_into().ok()
347 } else {
348 None
349 }
350 }
351
352 fn inner_function_type(&self) -> Option<Cow<'_, Signature>> {
362 self.entrypoint_optype().inner_function_type()
363 }
364
365 fn poly_func_type(&self) -> Option<PolyFuncType> {
368 match self.entrypoint_optype() {
369 OpType::FuncDecl(decl) => Some(decl.signature().clone()),
370 OpType::FuncDefn(defn) => Some(defn.signature().clone()),
371 _ => None,
372 }
373 }
374
375 #[inline]
377 fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>
378 where
379 Self: Sized,
380 {
381 PetgraphWrapper { hugr: self }
382 }
383
384 fn mermaid_string(&self) -> String {
392 self.mermaid_string_with_formatter(self.mermaid_format())
393 }
394
395 #[deprecated(note = "Use `mermaid_format` instead", since = "0.20.2")]
403 #[expect(deprecated)]
404 fn mermaid_string_with_config(&self, config: RenderConfig<Self::Node>) -> String;
405
406 fn mermaid_string_with_formatter(&self, formatter: MermaidFormatter<Self>) -> String {
422 #[expect(deprecated)]
423 let config = match RenderConfig::try_from(formatter) {
424 Ok(config) => config,
425 Err(e) => {
426 panic!("Unsupported format option: {e}");
427 }
428 };
429 #[expect(deprecated)]
430 self.mermaid_string_with_config(config)
431 }
432
433 fn mermaid_format(&self) -> MermaidFormatter<'_, Self> {
444 MermaidFormatter::new(self).with_entrypoint(self.entrypoint())
445 }
446
447 fn dot_string(&self) -> String
451 where
452 Self: Sized;
453
454 fn static_source(&self, node: Self::Node) -> Option<Self::Node> {
456 self.linked_outputs(node, self.get_optype(node).static_input_port()?)
457 .next()
458 .map(|(n, _)| n)
459 }
460
461 fn static_targets(
463 &self,
464 node: Self::Node,
465 ) -> Option<impl Iterator<Item = (Self::Node, IncomingPort)>> {
466 Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?))
467 }
468
469 fn signature(&self, node: Self::Node) -> Option<Cow<'_, Signature>> {
472 self.get_optype(node).dataflow_signature()
473 }
474
475 fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = (Port, Type)> {
478 let sig = self.signature(node).unwrap_or_default();
479 self.node_ports(node, dir)
480 .filter_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone())))
481 }
482
483 fn in_value_types(&self, node: Self::Node) -> impl Iterator<Item = (IncomingPort, Type)> {
486 self.value_types(node, Direction::Incoming)
487 .map(|(p, t)| (p.as_incoming().unwrap(), t))
488 }
489
490 fn out_value_types(&self, node: Self::Node) -> impl Iterator<Item = (OutgoingPort, Type)> {
493 self.value_types(node, Direction::Outgoing)
494 .map(|(p, t)| (p.as_outgoing().unwrap(), t))
495 }
496
497 fn extensions(&self) -> &ExtensionRegistry;
502
503 fn validate(&self) -> Result<(), ValidationError<Self::Node>>
505 where
506 Self: Sized,
507 {
508 let mut validator = ValidationContext::new(self);
509 validator.validate()
510 }
511
512 fn extract_hugr(
526 &self,
527 parent: Self::Node,
528 ) -> (Hugr, impl ExtractionResult<Self::Node> + 'static);
529}
530
531pub trait ExtractionResult<SourceN> {
536 fn extracted_node(&self, node: SourceN) -> Node;
542}
543
544struct DefaultNodeMap(HashMap<Node, Node>);
546
547impl ExtractionResult<Node> for DefaultNodeMap {
548 #[inline]
549 fn extracted_node(&self, node: Node) -> Node {
550 self.0.get(&node).copied().unwrap_or(node)
551 }
552}
553
554impl<S: HugrNode> ExtractionResult<S> for HashMap<S, Node> {
555 #[inline]
556 fn extracted_node(&self, node: S) -> Node {
557 self[&node]
558 }
559}
560
561impl HugrView for Hugr {
562 #[inline]
563 fn entrypoint(&self) -> Self::Node {
564 self.entrypoint.into()
565 }
566
567 #[inline]
568 fn module_root(&self) -> Self::Node {
569 let node: Self::Node = self.module_root.into();
570 let handle = node.try_cast();
571 debug_assert!(
572 handle.is_some(),
573 "The root node in a HUGR must be a module."
574 );
575 handle.unwrap()
576 }
577
578 #[inline]
579 fn contains_node(&self, node: Self::Node) -> bool {
580 self.graph.contains_node(node.into_portgraph())
581 }
582
583 #[inline]
584 fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
585 if !check_valid_non_root(self, node) {
586 return None;
587 }
588 self.hierarchy.parent(node.into_portgraph()).map(Into::into)
589 }
590
591 #[inline]
592 fn get_optype(&self, node: Node) -> &OpType {
593 panic_invalid_node(self, node);
594 self.op_types.get(node.into_portgraph())
595 }
596
597 #[inline]
598 fn num_nodes(&self) -> usize {
599 self.graph.node_count()
600 }
601
602 #[inline]
603 fn num_edges(&self) -> usize {
604 self.graph.link_count()
605 }
606
607 #[inline]
608 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize {
609 self.graph.num_ports(node.into_portgraph(), dir)
610 }
611
612 #[inline]
613 fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
614 self.graph.nodes_iter().map_into()
615 }
616
617 #[inline]
618 fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
619 self.graph
620 .port_offsets(node.into_portgraph(), dir)
621 .map_into()
622 }
623
624 #[inline]
625 fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
626 self.graph
627 .all_port_offsets(node.into_portgraph())
628 .map_into()
629 }
630
631 #[inline]
632 fn linked_ports(
633 &self,
634 node: Node,
635 port: impl Into<Port>,
636 ) -> impl Iterator<Item = (Node, Port)> + Clone {
637 let port = port.into();
638
639 let port = self
640 .graph
641 .port_index(node.into_portgraph(), port.pg_offset())
642 .unwrap();
643 self.graph.port_links(port).map(|(_, link)| {
644 let port = link.port();
645 let node = self.graph.port_node(port).unwrap();
646 let offset = self.graph.port_offset(port).unwrap();
647 (node.into(), offset.into())
648 })
649 }
650
651 #[inline]
652 fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
653 self.graph
654 .get_connections(node.into_portgraph(), other.into_portgraph())
655 .map(|(p1, p2)| {
656 [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into())
657 })
658 }
659
660 #[inline]
661 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
662 self.hierarchy.children(node.into_portgraph()).map_into()
663 }
664
665 #[inline]
666 fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
667 self.hierarchy.descendants(node.into_portgraph()).map_into()
668 }
669
670 #[inline]
671 fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
672 self.graph.neighbours(node.into_portgraph(), dir).map_into()
673 }
674
675 #[inline]
676 fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
677 self.graph.all_neighbours(node.into_portgraph()).map_into()
678 }
679
680 #[expect(deprecated)]
681 fn mermaid_string_with_config(&self, config: RenderConfig) -> String {
682 self.mermaid_string_with_formatter(MermaidFormatter::from_render_config(config, self))
683 }
684
685 fn mermaid_string_with_formatter(&self, formatter: MermaidFormatter<Self>) -> String {
686 self.graph
687 .mermaid_format()
688 .with_hierarchy(&self.hierarchy)
689 .with_node_style(render::node_style(self, formatter.clone()))
690 .with_edge_style(render::edge_style(self, formatter))
691 .finish()
692 }
693
694 fn dot_string(&self) -> String
695 where
696 Self: Sized,
697 {
698 let formatter = MermaidFormatter::new(self).with_entrypoint(self.entrypoint());
699 self.graph
700 .dot_format()
701 .with_hierarchy(&self.hierarchy)
702 .with_node_style(render::node_style(self, formatter.clone()))
703 .with_port_style(render::port_style(self))
704 .with_edge_style(render::edge_style(self, formatter))
705 .finish()
706 }
707
708 #[inline]
709 fn extensions(&self) -> &ExtensionRegistry {
710 &self.extensions
711 }
712
713 #[inline]
714 fn extract_hugr(&self, target: Node) -> (Hugr, impl ExtractionResult<Node> + 'static) {
715 if target == self.module_root().node() {
717 return (self.clone(), DefaultNodeMap(HashMap::new()));
718 }
719
720 let mut parent = target;
726 let mut extracted = loop {
727 let parent_op = self.get_optype(parent).clone();
728 if let Ok(hugr) = Hugr::new_with_entrypoint(parent_op) {
729 break hugr;
730 }
731 parent = self
734 .get_parent(parent)
735 .expect("The module root is always extractable");
736 };
737
738 let old_entrypoint = extracted.entrypoint();
741 let old_parent = extracted.get_parent(old_entrypoint);
742
743 let inserted = extracted.insert_from_view(old_entrypoint, &self.with_entrypoint(parent));
744 let new_entrypoint = inserted.inserted_entrypoint;
745
746 match old_parent {
747 Some(old_parent) => {
748 let old_ins = extracted
751 .node_inputs(old_entrypoint)
752 .flat_map(|inp| {
753 extracted
754 .linked_outputs(old_entrypoint, inp)
755 .map(move |link| (inp, link))
756 })
757 .collect_vec();
758 let old_outs = extracted
759 .node_outputs(old_entrypoint)
760 .flat_map(|out| {
761 extracted
762 .linked_inputs(old_entrypoint, out)
763 .map(move |link| (out, link))
764 })
765 .collect_vec();
766 extracted.set_entrypoint(inserted.node_map[&target]);
768 extracted.remove_node(old_entrypoint);
769 extracted.set_parent(new_entrypoint, old_parent);
770 for (inp, (neigh, neigh_out)) in old_ins {
772 extracted.connect(neigh, neigh_out, new_entrypoint, inp);
773 }
774 for (out, (neigh, neigh_in)) in old_outs {
775 extracted.connect(new_entrypoint, out, neigh, neigh_in);
776 }
777 }
778 None => {
780 extracted.set_entrypoint(inserted.node_map[&target]);
781 extracted.set_module_root(new_entrypoint);
782 extracted.remove_node(old_entrypoint);
783 }
784 }
785 (extracted, DefaultNodeMap(inserted.node_map))
786 }
787}
788
789pub trait PortIterator<P>: Iterator<Item = (Node, P)>
791where
792 P: Into<Port> + Copy,
793 Self: Sized,
794{
795 fn dataflow_ports_only(
798 self,
799 hugr: &impl HugrView<Node = Node>,
800 ) -> impl Iterator<Item = (Node, P)> {
801 self.filter_edge_kind(
802 |kind| matches!(kind, Some(EdgeKind::Value(..) | EdgeKind::StateOrder)),
803 hugr,
804 )
805 }
806
807 fn filter_edge_kind(
809 self,
810 predicate: impl Fn(Option<EdgeKind>) -> bool,
811 hugr: &impl HugrView<Node = Node>,
812 ) -> impl Iterator<Item = (Node, P)> {
813 self.filter(move |(n, p)| {
814 let kind = HugrView::get_optype(hugr, *n).port_kind(*p);
815 predicate(kind)
816 })
817 }
818}
819
820impl<I, P> PortIterator<P> for I
821where
822 I: Iterator<Item = (Node, P)>,
823 P: Into<Port> + Copy,
824{
825}
826
827pub(super) fn check_valid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
829 hugr.contains_node(node) && node != hugr.entrypoint()
830}
831
832pub(super) fn check_valid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
834 hugr.contains_node(node) && node != hugr.module_root().node()
835}
836
837#[track_caller]
839pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
840 assert!(hugr.contains_node(node), "Received an invalid node {node}.",);
841}
842
843#[track_caller]
845pub(super) fn panic_invalid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
846 assert!(
847 check_valid_non_entrypoint(hugr, node),
848 "Received an invalid non-entrypoint node {node}.",
849 );
850}
851
852#[track_caller]
854pub(super) fn panic_invalid_port(hugr: &Hugr, node: Node, port: impl Into<Port>) {
855 let port = port.into();
856 if hugr
857 .graph
858 .port_index(node.into_portgraph(), port.pg_offset())
859 .is_none()
860 {
861 panic!("Received an invalid {port} for {node} while mutating a HUGR");
862 }
863}