1use std::{borrow::Cow, rc::Rc, sync::Arc};
2
3use delegate::delegate;
4use itertools::Either;
5
6use super::{render::RenderConfig, HugrView, RootChecked};
7use crate::{
8 extension::ExtensionRegistry,
9 hugr::{NodeMetadata, NodeMetadataMap, ValidationError},
10 ops::OpType,
11 types::{PolyFuncType, Signature, Type},
12 Direction, Hugr, IncomingPort, OutgoingPort, Port,
13};
14
15macro_rules! hugr_view_methods {
16 ($arg:ident, $e:expr) => {
18 delegate! {
19 to ({let $arg=self; $e}) {
20 fn root(&self) -> Self::Node;
21 fn root_type(&self) -> &OpType;
22 fn contains_node(&self, node: Self::Node) -> bool;
23 fn valid_node(&self, node: Self::Node) -> bool;
24 fn valid_non_root(&self, node: Self::Node) -> bool;
25 fn get_parent(&self, node: Self::Node) -> Option<Self::Node>;
26 fn get_optype(&self, node: Self::Node) -> &OpType;
27 fn get_metadata(&self, node: Self::Node, key: impl AsRef<str>) -> Option<&NodeMetadata>;
28 fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap>;
29 fn node_count(&self) -> usize;
30 fn edge_count(&self) -> usize;
31 fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone;
32 fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone;
33 fn node_outputs(&self, node: Self::Node) -> impl Iterator<Item = OutgoingPort> + Clone;
34 fn node_inputs(&self, node: Self::Node) -> impl Iterator<Item = IncomingPort> + Clone;
35 fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone;
36 fn linked_ports(
37 &self,
38 node: Self::Node,
39 port: impl Into<Port>,
40 ) -> impl Iterator<Item = (Self::Node, Port)> + Clone;
41 fn all_linked_ports(
42 &self,
43 node: Self::Node,
44 dir: Direction,
45 ) -> Either<
46 impl Iterator<Item = (Self::Node, OutgoingPort)>,
47 impl Iterator<Item = (Self::Node, IncomingPort)>,
48 >;
49 fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator<Item = (Self::Node, OutgoingPort)>;
50 fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator<Item = (Self::Node, IncomingPort)>;
51 fn single_linked_port(&self, node: Self::Node, port: impl Into<Port>) -> Option<(Self::Node, Port)>;
52 fn single_linked_output(&self, node: Self::Node, port: impl Into<IncomingPort>) -> Option<(Self::Node, OutgoingPort)>;
53 fn single_linked_input(&self, node: Self::Node, port: impl Into<OutgoingPort>) -> Option<(Self::Node, IncomingPort)>;
54 fn linked_outputs(&self, node: Self::Node, port: impl Into<IncomingPort>) -> impl Iterator<Item = (Self::Node, OutgoingPort)>;
55 fn linked_inputs(&self, node: Self::Node, port: impl Into<OutgoingPort>) -> impl Iterator<Item = (Self::Node, IncomingPort)>;
56 fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator<Item = [Port; 2]> + Clone;
57 fn is_linked(&self, node: Self::Node, port: impl Into<Port>) -> bool;
58 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize;
59 fn num_inputs(&self, node: Self::Node) -> usize;
60 fn num_outputs(&self, node: Self::Node) -> usize;
61 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone;
62 fn first_child(&self, node: Self::Node) -> Option<Self::Node>;
63 fn neighbours(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Self::Node> + Clone;
64 fn input_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
65 fn output_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
66 fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
67 fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>;
68 fn inner_function_type(&self) -> Option<Cow<'_, Signature>>;
69 fn poly_func_type(&self) -> Option<PolyFuncType>;
70 fn mermaid_string(&self) -> String;
75 fn mermaid_string_with_config(&self, config: RenderConfig) -> String;
76 fn dot_string(&self) -> String;
77 fn static_source(&self, node: Self::Node) -> Option<Self::Node>;
78 fn static_targets(&self, node: Self::Node) -> Option<impl Iterator<Item = (Self::Node, IncomingPort)>>;
79 fn signature(&self, node: Self::Node) -> Option<Cow<'_, Signature>>;
80 fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = (Port, Type)>;
81 fn in_value_types(&self, node: Self::Node) -> impl Iterator<Item = (IncomingPort, Type)>;
82 fn out_value_types(&self, node: Self::Node) -> impl Iterator<Item = (OutgoingPort, Type)>;
83 fn extensions(&self) -> &ExtensionRegistry;
84 fn validate(&self) -> Result<(), ValidationError>;
85 fn validate_no_extensions(&self) -> Result<(), ValidationError>;
86 }
87 }
88 }
89}
90
91impl<T: HugrView> HugrView for &T {
92 hugr_view_methods! {this, *this}
93}
94
95impl<T: HugrView> HugrView for &mut T {
96 hugr_view_methods! {this, &**this}
97}
98
99impl<T: HugrView> HugrView for Rc<T> {
100 hugr_view_methods! {this, this.as_ref()}
101}
102
103impl<T: HugrView> HugrView for Arc<T> {
104 hugr_view_methods! {this, this.as_ref()}
105}
106
107impl<T: HugrView> HugrView for Box<T> {
108 hugr_view_methods! {this, this.as_ref()}
109}
110
111impl<T: HugrView + ToOwned> HugrView for Cow<'_, T> {
112 hugr_view_methods! {this, this.as_ref()}
113}
114
115impl<H: AsRef<Hugr>, Root> HugrView for RootChecked<H, Root> {
116 hugr_view_methods! {this, this.as_ref()}
117}
118
119#[cfg(test)]
120mod test {
121 use std::{rc::Rc, sync::Arc};
122
123 use crate::hugr::views::{DescendantsGraph, HierarchyView};
124 use crate::{Hugr, HugrView, Node};
125
126 struct ViewWrapper<H>(H);
127 impl<H: HugrView> ViewWrapper<H> {
128 fn nodes(&self) -> impl Iterator<Item = H::Node> + '_ {
129 self.0.nodes()
130 }
131 }
132
133 #[test]
134 fn test_refs_to_view() {
135 let h = Hugr::default();
136 let v = ViewWrapper(&h);
137 let c = h.nodes().count();
138 assert_eq!(v.nodes().count(), c);
139 let v2 = ViewWrapper(DescendantsGraph::<Node>::try_new(&h, h.root()).unwrap());
140 assert_eq!(v2.nodes().count(), v.nodes().count());
142 assert_eq!(ViewWrapper(&v2.0).nodes().count(), v.nodes().count());
144
145 let vh = ViewWrapper(h);
146 assert_eq!(vh.nodes().count(), c);
147 let h: Hugr = vh.0;
148 assert_eq!(h.nodes().count(), c);
149
150 let vb = ViewWrapper(Box::new(&h));
151 assert_eq!(vb.nodes().count(), c);
152 let va = ViewWrapper(Arc::new(h));
153 assert_eq!(va.nodes().count(), c);
154 let h = Arc::try_unwrap(va.0).unwrap();
155 let vr = Rc::new(&h);
156 assert_eq!(ViewWrapper(&vr).nodes().count(), h.nodes().count());
157 }
158}