1mod impls;
4pub mod petgraph;
5pub mod render;
6mod rerooted;
7mod root_checked;
8pub mod sibling_subgraph;
9
10#[cfg(test)]
11mod tests;
12
13use std::borrow::Cow;
14use std::collections::HashMap;
15
16pub use self::petgraph::PetgraphWrapper;
17use self::render::RenderConfig;
18pub use rerooted::Rerooted;
19pub use root_checked::{InvalidSignature, RootCheckable, RootChecked, check_tag};
20pub use sibling_subgraph::SiblingSubgraph;
21
22use itertools::Itertools;
23use portgraph::render::{DotFormat, MermaidFormat};
24use portgraph::{LinkView, PortView};
25
26use super::internal::{HugrInternals, HugrMutInternals};
27use super::validate::ValidationContext;
28use super::{Hugr, HugrMut, Node, NodeMetadata, ValidationError};
29use crate::core::HugrNode;
30use crate::extension::ExtensionRegistry;
31use crate::ops::handle::NodeHandle;
32use crate::ops::{OpParent, OpTag, OpTrait, OpType};
33
34use crate::types::{EdgeKind, PolyFuncType, Signature, Type};
35use crate::{Direction, IncomingPort, OutgoingPort, Port};
36
37use itertools::Either;
38
39pub trait HugrView: HugrInternals {
42    fn entrypoint(&self) -> Self::Node;
49
50    #[inline]
52    fn entrypoint_optype(&self) -> &OpType {
53        self.get_optype(self.entrypoint())
54    }
55
56    #[inline]
66    fn entrypoint_tag(&self) -> OpTag {
67        self.entrypoint_optype().tag()
68    }
69
70    fn with_entrypoint(&self, entrypoint: Self::Node) -> Rerooted<&Self>
78    where
79        Self: Sized,
80    {
81        Rerooted::new(self, entrypoint)
82    }
83
84    fn module_root(&self) -> Self::Node;
92
93    fn contains_node(&self, node: Self::Node) -> bool;
95
96    fn get_parent(&self, node: Self::Node) -> Option<Self::Node>;
98
99    #[inline]
101    fn get_metadata(&self, node: Self::Node, key: impl AsRef<str>) -> Option<&NodeMetadata> {
102        if self.contains_node(node) {
103            self.node_metadata_map(node).get(key.as_ref())
104        } else {
105            None
106        }
107    }
108
109    fn get_optype(&self, node: Self::Node) -> &OpType;
115
116    fn num_nodes(&self) -> usize;
118
119    fn num_edges(&self) -> usize;
121
122    fn num_ports(&self, node: Self::Node, dir: Direction) -> usize;
124
125    #[inline]
128    fn num_inputs(&self, node: Self::Node) -> usize {
129        self.num_ports(node, Direction::Incoming)
130    }
131
132    #[inline]
135    fn num_outputs(&self, node: Self::Node) -> usize {
136        self.num_ports(node, Direction::Outgoing)
137    }
138
139    fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone;
148
149    fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone;
151
152    #[inline]
156    fn node_outputs(&self, node: Self::Node) -> impl Iterator<Item = OutgoingPort> + Clone {
157        self.node_ports(node, Direction::Outgoing)
158            .map(|p| p.as_outgoing().unwrap())
159    }
160
161    #[inline]
165    fn node_inputs(&self, node: Self::Node) -> impl Iterator<Item = IncomingPort> + Clone {
166        self.node_ports(node, Direction::Incoming)
167            .map(|p| p.as_incoming().unwrap())
168    }
169
170    fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone;
172
173    fn linked_ports(
175        &self,
176        node: Self::Node,
177        port: impl Into<Port>,
178    ) -> impl Iterator<Item = (Self::Node, Port)> + Clone;
179
180    fn all_linked_ports(
182        &self,
183        node: Self::Node,
184        dir: Direction,
185    ) -> Either<
186        impl Iterator<Item = (Self::Node, OutgoingPort)>,
187        impl Iterator<Item = (Self::Node, IncomingPort)>,
188    > {
189        match dir {
190            Direction::Incoming => Either::Left(
191                self.node_inputs(node)
192                    .flat_map(move |port| self.linked_outputs(node, port)),
193            ),
194            Direction::Outgoing => Either::Right(
195                self.node_outputs(node)
196                    .flat_map(move |port| self.linked_inputs(node, port)),
197            ),
198        }
199    }
200
201    fn all_linked_outputs(
203        &self,
204        node: Self::Node,
205    ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
206        self.all_linked_ports(node, Direction::Incoming)
207            .left()
208            .unwrap()
209    }
210
211    fn all_linked_inputs(
213        &self,
214        node: Self::Node,
215    ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
216        self.all_linked_ports(node, Direction::Outgoing)
217            .right()
218            .unwrap()
219    }
220
221    fn single_linked_port(
224        &self,
225        node: Self::Node,
226        port: impl Into<Port>,
227    ) -> Option<(Self::Node, Port)> {
228        self.linked_ports(node, port).exactly_one().ok()
229    }
230
231    fn single_linked_output(
234        &self,
235        node: Self::Node,
236        port: impl Into<IncomingPort>,
237    ) -> Option<(Self::Node, OutgoingPort)> {
238        self.single_linked_port(node, port.into())
239            .map(|(n, p)| (n, p.as_outgoing().unwrap()))
240    }
241
242    fn single_linked_input(
245        &self,
246        node: Self::Node,
247        port: impl Into<OutgoingPort>,
248    ) -> Option<(Self::Node, IncomingPort)> {
249        self.single_linked_port(node, port.into())
250            .map(|(n, p)| (n, p.as_incoming().unwrap()))
251    }
252    fn linked_outputs(
256        &self,
257        node: Self::Node,
258        port: impl Into<IncomingPort>,
259    ) -> impl Iterator<Item = (Self::Node, OutgoingPort)> {
260        self.linked_ports(node, port.into())
261            .map(|(n, p)| (n, p.as_outgoing().unwrap()))
262    }
263
264    fn linked_inputs(
268        &self,
269        node: Self::Node,
270        port: impl Into<OutgoingPort>,
271    ) -> impl Iterator<Item = (Self::Node, IncomingPort)> {
272        self.linked_ports(node, port.into())
273            .map(|(n, p)| (n, p.as_incoming().unwrap()))
274    }
275
276    fn node_connections(
278        &self,
279        node: Self::Node,
280        other: Self::Node,
281    ) -> impl Iterator<Item = [Port; 2]> + Clone;
282
283    fn is_linked(&self, node: Self::Node, port: impl Into<Port>) -> bool {
285        self.linked_ports(node, port).next().is_some()
286    }
287
288    fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone;
290
291    fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
296
297    fn entry_descendants(&self) -> impl Iterator<Item = Self::Node> + Clone {
302        self.descendants(self.entrypoint())
303    }
304
305    fn first_child(&self, node: Self::Node) -> Option<Self::Node> {
308        self.children(node).next()
309    }
310
311    fn neighbours(
314        &self,
315        node: Self::Node,
316        dir: Direction,
317    ) -> impl Iterator<Item = Self::Node> + Clone;
318
319    #[inline]
322    fn input_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
323        self.neighbours(node, Direction::Incoming)
324    }
325
326    #[inline]
329    fn output_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
330        self.neighbours(node, Direction::Outgoing)
331    }
332
333    fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
335
336    #[inline]
339    fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]> {
340        let op = self.get_optype(node);
341        if OpTag::DataflowParent.is_superset(op.tag()) {
343            self.children(node).take(2).collect_vec().try_into().ok()
344        } else {
345            None
346        }
347    }
348
349    fn inner_function_type(&self) -> Option<Cow<'_, Signature>> {
359        self.entrypoint_optype().inner_function_type()
360    }
361
362    fn poly_func_type(&self) -> Option<PolyFuncType> {
365        match self.entrypoint_optype() {
366            OpType::FuncDecl(decl) => Some(decl.signature().clone()),
367            OpType::FuncDefn(defn) => Some(defn.signature().clone()),
368            _ => None,
369        }
370    }
371
372    #[inline]
374    fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>
375    where
376        Self: Sized,
377    {
378        PetgraphWrapper { hugr: self }
379    }
380
381    fn mermaid_string(&self) -> String;
389
390    fn mermaid_string_with_config(&self, config: RenderConfig<Self::Node>) -> String;
398
399    fn dot_string(&self) -> String
403    where
404        Self: Sized;
405
406    fn static_source(&self, node: Self::Node) -> Option<Self::Node> {
408        self.linked_outputs(node, self.get_optype(node).static_input_port()?)
409            .next()
410            .map(|(n, _)| n)
411    }
412
413    fn static_targets(
415        &self,
416        node: Self::Node,
417    ) -> Option<impl Iterator<Item = (Self::Node, IncomingPort)>> {
418        Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?))
419    }
420
421    fn signature(&self, node: Self::Node) -> Option<Cow<'_, Signature>> {
424        self.get_optype(node).dataflow_signature()
425    }
426
427    fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = (Port, Type)> {
430        let sig = self.signature(node).unwrap_or_default();
431        self.node_ports(node, dir)
432            .filter_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone())))
433    }
434
435    fn in_value_types(&self, node: Self::Node) -> impl Iterator<Item = (IncomingPort, Type)> {
438        self.value_types(node, Direction::Incoming)
439            .map(|(p, t)| (p.as_incoming().unwrap(), t))
440    }
441
442    fn out_value_types(&self, node: Self::Node) -> impl Iterator<Item = (OutgoingPort, Type)> {
445        self.value_types(node, Direction::Outgoing)
446            .map(|(p, t)| (p.as_outgoing().unwrap(), t))
447    }
448
449    fn extensions(&self) -> &ExtensionRegistry;
454
455    fn validate(&self) -> Result<(), ValidationError<Self::Node>>
457    where
458        Self: Sized,
459    {
460        let mut validator = ValidationContext::new(self);
461        validator.validate()
462    }
463
464    fn extract_hugr(
478        &self,
479        parent: Self::Node,
480    ) -> (Hugr, impl ExtractionResult<Self::Node> + 'static);
481}
482
483pub trait ExtractionResult<SourceN> {
488    fn extracted_node(&self, node: SourceN) -> Node;
494}
495
496struct DefaultNodeMap(HashMap<Node, Node>);
498
499impl ExtractionResult<Node> for DefaultNodeMap {
500    #[inline]
501    fn extracted_node(&self, node: Node) -> Node {
502        self.0.get(&node).copied().unwrap_or(node)
503    }
504}
505
506impl<S: HugrNode> ExtractionResult<S> for HashMap<S, Node> {
507    #[inline]
508    fn extracted_node(&self, node: S) -> Node {
509        self[&node]
510    }
511}
512
513impl HugrView for Hugr {
514    #[inline]
515    fn entrypoint(&self) -> Self::Node {
516        self.entrypoint.into()
517    }
518
519    #[inline]
520    fn module_root(&self) -> Self::Node {
521        let node: Self::Node = self.module_root.into();
522        let handle = node.try_cast();
523        debug_assert!(
524            handle.is_some(),
525            "The root node in a HUGR must be a module."
526        );
527        handle.unwrap()
528    }
529
530    #[inline]
531    fn contains_node(&self, node: Self::Node) -> bool {
532        self.graph.contains_node(node.into_portgraph())
533    }
534
535    #[inline]
536    fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
537        if !check_valid_non_root(self, node) {
538            return None;
539        }
540        self.hierarchy.parent(node.into_portgraph()).map(Into::into)
541    }
542
543    #[inline]
544    fn get_optype(&self, node: Node) -> &OpType {
545        panic_invalid_node(self, node);
546        self.op_types.get(node.into_portgraph())
547    }
548
549    #[inline]
550    fn num_nodes(&self) -> usize {
551        self.graph.node_count()
552    }
553
554    #[inline]
555    fn num_edges(&self) -> usize {
556        self.graph.link_count()
557    }
558
559    #[inline]
560    fn num_ports(&self, node: Self::Node, dir: Direction) -> usize {
561        self.graph.num_ports(node.into_portgraph(), dir)
562    }
563
564    #[inline]
565    fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
566        self.graph.nodes_iter().map_into()
567    }
568
569    #[inline]
570    fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
571        self.graph
572            .port_offsets(node.into_portgraph(), dir)
573            .map_into()
574    }
575
576    #[inline]
577    fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
578        self.graph
579            .all_port_offsets(node.into_portgraph())
580            .map_into()
581    }
582
583    #[inline]
584    fn linked_ports(
585        &self,
586        node: Node,
587        port: impl Into<Port>,
588    ) -> impl Iterator<Item = (Node, Port)> + Clone {
589        let port = port.into();
590
591        let port = self
592            .graph
593            .port_index(node.into_portgraph(), port.pg_offset())
594            .unwrap();
595        self.graph.port_links(port).map(|(_, link)| {
596            let port = link.port();
597            let node = self.graph.port_node(port).unwrap();
598            let offset = self.graph.port_offset(port).unwrap();
599            (node.into(), offset.into())
600        })
601    }
602
603    #[inline]
604    fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
605        self.graph
606            .get_connections(node.into_portgraph(), other.into_portgraph())
607            .map(|(p1, p2)| {
608                [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into())
609            })
610    }
611
612    #[inline]
613    fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
614        self.hierarchy.children(node.into_portgraph()).map_into()
615    }
616
617    #[inline]
618    fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
619        self.hierarchy.descendants(node.into_portgraph()).map_into()
620    }
621
622    #[inline]
623    fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
624        self.graph.neighbours(node.into_portgraph(), dir).map_into()
625    }
626
627    #[inline]
628    fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
629        self.graph.all_neighbours(node.into_portgraph()).map_into()
630    }
631
632    fn mermaid_string(&self) -> String {
633        self.mermaid_string_with_config(RenderConfig {
634            node_indices: true,
635            port_offsets_in_edges: true,
636            type_labels_in_edges: true,
637            entrypoint: Some(self.entrypoint()),
638        })
639    }
640
641    fn mermaid_string_with_config(&self, config: RenderConfig) -> String {
642        self.graph
643            .mermaid_format()
644            .with_hierarchy(&self.hierarchy)
645            .with_node_style(render::node_style(self, config, |n| n.index().to_string()))
646            .with_edge_style(render::edge_style(self, config))
647            .finish()
648    }
649
650    fn dot_string(&self) -> String
651    where
652        Self: Sized,
653    {
654        let config = RenderConfig {
655            entrypoint: Some(self.entrypoint()),
656            ..RenderConfig::default()
657        };
658        self.graph
659            .dot_format()
660            .with_hierarchy(&self.hierarchy)
661            .with_node_style(render::node_style(self, config, |n| n.index().to_string()))
662            .with_port_style(render::port_style(self, config))
663            .with_edge_style(render::edge_style(self, config))
664            .finish()
665    }
666
667    #[inline]
668    fn extensions(&self) -> &ExtensionRegistry {
669        &self.extensions
670    }
671
672    #[inline]
673    fn extract_hugr(&self, target: Node) -> (Hugr, impl ExtractionResult<Node> + 'static) {
674        if target == self.module_root().node() {
676            return (self.clone(), DefaultNodeMap(HashMap::new()));
677        }
678
679        let mut parent = target;
685        let mut extracted = loop {
686            let parent_op = self.get_optype(parent).clone();
687            if let Ok(hugr) = Hugr::new_with_entrypoint(parent_op) {
688                break hugr;
689            }
690            parent = self
693                .get_parent(parent)
694                .expect("The module root is always extractable");
695        };
696
697        let old_entrypoint = extracted.entrypoint();
700        let old_parent = extracted.get_parent(old_entrypoint);
701
702        let inserted = extracted.insert_from_view(old_entrypoint, &self.with_entrypoint(parent));
703        let new_entrypoint = inserted.inserted_entrypoint;
704
705        match old_parent {
706            Some(old_parent) => {
707                let old_ins = extracted
710                    .node_inputs(old_entrypoint)
711                    .flat_map(|inp| {
712                        extracted
713                            .linked_outputs(old_entrypoint, inp)
714                            .map(move |link| (inp, link))
715                    })
716                    .collect_vec();
717                let old_outs = extracted
718                    .node_outputs(old_entrypoint)
719                    .flat_map(|out| {
720                        extracted
721                            .linked_inputs(old_entrypoint, out)
722                            .map(move |link| (out, link))
723                    })
724                    .collect_vec();
725                extracted.set_entrypoint(inserted.node_map[&target]);
727                extracted.remove_node(old_entrypoint);
728                extracted.set_parent(new_entrypoint, old_parent);
729                for (inp, (neigh, neigh_out)) in old_ins {
731                    extracted.connect(neigh, neigh_out, new_entrypoint, inp);
732                }
733                for (out, (neigh, neigh_in)) in old_outs {
734                    extracted.connect(new_entrypoint, out, neigh, neigh_in);
735                }
736            }
737            None => {
739                extracted.set_entrypoint(inserted.node_map[&target]);
740                extracted.set_module_root(new_entrypoint);
741                extracted.remove_node(old_entrypoint);
742            }
743        }
744        (extracted, DefaultNodeMap(inserted.node_map))
745    }
746}
747
748pub trait PortIterator<P>: Iterator<Item = (Node, P)>
750where
751    P: Into<Port> + Copy,
752    Self: Sized,
753{
754    fn dataflow_ports_only(
757        self,
758        hugr: &impl HugrView<Node = Node>,
759    ) -> impl Iterator<Item = (Node, P)> {
760        self.filter_edge_kind(
761            |kind| matches!(kind, Some(EdgeKind::Value(..) | EdgeKind::StateOrder)),
762            hugr,
763        )
764    }
765
766    fn filter_edge_kind(
768        self,
769        predicate: impl Fn(Option<EdgeKind>) -> bool,
770        hugr: &impl HugrView<Node = Node>,
771    ) -> impl Iterator<Item = (Node, P)> {
772        self.filter(move |(n, p)| {
773            let kind = HugrView::get_optype(hugr, *n).port_kind(*p);
774            predicate(kind)
775        })
776    }
777}
778
779impl<I, P> PortIterator<P> for I
780where
781    I: Iterator<Item = (Node, P)>,
782    P: Into<Port> + Copy,
783{
784}
785
786pub(super) fn check_valid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
788    hugr.contains_node(node) && node != hugr.entrypoint()
789}
790
791pub(super) fn check_valid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) -> bool {
793    hugr.contains_node(node) && node != hugr.module_root().node()
794}
795
796#[track_caller]
798pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
799    assert!(hugr.contains_node(node), "Received an invalid node {node}.",);
800}
801
802#[track_caller]
804pub(super) fn panic_invalid_non_entrypoint<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
805    assert!(
806        check_valid_non_entrypoint(hugr, node),
807        "Received an invalid non-entrypoint node {node}.",
808    );
809}
810
811#[track_caller]
813pub(super) fn panic_invalid_port(hugr: &Hugr, node: Node, port: impl Into<Port>) {
814    let port = port.into();
815    if hugr
816        .graph
817        .port_index(node.into_portgraph(), port.pg_offset())
818        .is_none()
819    {
820        panic!("Received an invalid {port} for {node} while mutating a HUGR");
821    }
822}