hugr_core/hugr/persistent/
trait_impls.rs

1use std::collections::HashMap;
2
3use itertools::{Either, Itertools};
4use portgraph::render::{DotFormat, MermaidFormat};
5
6use crate::{
7    Direction, Hugr, HugrView, Node, Port,
8    hugr::{
9        Patch, SimpleReplacementError,
10        internal::HugrInternals,
11        views::{
12            ExtractionResult,
13            render::{self, RenderConfig},
14        },
15    },
16};
17
18use super::{
19    InvalidCommit, PatchNode, PersistentHugr, PersistentReplacement, state_space::CommitData,
20};
21
22impl Patch<PersistentHugr> for PersistentReplacement {
23    type Outcome = ();
24    const UNCHANGED_ON_FAILURE: bool = true;
25
26    fn apply(self, h: &mut PersistentHugr) -> Result<Self::Outcome, Self::Error> {
27        match h.try_add_replacement(self) {
28            Ok(_) => Ok(()),
29            Err(
30                InvalidCommit::UnknownParent(_)
31                | InvalidCommit::IncompatibleHistory(_, _)
32                | InvalidCommit::EmptyReplacement,
33            ) => Err(SimpleReplacementError::InvalidRemovedNode()),
34            _ => unreachable!(),
35        }
36    }
37}
38
39impl HugrInternals for PersistentHugr {
40    type RegionPortgraph<'p>
41        = portgraph::MultiPortGraph
42    where
43        Self: 'p;
44
45    type Node = PatchNode;
46
47    type RegionPortgraphNodes = HashMap<PatchNode, Node>;
48
49    fn region_portgraph(
50        &self,
51        parent: Self::Node,
52    ) -> (
53        portgraph::view::FlatRegion<'_, Self::RegionPortgraph<'_>>,
54        Self::RegionPortgraphNodes,
55    ) {
56        // TODO: this is currently not very efficient (see #2248)
57        let (hugr, node_map) = self.apply_all();
58        let parent = node_map[&parent];
59
60        let region = portgraph::view::FlatRegion::new_without_root(
61            hugr.graph,
62            hugr.hierarchy,
63            parent.into_portgraph(),
64        );
65        (region, node_map)
66    }
67
68    fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap {
69        self.as_state_space().node_metadata_map(node)
70    }
71}
72
73// TODO: A lot of these implementations (especially the ones relating to node
74// hierarchies) are very inefficient as they (often unnecessarily) construct
75// the whole extracted HUGR in memory. We are currently prioritizing correctness
76// and clarity over performance and will optimise some of these operations in
77// the future as bottlenecks are encountered. (see #2248)
78impl HugrView for PersistentHugr {
79    fn entrypoint(&self) -> Self::Node {
80        // The entrypoint remains unchanged throughout the patch history, and is
81        // found in the base hugr.
82        let entry = self.base_hugr().entrypoint();
83        let node = PatchNode(self.base(), entry);
84
85        debug_assert!(self.contains_node(node), "invalid entrypoint");
86        node
87    }
88
89    fn module_root(&self) -> Self::Node {
90        // The module root remains unchanged throughout the patch history, and is
91        // found in the base hugr.
92        let root = self.base_hugr().module_root();
93        let node = PatchNode(self.base(), root);
94
95        debug_assert!(self.contains_node(node), "invalid module root");
96        node
97    }
98
99    fn contains_node(&self, node: Self::Node) -> bool {
100        self.contains_node(node)
101    }
102
103    fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
104        assert!(self.contains_node(node), "invalid node");
105        let (hugr, node_map) = self.apply_all();
106        let parent = hugr.get_parent(node_map[&node])?;
107        let parent_inv = node_map
108            .iter()
109            .find_map(|(&k, &v)| (v == parent).then_some(k))
110            .expect("parent not found in node map");
111        Some(parent_inv)
112    }
113
114    fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType {
115        self.as_state_space().get_optype(node)
116    }
117
118    fn num_nodes(&self) -> usize {
119        let mut num_nodes = 0isize;
120        for commit in self.all_commit_ids() {
121            num_nodes += self.inserted_nodes(commit).count() as isize;
122            num_nodes -= self.deleted_nodes(commit).count() as isize;
123        }
124        num_nodes as usize
125    }
126
127    fn num_edges(&self) -> usize {
128        self.to_hugr().num_edges()
129    }
130
131    fn num_ports(&self, node: Self::Node, dir: Direction) -> usize {
132        self.as_state_space().num_ports(node, dir)
133    }
134
135    fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone {
136        self.all_commit_ids()
137            .flat_map(|commit_id| {
138                let to_patch_node = move |node: Node| PatchNode(commit_id, node);
139                match self.get_commit(commit_id).value() {
140                    CommitData::Base(hugr) => Either::Left(hugr.nodes().map(to_patch_node)),
141                    CommitData::Replacement(repl) => Either::Right(
142                        repl.replacement()
143                            .children(repl.replacement().entrypoint())
144                            .filter(|&n| {
145                                let ot = repl.replacement().get_optype(n);
146                                !ot.is_input() && !ot.is_output()
147                            })
148                            .map(to_patch_node),
149                    ),
150                }
151            })
152            .filter(|&n| self.contains_node(n))
153    }
154
155    fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
156        self.as_state_space().node_ports(node, dir)
157    }
158
159    fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = Port> + Clone {
160        self.as_state_space().all_node_ports(node)
161    }
162
163    fn linked_ports(
164        &self,
165        node: Self::Node,
166        port: impl Into<Port>,
167    ) -> impl Iterator<Item = (Self::Node, Port)> + Clone {
168        let port = port.into();
169        let mut ret_ports = Vec::new();
170        if !self.is_value_port(node, port) {
171            // currently non-value ports are not modified by patches
172            let commit_id = node.0;
173            let to_patch_node = |(node, port)| (PatchNode(commit_id, node), port);
174            ret_ports.extend(
175                self.commit_hugr(commit_id)
176                    .linked_ports(node.1, port)
177                    .map(to_patch_node),
178            );
179        } else {
180            match port.as_directed() {
181                Either::Left(incoming) => {
182                    let (out_node, out_port) = self.get_single_outgoing_port(node, incoming);
183                    ret_ports.push((out_node, out_port.into()))
184                }
185                Either::Right(outgoing) => ret_ports.extend(
186                    self.get_all_incoming_ports(node, outgoing)
187                        .map(|(node, port)| (node, port.into())),
188                ),
189            }
190        }
191
192        ret_ports.into_iter()
193    }
194
195    fn node_connections(
196        &self,
197        node: Self::Node,
198        other: Self::Node,
199    ) -> impl Iterator<Item = [Port; 2]> + Clone {
200        self.node_outputs(node)
201            .flat_map(move |port| {
202                self.linked_ports(node, port)
203                    .map(move |(opp_node, opp_port)| (port, opp_node, opp_port))
204            })
205            .filter(move |&(_, opp_node, _)| opp_node == other)
206            .map(|(port, _, opp_port)| [port.into(), opp_port])
207    }
208
209    fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
210        let (hugr, node_map) = self.apply_all();
211        let children = hugr.children(node_map[&node]).collect_vec();
212        let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
213        children.into_iter().map(move |child| {
214            *inv_node_map
215                .get(&child)
216                .expect("node not found in node map")
217        })
218    }
219
220    fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
221        let (hugr, node_map) = self.apply_all();
222        let descendants = hugr.descendants(node_map[&node]).collect_vec();
223        let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
224        descendants.into_iter().map(move |child| {
225            *inv_node_map
226                .get(&child)
227                .expect("node not found in node map")
228        })
229    }
230
231    fn neighbours(
232        &self,
233        node: Self::Node,
234        dir: Direction,
235    ) -> impl Iterator<Item = Self::Node> + Clone {
236        self.node_ports(node, dir)
237            .flat_map(move |port| self.linked_ports(node, port).map(|(opp_node, _)| opp_node))
238    }
239
240    fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone {
241        self.all_node_ports(node)
242            .flat_map(move |port| self.linked_ports(node, port).map(|(opp_node, _)| opp_node))
243    }
244
245    fn mermaid_string(&self) -> String {
246        self.mermaid_string_with_config(RenderConfig {
247            node_indices: true,
248            port_offsets_in_edges: true,
249            type_labels_in_edges: true,
250            entrypoint: Some(self.entrypoint()),
251        })
252    }
253
254    fn mermaid_string_with_config(&self, config: RenderConfig<Self::Node>) -> String {
255        // Extract a concrete HUGR for displaying
256        let (hugr, node_map) = self.apply_all();
257
258        // Map config accordingly
259        let config = RenderConfig {
260            entrypoint: config.entrypoint.map(|n| node_map[&n]),
261            node_indices: config.node_indices,
262            port_offsets_in_edges: config.port_offsets_in_edges,
263            type_labels_in_edges: config.type_labels_in_edges,
264        };
265
266        // Render the extracted HUGR but map the node indices back to the
267        // original patch node IDs
268        let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
269        let fmt_node_index = |n: portgraph::NodeIndex| format!("{:?}", inv_node_map[&n.into()]);
270        hugr.graph
271            .mermaid_format()
272            .with_hierarchy(&hugr.hierarchy)
273            .with_node_style(render::node_style(&hugr, config, fmt_node_index))
274            .with_edge_style(render::edge_style(&hugr, config))
275            .finish()
276    }
277
278    fn dot_string(&self) -> String
279    where
280        Self: Sized,
281    {
282        // Extract a concrete HUGR for displaying
283        let (hugr, node_map) = self.apply_all();
284
285        // Map config accordingly
286        let config = RenderConfig {
287            entrypoint: Some(node_map[&self.entrypoint()]),
288            ..RenderConfig::default()
289        };
290
291        // Render the extracted HUGR but map the node indices back to the
292        // original patch node IDs
293        let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
294        let fmt_node_index = |n: portgraph::NodeIndex| format!("{:?}", inv_node_map[&n.into()]);
295        hugr.graph
296            .dot_format()
297            .with_hierarchy(&hugr.hierarchy)
298            .with_node_style(render::node_style(&hugr, config, fmt_node_index))
299            .with_port_style(render::port_style(&hugr, config))
300            .with_edge_style(render::edge_style(&hugr, config))
301            .finish()
302    }
303
304    fn extensions(&self) -> &crate::extension::ExtensionRegistry {
305        &self.base_hugr().extensions
306    }
307
308    fn extract_hugr(
309        &self,
310        parent: Self::Node,
311    ) -> (
312        Hugr,
313        impl crate::hugr::views::ExtractionResult<Self::Node> + 'static,
314    ) {
315        let (hugr, apply_node_map) = self.apply_all();
316        let (extracted_hugr, extracted_node_map) = hugr.extract_hugr(apply_node_map[&parent]);
317
318        let node_map: HashMap<_, _> = apply_node_map
319            .into_iter()
320            .filter_map(|(patch_node, node)| {
321                let extracted_node = extracted_node_map.extracted_node(node);
322                if extracted_hugr.contains_node(extracted_node) {
323                    Some((patch_node, extracted_node))
324                } else {
325                    None
326                }
327            })
328            .collect();
329
330        (extracted_hugr, node_map)
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use std::collections::HashSet;
337
338    use crate::hugr::persistent::{CommitStateSpace, state_space::CommitId};
339
340    use super::super::tests::test_state_space;
341    use super::*;
342
343    use portgraph::PortView;
344    use rstest::rstest;
345
346    #[rstest]
347    fn test_mermaid_string(test_state_space: (CommitStateSpace, [CommitId; 4])) {
348        let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
349
350        let hugr = state_space
351            .try_extract_hugr([commit1, commit2, commit4])
352            .unwrap();
353
354        let mermaid_str = hugr.mermaid_string_with_config(RenderConfig {
355            node_indices: false,
356            entrypoint: Some(hugr.entrypoint()),
357            ..Default::default()
358        });
359        let extracted_hugr = hugr.to_hugr();
360        let exp_str = extracted_hugr
361            .mermaid_string_with_config(RenderConfig {
362                node_indices: false,
363                entrypoint: Some(extracted_hugr.entrypoint()),
364                ..Default::default()
365            })
366            .to_string();
367
368        assert_eq!(mermaid_str, exp_str);
369    }
370
371    #[rstest]
372    fn test_hierarchy(test_state_space: (CommitStateSpace, [CommitId; 4])) {
373        let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
374
375        let hugr = state_space
376            .try_extract_hugr([commit1, commit2, commit4])
377            .unwrap();
378
379        let commit2_nodes = hugr.nodes().filter(|&n| n.0 == commit2).collect_vec();
380        let commit4_nodes = hugr.nodes().filter(|&n| n.0 == commit4).collect_vec();
381
382        let all_children: HashSet<_> = hugr.children(hugr.entrypoint()).collect();
383
384        assert!(commit2_nodes.iter().all(|&n| all_children.contains(&n)));
385        assert!(commit4_nodes.iter().all(|&n| all_children.contains(&n)));
386
387        let (extracted_hugr, node_map) = hugr.apply_all();
388
389        for n in hugr.nodes() {
390            assert_eq!(
391                extracted_hugr.get_parent(node_map[&n]),
392                hugr.get_parent(n).map(|p| node_map[&p])
393            );
394            assert_eq!(
395                extracted_hugr.children(node_map[&n]).collect_vec(),
396                hugr.children(n).map(|c| node_map[&c]).collect_vec()
397            );
398            assert_eq!(
399                extracted_hugr.descendants(node_map[&n]).collect_vec(),
400                hugr.descendants(n).map(|c| node_map[&c]).collect_vec()
401            );
402        }
403    }
404
405    #[rstest]
406    fn test_linked_ports(test_state_space: (CommitStateSpace, [CommitId; 4])) {
407        let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
408
409        let hugr = state_space
410            .try_extract_hugr([commit1, commit2, commit4])
411            .unwrap();
412        let (extracted_hugr, node_map) = hugr.apply_all();
413
414        for n in hugr.nodes() {
415            for port in hugr.all_node_ports(n) {
416                let linked_ports = hugr
417                    .linked_ports(n, port)
418                    .map(|(node, port)| (node_map[&node], port))
419                    .collect_vec();
420                let extracted_linked_ports = extracted_hugr
421                    .linked_ports(node_map[&n], port)
422                    .collect_vec();
423
424                assert_eq!(linked_ports, extracted_linked_ports);
425
426                // Test neighbours
427                for dir in [Direction::Incoming, Direction::Outgoing] {
428                    let neighbours = hugr
429                        .neighbours(n, dir)
430                        .map(|node| node_map[&node])
431                        .collect_vec();
432                    let extracted_neighbours =
433                        extracted_hugr.neighbours(node_map[&n], dir).collect_vec();
434
435                    assert_eq!(neighbours, extracted_neighbours);
436                }
437
438                // Test all_neighbours
439                let all_neighbours = hugr
440                    .all_neighbours(n)
441                    .map(|node| node_map[&node])
442                    .collect_vec();
443                let extracted_all_neighbours =
444                    extracted_hugr.all_neighbours(node_map[&n]).collect_vec();
445
446                assert_eq!(all_neighbours, extracted_all_neighbours);
447
448                // Test node_connections with all other nodes
449                for other in hugr.nodes() {
450                    let connections = hugr.node_connections(n, other).collect_vec();
451                    let extracted_connections = extracted_hugr
452                        .node_connections(node_map[&n], node_map[&other])
453                        .collect_vec();
454
455                    assert_eq!(connections, extracted_connections);
456                }
457            }
458        }
459    }
460
461    #[rstest]
462    fn test_extract_hugr(test_state_space: (CommitStateSpace, [CommitId; 4])) {
463        let (state_space, [commit1, commit2, _commit3, commit4]) = test_state_space;
464
465        let hugr = state_space
466            .try_extract_hugr([commit1, commit2, commit4])
467            .unwrap();
468        let extracted_hugr = hugr.to_hugr();
469
470        assert_eq!(
471            hugr.module_root(),
472            PatchNode(state_space.base(), state_space.base_hugr().module_root())
473        );
474
475        assert_eq!(hugr.num_nodes(), extracted_hugr.num_nodes());
476        assert_eq!(hugr.num_edges(), extracted_hugr.num_edges());
477
478        let (pg, _) = hugr.region_portgraph(hugr.entrypoint());
479
480        assert_eq!(pg.node_count(), hugr.children(hugr.entrypoint()).count());
481
482        let (new_hugr, _) = hugr.extract_hugr(hugr.entrypoint());
483
484        assert_eq!(new_hugr.num_nodes(), extracted_hugr.num_nodes());
485    }
486}