1use std::collections::HashMap;
2
3use itertools::{Either, Itertools};
4use portgraph::render::{DotFormat, MermaidFormat};
5
6use crate::{
7 Direction, Hugr, HugrView, Node, Port,
8 hugr::{
9 Patch, SimpleReplacementError,
10 internal::HugrInternals,
11 views::{
12 ExtractionResult,
13 render::{self, RenderConfig},
14 },
15 },
16};
17
18use super::{
19 InvalidCommit, PatchNode, PersistentHugr, PersistentReplacement, state_space::CommitData,
20};
21
22impl Patch<PersistentHugr> for PersistentReplacement {
23 type Outcome = ();
24 const UNCHANGED_ON_FAILURE: bool = true;
25
26 fn apply(self, h: &mut PersistentHugr) -> Result<Self::Outcome, Self::Error> {
27 match h.try_add_replacement(self) {
28 Ok(_) => Ok(()),
29 Err(
30 InvalidCommit::UnknownParent(_)
31 | InvalidCommit::IncompatibleHistory(_, _)
32 | InvalidCommit::EmptyReplacement,
33 ) => Err(SimpleReplacementError::InvalidRemovedNode()),
34 _ => unreachable!(),
35 }
36 }
37}
38
39impl HugrInternals for PersistentHugr {
40 type RegionPortgraph<'p>
41 = portgraph::MultiPortGraph
42 where
43 Self: 'p;
44
45 type Node = PatchNode;
46
47 type RegionPortgraphNodes = HashMap<PatchNode, Node>;
48
49 fn region_portgraph(
50 &self,
51 parent: Self::Node,
52 ) -> (
53 portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
54 Self::RegionPortgraphNodes,
55 ) {
56 let (hugr, node_map) = self.apply_all();
58 let parent = node_map[&parent];
59
60 let region = portgraph::view::FlatRegion::new_without_root(
61 hugr.graph,
62 hugr.hierarchy,
63 parent.into_portgraph(),
64 );
65 (region, node_map)
66 }
67
68 fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap {
69 self.as_state_space().node_metadata_map(node)
70 }
71}
72
73impl HugrView for PersistentHugr {
79 fn entrypoint(&self) -> Self::Node {
80 let entry = self.base_hugr().entrypoint();
83 let node = PatchNode(self.base(), entry);
84
85 debug_assert!(self.contains_node(node), "invalid entrypoint");
86 node
87 }
88
89 fn module_root(&self) -> Self::Node {
90 let root = self.base_hugr().module_root();
93 let node = PatchNode(self.base(), root);
94
95 debug_assert!(self.contains_node(node), "invalid module root");
96 node
97 }
98
99 fn contains_node(&self, node: Self::Node) -> bool {
100 self.contains_node(node)
101 }
102
103 fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
104 assert!(self.contains_node(node), "invalid node");
105 let (hugr, node_map) = self.apply_all();
106 let parent = hugr.get_parent(node_map[&node])?;
107 let parent_inv = node_map
108 .iter()
109 .find_map(|(&k, &v)| (v == parent).then_some(k))
110 .expect("parent not found in node map");
111 Some(parent_inv)
112 }
113
114 fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType {
115 self.as_state_space().get_optype(node)
116 }
117
118 fn num_nodes(&self) -> usize {
119 let mut num_nodes = 0isize;
120 for commit in self.all_commit_ids() {
121 num_nodes += self.inserted_nodes(commit).count() as isize;
122 num_nodes -= self.deleted_nodes(commit).count() as isize;
123 }
124 num_nodes as usize
125 }
126
127 fn num_edges(&self) -> usize {
128 self.to_hugr().num_edges()
129 }
130
131 fn num_ports(&self, node: Self::Node, dir: Direction) -> usize {
132 self.as_state_space().num_ports(node, dir)
133 }
134
135 fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone {
136 self.all_commit_ids()
137 .flat_map(|commit_id| {
138 let to_patch_node = move |node: Node| PatchNode(commit_id, node);
139 match self.get_commit(commit_id).value() {
140 CommitData::Base(hugr) => Either::Left(hugr.nodes().map(to_patch_node)),
141 CommitData::Replacement(repl) => Either::Right(
142 repl.replacement()
143 .children(repl.replacement().entrypoint())
144 .filter(|&n| {
145 let ot = repl.replacement().get_optype(n);
146 !ot.is_input() && !ot.is_output()
147 })
148 .map(to_patch_node),
149 ),
150 }
151 })
152 .filter(|&n| self.contains_node(n))
153 }
154
155 fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
156 self.as_state_space().node_ports(node, dir)
157 }
158
159 fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone {
160 self.as_state_space().all_node_ports(node)
161 }
162
163 fn linked_ports(
164 &self,
165 node: Self::Node,
166 port: impl Into<Port>,
167 ) -> impl Iterator<Item = (Self::Node, Port)> + Clone {
168 let port = port.into();
169 let mut ret_ports = Vec::new();
170 if !self.is_value_port(node, port) {
171 let commit_id = node.0;
173 let to_patch_node = |(node, port)| (PatchNode(commit_id, node), port);
174 ret_ports.extend(
175 self.commit_hugr(commit_id)
176 .linked_ports(node.1, port)
177 .map(to_patch_node),
178 );
179 } else {
180 match port.as_directed() {
181 Either::Left(incoming) => {
182 let (out_node, out_port) = self.get_single_outgoing_port(node, incoming);
183 ret_ports.push((out_node, out_port.into()))
184 }
185 Either::Right(outgoing) => ret_ports.extend(
186 self.get_all_incoming_ports(node, outgoing)
187 .map(|(node, port)| (node, port.into())),
188 ),
189 }
190 }
191
192 ret_ports.into_iter()
193 }
194
195 fn node_connections(
196 &self,
197 node: Self::Node,
198 other: Self::Node,
199 ) -> impl Iterator<Item = [Port; 2]> + Clone {
200 self.node_outputs(node)
201 .flat_map(move |port| {
202 self.linked_ports(node, port)
203 .map(move |(opp_node, opp_port)| (port, opp_node, opp_port))
204 })
205 .filter(move |&(_, opp_node, _)| opp_node == other)
206 .map(|(port, _, opp_port)| [port.into(), opp_port])
207 }
208
209 fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
210 let (hugr, node_map) = self.apply_all();
211 let children = hugr.children(node_map[&node]).collect_vec();
212 let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
213 children.into_iter().map(move |child| {
214 *inv_node_map
215 .get(&child)
216 .expect("node not found in node map")
217 })
218 }
219
220 fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
221 let (hugr, node_map) = self.apply_all();
222 let descendants = hugr.descendants(node_map[&node]).collect_vec();
223 let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
224 descendants.into_iter().map(move |child| {
225 *inv_node_map
226 .get(&child)
227 .expect("node not found in node map")
228 })
229 }
230
231 fn neighbours(
232 &self,
233 node: Self::Node,
234 dir: Direction,
235 ) -> impl Iterator<Item = Self::Node> + Clone {
236 self.node_ports(node, dir)
237 .flat_map(move |port| self.linked_ports(node, port).map(|(opp_node, _)| opp_node))
238 }
239
240 fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
241 self.all_node_ports(node)
242 .flat_map(move |port| self.linked_ports(node, port).map(|(opp_node, _)| opp_node))
243 }
244
245 fn mermaid_string(&self) -> String {
246 self.mermaid_string_with_config(RenderConfig {
247 node_indices: true,
248 port_offsets_in_edges: true,
249 type_labels_in_edges: true,
250 entrypoint: Some(self.entrypoint()),
251 })
252 }
253
254 fn mermaid_string_with_config(&self, config: RenderConfig<Self::Node>) -> String {
255 let (hugr, node_map) = self.apply_all();
257
258 let config = RenderConfig {
260 entrypoint: config.entrypoint.map(|n| node_map[&n]),
261 node_indices: config.node_indices,
262 port_offsets_in_edges: config.port_offsets_in_edges,
263 type_labels_in_edges: config.type_labels_in_edges,
264 };
265
266 let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
269 let fmt_node_index = |n: portgraph::NodeIndex| format!("{:?}", inv_node_map[&n.into()]);
270 hugr.graph
271 .mermaid_format()
272 .with_hierarchy(&hugr.hierarchy)
273 .with_node_style(render::node_style(&hugr, config, fmt_node_index))
274 .with_edge_style(render::edge_style(&hugr, config))
275 .finish()
276 }
277
278 fn dot_string(&self) -> String
279 where
280 Self: Sized,
281 {
282 let (hugr, node_map) = self.apply_all();
284
285 let config = RenderConfig {
287 entrypoint: Some(node_map[&self.entrypoint()]),
288 ..RenderConfig::default()
289 };
290
291 let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
294 let fmt_node_index = |n: portgraph::NodeIndex| format!("{:?}", inv_node_map[&n.into()]);
295 hugr.graph
296 .dot_format()
297 .with_hierarchy(&hugr.hierarchy)
298 .with_node_style(render::node_style(&hugr, config, fmt_node_index))
299 .with_port_style(render::port_style(&hugr, config))
300 .with_edge_style(render::edge_style(&hugr, config))
301 .finish()
302 }
303
304 fn extensions(&self) -> &crate::extension::ExtensionRegistry {
305 &self.base_hugr().extensions
306 }
307
308 fn extract_hugr(
309 &self,
310 parent: Self::Node,
311 ) -> (
312 Hugr,
313 impl crate::hugr::views::ExtractionResult<Self::Node> + 'static,
314 ) {
315 let (hugr, apply_node_map) = self.apply_all();
316 let (extracted_hugr, extracted_node_map) = hugr.extract_hugr(apply_node_map[&parent]);
317
318 let node_map: HashMap<_, _> = apply_node_map
319 .into_iter()
320 .filter_map(|(patch_node, node)| {
321 let extracted_node = extracted_node_map.extracted_node(node);
322 if extracted_hugr.contains_node(extracted_node) {
323 Some((patch_node, extracted_node))
324 } else {
325 None
326 }
327 })
328 .collect();
329
330 (extracted_hugr, node_map)
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use std::collections::HashSet;
337
338 use crate::hugr::persistent::{CommitStateSpace, state_space::CommitId};
339
340 use super::super::tests::test_state_space;
341 use super::*;
342
343 use portgraph::PortView;
344 use rstest::rstest;
345
346 #[rstest]
347 fn test_mermaid_string(test_state_space: (CommitStateSpace, [CommitId; 4])) {
348 let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
349
350 let hugr = state_space
351 .try_extract_hugr([commit1, commit2, commit4])
352 .unwrap();
353
354 let mermaid_str = hugr.mermaid_string_with_config(RenderConfig {
355 node_indices: false,
356 entrypoint: Some(hugr.entrypoint()),
357 ..Default::default()
358 });
359 let extracted_hugr = hugr.to_hugr();
360 let exp_str = extracted_hugr
361 .mermaid_string_with_config(RenderConfig {
362 node_indices: false,
363 entrypoint: Some(extracted_hugr.entrypoint()),
364 ..Default::default()
365 })
366 .to_string();
367
368 assert_eq!(mermaid_str, exp_str);
369 }
370
371 #[rstest]
372 fn test_hierarchy(test_state_space: (CommitStateSpace, [CommitId; 4])) {
373 let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
374
375 let hugr = state_space
376 .try_extract_hugr([commit1, commit2, commit4])
377 .unwrap();
378
379 let commit2_nodes = hugr.nodes().filter(|&n| n.0 == commit2).collect_vec();
380 let commit4_nodes = hugr.nodes().filter(|&n| n.0 == commit4).collect_vec();
381
382 let all_children: HashSet<_> = hugr.children(hugr.entrypoint()).collect();
383
384 assert!(commit2_nodes.iter().all(|&n| all_children.contains(&n)));
385 assert!(commit4_nodes.iter().all(|&n| all_children.contains(&n)));
386
387 let (extracted_hugr, node_map) = hugr.apply_all();
388
389 for n in hugr.nodes() {
390 assert_eq!(
391 extracted_hugr.get_parent(node_map[&n]),
392 hugr.get_parent(n).map(|p| node_map[&p])
393 );
394 assert_eq!(
395 extracted_hugr.children(node_map[&n]).collect_vec(),
396 hugr.children(n).map(|c| node_map[&c]).collect_vec()
397 );
398 assert_eq!(
399 extracted_hugr.descendants(node_map[&n]).collect_vec(),
400 hugr.descendants(n).map(|c| node_map[&c]).collect_vec()
401 );
402 }
403 }
404
405 #[rstest]
406 fn test_linked_ports(test_state_space: (CommitStateSpace, [CommitId; 4])) {
407 let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
408
409 let hugr = state_space
410 .try_extract_hugr([commit1, commit2, commit4])
411 .unwrap();
412 let (extracted_hugr, node_map) = hugr.apply_all();
413
414 for n in hugr.nodes() {
415 for port in hugr.all_node_ports(n) {
416 let linked_ports = hugr
417 .linked_ports(n, port)
418 .map(|(node, port)| (node_map[&node], port))
419 .collect_vec();
420 let extracted_linked_ports = extracted_hugr
421 .linked_ports(node_map[&n], port)
422 .collect_vec();
423
424 assert_eq!(linked_ports, extracted_linked_ports);
425
426 for dir in [Direction::Incoming, Direction::Outgoing] {
428 let neighbours = hugr
429 .neighbours(n, dir)
430 .map(|node| node_map[&node])
431 .collect_vec();
432 let extracted_neighbours =
433 extracted_hugr.neighbours(node_map[&n], dir).collect_vec();
434
435 assert_eq!(neighbours, extracted_neighbours);
436 }
437
438 let all_neighbours = hugr
440 .all_neighbours(n)
441 .map(|node| node_map[&node])
442 .collect_vec();
443 let extracted_all_neighbours =
444 extracted_hugr.all_neighbours(node_map[&n]).collect_vec();
445
446 assert_eq!(all_neighbours, extracted_all_neighbours);
447
448 for other in hugr.nodes() {
450 let connections = hugr.node_connections(n, other).collect_vec();
451 let extracted_connections = extracted_hugr
452 .node_connections(node_map[&n], node_map[&other])
453 .collect_vec();
454
455 assert_eq!(connections, extracted_connections);
456 }
457 }
458 }
459 }
460
461 #[rstest]
462 fn test_extract_hugr(test_state_space: (CommitStateSpace, [CommitId; 4])) {
463 let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
464
465 let hugr = state_space
466 .try_extract_hugr([commit1, commit2, commit4])
467 .unwrap();
468 let extracted_hugr = hugr.to_hugr();
469
470 assert_eq!(
471 hugr.module_root(),
472 PatchNode(state_space.base(), state_space.base_hugr().module_root())
473 );
474
475 assert_eq!(hugr.num_nodes(), extracted_hugr.num_nodes());
476 assert_eq!(hugr.num_edges(), extracted_hugr.num_edges());
477
478 let (pg, _) = hugr.region_portgraph(hugr.entrypoint());
479
480 assert_eq!(pg.node_count(), hugr.children(hugr.entrypoint()).count());
481
482 let (new_hugr, _) = hugr.extract_hugr(hugr.entrypoint());
483
484 assert_eq!(new_hugr.num_nodes(), extracted_hugr.num_nodes());
485 }
486}