portdiff/port_diff/
rewrite.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use bimap::BiBTreeMap;
4use thiserror::Error;
5
6use crate::{
7    port::{BoundPort, EdgeEnd, Port},
8    port_diff::IncomingEdgeIndex,
9    subgraph::Subgraph,
10    Graph, PortDiff,
11};
12
13use super::{BoundarySite, EdgeData, Owned, PortDiffData};
14
15#[derive(Error, Debug)]
16pub enum InvalidRewriteError {
17    #[error("{0}")]
18    BoundPortsEdge(String),
19    #[error("{0}")]
20    InvalidEdge(String),
21}
22
23impl<G: Graph> PortDiff<G> {
24    /// Create a new diff that rewrites `nodes` and `edges` to `new_graph`.
25    ///
26    /// The returned diff will be a child of all diffs in `nodes`. Edges are
27    /// expressed as pairs of ports. The nodes they belong to must be in `nodes`.
28    ///
29    /// The function `boundary_map` will be called once for every boundary port
30    /// of the new diff. It is passed as argument an owned port, the image of
31    /// the boundary port in a parent diff. It must return the site of the
32    /// boundary port in the new graph, or a sentinel node.
33    pub fn rewrite(
34        nodes: impl IntoIterator<Item = Owned<G::Node, G>>,
35        edges: impl IntoIterator<Item = (Owned<Port<G>, G>, Owned<Port<G>, G>)>,
36        new_graph: G,
37        mut boundary_map: impl FnMut(Owned<Port<G>, G>) -> BoundarySite<G>,
38    ) -> Result<Self, InvalidRewriteError> {
39        // Collect nodes per portdiff
40        let nodes: BTreeMap<_, BTreeSet<_>> =
41            nodes.into_iter().fold(BTreeMap::new(), |mut map, n| {
42                map.entry(n.owner).or_default().insert(n.data);
43                map
44            });
45        // Split edges into edges within and between portdiffs
46        let mut internal_edges: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
47        let mut used_bound_ports: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
48        let mut used_unbound_ports: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
49        for (left, right) in edges {
50            match (left.data, right.data) {
51                (Port::Bound(left_port), Port::Bound(right_port)) => {
52                    if left.owner != right.owner {
53                        return Err(InvalidRewriteError::BoundPortsEdge(
54                            "Edges between bound ports must be on the same portdiff".to_string(),
55                        ));
56                    }
57                    if left_port.edge != right_port.edge {
58                        return Err(InvalidRewriteError::BoundPortsEdge(
59                            "Edges between bound ports must be on the same edge".to_string(),
60                        ));
61                    }
62                    internal_edges
63                        .entry(left.owner)
64                        .or_default()
65                        .insert(left_port.edge);
66                }
67                (Port::Boundary(left_port), Port::Boundary(right_port)) => {
68                    check_valid_edge(&left, &right)?;
69                    used_unbound_ports
70                        .entry(left.owner)
71                        .or_default()
72                        .insert(left_port);
73                    used_unbound_ports
74                        .entry(right.owner)
75                        .or_default()
76                        .insert(right_port);
77                }
78                (Port::Boundary(left_port), Port::Bound(right_port)) => {
79                    check_valid_edge(&left, &right)?;
80                    if left.owner == right.owner {
81                        return Err(InvalidRewriteError::BoundPortsEdge(
82                            "A bound port may only connect distinct diffs".to_string(),
83                        ));
84                    }
85                    used_unbound_ports
86                        .entry(left.owner)
87                        .or_default()
88                        .insert(left_port);
89                    used_bound_ports
90                        .entry(right.owner)
91                        .or_default()
92                        .insert(right_port);
93                }
94                (Port::Bound(left_port), Port::Boundary(right_port)) => {
95                    check_valid_edge(&right, &left)?;
96                    if left.owner == right.owner {
97                        return Err(InvalidRewriteError::BoundPortsEdge(
98                            "A bound port may only connect distinct diffs".to_string(),
99                        ));
100                    }
101                    used_bound_ports
102                        .entry(left.owner)
103                        .or_default()
104                        .insert(left_port);
105                    used_unbound_ports
106                        .entry(right.owner)
107                        .or_default()
108                        .insert(right_port);
109                }
110            }
111        }
112
113        // Create the incoming edges between parents and the new diff
114        let mut parents = Vec::new();
115        let mut boundary = Vec::new();
116        for (i, (diff, nodes)) in nodes.into_iter().enumerate() {
117            let incoming_edge = IncomingEdgeIndex(i);
118            let mut used_bound_ports = used_bound_ports.remove(&diff).unwrap_or_default();
119            let mut used_unbound_ports = used_unbound_ports.remove(&diff).unwrap_or_default();
120
121            // Create subgraph
122            let edges = internal_edges.remove(&diff).unwrap_or_default();
123            let subgraph = Subgraph::new(&diff.graph, nodes, edges);
124
125            // Map boundaries
126            let mut port_map = BiBTreeMap::new();
127            for b in subgraph.boundary(&diff.graph) {
128                if !used_bound_ports.remove(&b) {
129                    let port = Port::Bound(b);
130                    let site = boundary_map(Owned {
131                        data: port,
132                        owner: diff.clone(),
133                    });
134                    let boundary_ind = boundary.len();
135                    boundary.push((site, incoming_edge));
136                    port_map.insert(port, boundary_ind.into());
137                }
138            }
139            for b in diff.boundary_iter() {
140                let Some(site) = diff.boundary_site(b).try_as_site_ref() else {
141                    // Sentinel boundaries cannot be rewritten
142                    continue;
143                };
144                if !subgraph.nodes().contains(&site.node) {
145                    continue;
146                }
147                if !used_unbound_ports.remove(&b) {
148                    let port = Port::Boundary(b);
149                    let site = boundary_map(Owned {
150                        data: port,
151                        owner: diff.clone(),
152                    });
153                    let boundary_ind = boundary.len();
154                    boundary.push((site, incoming_edge));
155                    port_map.insert(port, boundary_ind.into());
156                }
157            }
158            let edge_data = EdgeData { subgraph, port_map };
159            parents.push((diff, edge_data));
160
161            // Check that the edges used only valid boundary ports
162            if !used_bound_ports.is_empty() {
163                return Err(InvalidRewriteError::InvalidEdge(
164                    "Cross-diff edge uses invalid boundary port".to_string(),
165                ));
166            }
167            if !used_unbound_ports.is_empty() {
168                return Err(InvalidRewriteError::InvalidEdge(
169                    "Cross-diff edge uses invalid boundary port".to_string(),
170                ));
171            }
172        }
173        if !internal_edges.is_empty() {
174            return Err(InvalidRewriteError::InvalidEdge(
175                "Edges with no corresponding nodes".to_string(),
176            ));
177        }
178        let data = PortDiffData {
179            graph: new_graph,
180            boundary,
181            value: None,
182        };
183        Ok(PortDiff::new(data, parents))
184    }
185
186    /// Create a new diff that rewrites `edges` to `new_graph`.
187    ///
188    /// The `nodes` are given by the set of end vertices of the edges. See
189    /// [`Self::rewrite`] for more details.
190    pub fn rewrite_edges(
191        edges: impl IntoIterator<Item = (Owned<Port<G>, G>, Owned<Port<G>, G>)> + Clone,
192        new_graph: G,
193        boundary_map: impl FnMut(Owned<Port<G>, G>) -> BoundarySite<G>,
194    ) -> Result<Self, InvalidRewriteError> {
195        let nodes: BTreeSet<_> = edges
196            .clone()
197            .into_iter()
198            .flat_map(|(l, r)| {
199                [l, r].map(|p| Owned {
200                    data: p.site().unwrap().node, // TODO: what to do with sentinels?
201                    owner: p.owner,
202                })
203            })
204            .collect();
205        Self::rewrite(nodes, edges, new_graph, boundary_map)
206    }
207
208    /// Create a new diff that rewrites the subgraph of `self` induced by `nodes`.
209    ///
210    /// See [`Self::rewrite`] for more details.
211    pub fn rewrite_induced(
212        &self,
213        nodes: &BTreeSet<G::Node>,
214        new_graph: G,
215        mut boundary_map: impl FnMut(Port<G>) -> BoundarySite<G>,
216    ) -> Result<Self, InvalidRewriteError> {
217        let edges = self
218            .graph()
219            .edges_iter()
220            .filter(|&e| {
221                let left_node = self.graph().incident_node(e, EdgeEnd::Left);
222                let right_node = self.graph().incident_node(e, EdgeEnd::Right);
223                nodes.contains(&left_node) && nodes.contains(&right_node)
224            })
225            .map(|edge| {
226                let left_port = Port::Bound(BoundPort {
227                    edge,
228                    end: EdgeEnd::Left,
229                });
230                let right_port = Port::Bound(BoundPort {
231                    edge,
232                    end: EdgeEnd::Right,
233                });
234                (
235                    Owned {
236                        data: left_port,
237                        owner: self.clone(),
238                    },
239                    Owned {
240                        data: right_port,
241                        owner: self.clone(),
242                    },
243                )
244            });
245        let nodes = nodes.into_iter().copied().map(|data| Owned {
246            data,
247            owner: self.clone(),
248        });
249        Self::rewrite(nodes, edges, new_graph, |p| boundary_map(p.data))
250    }
251}
252
253fn check_valid_edge<G: Graph>(
254    left: &Owned<Port<G>, G>,
255    right: &Owned<Port<G>, G>,
256) -> Result<(), InvalidRewriteError> {
257    match left
258        .owner
259        .opposite_ports(left.data)
260        .iter()
261        .find(|p| p == &right)
262    {
263        Some(_) => Ok(()),
264        None => Err(InvalidRewriteError::InvalidEdge(
265            "Valid edges must have opposite ports".to_string(),
266        )),
267    }
268}
269
270// /// The sets of nodes and edges to be rewritten
271// #[derive(Clone)]
272// struct RewriteSubgraph<G: Graph> {
273//     nodes: BTreeSet<G::Node>,
274//     internal_edges: BTreeSet<G::Edge>,
275//     new_boundary_ports: BTreeSet<Site<G::Node, G::PortLabel>>,
276// }
277
278// impl<G: Graph> Default for RewriteSubgraph<G> {
279//     fn default() -> Self {
280//         Self {
281//             nodes: BTreeSet::new(),
282//             internal_edges: BTreeSet::new(),
283//             new_boundary_ports: BTreeSet::new(),
284//         }
285//     }
286// }
287
288// impl<G: Graph> RewriteSubgraph<G> {
289//     fn collect(
290//         nodes: impl IntoIterator<Item = UniqueNodeId<G>>,
291//         edges: impl IntoIterator<Item = EdgeData<G>>,
292//     ) -> HashMap<PortDiff<G>, Self> {
293//         let mut ret_map = HashMap::<PortDiff<G>, Self>::new();
294
295//         for node in nodes.into_iter() {
296//             ret_map
297//                 .entry(node.owner)
298//                 .or_default()
299//                 .nodes
300//                 .insert(node.node);
301//         }
302
303//         for edge in edges {
304//             match edge {
305//                 EdgeData::Internal { owner, edge } => {
306//                     ret_map
307//                         .entry(owner)
308//                         .or_default()
309//                         .internal_edges
310//                         .insert(edge);
311//                 }
312//                 EdgeData::Boundary { left, right } => {
313//                     if let Port::Unbound { owner, port } = left {
314//                         ret_map
315//                             .entry(owner)
316//                             .or_default()
317//                             .new_boundary_ports
318//                             .insert(port);
319//                     }
320//                     if let Port::Unbound { owner, port } = right {
321//                         ret_map
322//                             .entry(owner)
323//                             .or_default()
324//                             .new_boundary_ports
325//                             .insert(port);
326//                     }
327//                 }
328//             }
329//         }
330//         ret_map
331//     }
332
333//     fn filter_boundary<'a>(
334//         &'a self,
335//         boundary: &'a Boundary<G>,
336//     ) -> impl Iterator<Item = (Site<G::Node, G::PortLabel>, ParentPort<G>)> + 'a {
337//         let mut boundary_ports = self.new_boundary_ports.clone();
338//         boundary
339//             .iter()
340//             .filter(|(port, _)| self.nodes.contains(&port.node))
341//             .filter(move |(port, _)| {
342//                 // Only keep in boundary if not present in new boundary edges
343//                 !boundary_ports.remove(port)
344//             })
345//             .map(|(port, parent)| (port.clone(), parent.clone()))
346//     }
347
348//     fn new_boundary_from_edges<'a>(
349//         &'a self,
350//         graph: &'a G,
351//     ) -> impl Iterator<Item = (Site<G::Node, G::PortLabel>, BoundPort<G::Edge>)> + 'a {
352//         graph
353//             .edges_iter()
354//             .filter(|e| !self.internal_edges.contains(e))
355//             .flat_map(move |edge| {
356//                 let left_port = graph.get_port_site(BoundPort {
357//                     edge,
358//                     port: EdgeEnd::Left,
359//                 });
360//                 let right_port = graph.get_port_site(BoundPort {
361//                     edge,
362//                     port: EdgeEnd::Right,
363//                 });
364//                 let mut boundary_ports = Vec::new();
365//                 if self.nodes.contains(&left_port.node) {
366//                     boundary_ports.push((
367//                         left_port,
368//                         BoundPort {
369//                             edge,
370//                             port: EdgeEnd::Left,
371//                         },
372//                     ));
373//                 }
374//                 if self.nodes.contains(&right_port.node) {
375//                     boundary_ports.push((
376//                         right_port,
377//                         BoundPort {
378//                             edge,
379//                             port: EdgeEnd::Right,
380//                         },
381//                     ));
382//                 }
383//                 boundary_ports
384//             })
385//     }
386// }
387
388#[cfg(feature = "portgraph")]
389#[cfg(test)]
390mod tests {
391    use insta::assert_snapshot;
392    use itertools::Itertools;
393    use portgraph::{
394        render::DotFormat, LinkMut, LinkView, NodeIndex, PortGraph, PortMut, PortOffset, PortView,
395    };
396    use rstest::rstest;
397
398    use crate::{
399        port::Port,
400        port_diff::tests::{parent_child_diffs, TestPortDiff},
401        Site,
402    };
403
404    use super::*;
405
406    #[ignore = "TODO this is currently not deterministic"]
407    #[rstest]
408    fn test_rewrite(parent_child_diffs: [TestPortDiff; 2]) {
409        let [parent, _] = parent_child_diffs;
410        let rewrite = |v| {
411            let mut rhs = PortGraph::new();
412            let n0 = rhs.add_node(0, 4);
413            let n1 = rhs.add_node(1, 0);
414            rhs.link_nodes(n0, 3, n1, 0).unwrap();
415            parent.rewrite_induced(&BTreeSet::from_iter([v]), rhs, |p| {
416                let offset = Owned::new(p, parent.clone()).site().unwrap().port;
417                Site {
418                    node: n0,
419                    port: offset,
420                }
421                .into()
422            })
423        };
424        let (_, n1, n2, _) = PortView::nodes_iter(&parent.graph).collect_tuple().unwrap();
425        let child_a = rewrite(n1).unwrap();
426        let child_b = rewrite(n2).unwrap();
427
428        let pg: PortGraph =
429            PortDiff::extract_graph([child_a.clone(), child_b.clone()].to_vec()).unwrap();
430        assert_eq!(pg.node_count(), 6);
431        assert_eq!(pg.link_count(), 3 + 3 + 1 + 2);
432
433        // Now rewrite across child_a and child_b
434        let mut rhs = PortGraph::new();
435        let n0 = rhs.add_node(0, 2);
436        let n1 = rhs.add_node(2, 0);
437        rhs.link_nodes(n0, 0, n1, 0).unwrap();
438        rhs.link_nodes(n0, 1, n1, 1).unwrap();
439
440        let child_a_out0 = child_a
441            .boundary_iter()
442            .find(|&bd| {
443                child_a.boundary_site(bd).try_as_site_ref().unwrap().port == PortOffset::Outgoing(0)
444            })
445            .unwrap();
446        let child_b_in0 = child_b
447            .boundary_iter()
448            .find(|&bd| {
449                child_b.boundary_site(bd).try_as_site_ref().unwrap().port == PortOffset::Incoming(0)
450            })
451            .unwrap();
452        let cross_edge = (
453            Owned::new(Port::Boundary(child_a_out0), child_a.clone()),
454            Owned::new(Port::Boundary(child_b_in0), child_b.clone()),
455        );
456
457        let nodes = BTreeSet::from_iter([
458            Owned::new(n0, child_a.clone()),
459            Owned::new(n0, child_b.clone()),
460        ]);
461        let merged = PortDiff::rewrite(nodes, [cross_edge], rhs, |n| {
462            if n.owner == child_a {
463                Site {
464                    node: n0,
465                    port: n.site().unwrap().port,
466                }
467                .into()
468            } else {
469                Site {
470                    node: n1,
471                    port: n.site().unwrap().port,
472                }
473                .into()
474            }
475        })
476        .unwrap();
477        let pg: PortGraph = PortDiff::extract_graph([merged].to_vec()).unwrap();
478        assert_snapshot!("extracted_graph_2", pg.dot_string());
479    }
480
481    #[rstest]
482    fn test_rewrite_empty(parent_child_diffs: [TestPortDiff; 2]) {
483        let [parent, _] = parent_child_diffs;
484        //          a --        -- d              a --- d
485        // Rewrite  a -- b -- c -- d      =>      a --- d
486        //          a --        -- d              a --- d
487        let rewritten = parent
488            .rewrite_induced(
489                &BTreeSet::from([NodeIndex::new(1), NodeIndex::new(2)]),
490                PortGraph::new(),
491                |p| {
492                    let Port::Bound(BoundPort { edge, end }) = p else {
493                        panic!("expected bound port")
494                    };
495                    BoundarySite::Wire {
496                        id: edge.out_offset().index(),
497                        end,
498                    }
499                },
500            )
501            .unwrap();
502        let g = PortDiff::extract_graph([rewritten].to_vec()).unwrap();
503        assert_eq!(g.node_count(), 2);
504        assert_eq!(g.link_count(), 3);
505    }
506}