portdiff/port_diff/rewrite.rs
1use std::collections::{BTreeMap, BTreeSet};
2
3use bimap::BiBTreeMap;
4use thiserror::Error;
5
6use crate::{
7 port::{BoundPort, EdgeEnd, Port},
8 port_diff::IncomingEdgeIndex,
9 subgraph::Subgraph,
10 Graph, PortDiff,
11};
12
13use super::{BoundarySite, EdgeData, Owned, PortDiffData};
14
15#[derive(Error, Debug)]
16pub enum InvalidRewriteError {
17 #[error("{0}")]
18 BoundPortsEdge(String),
19 #[error("{0}")]
20 InvalidEdge(String),
21}
22
23impl<G: Graph> PortDiff<G> {
24 /// Create a new diff that rewrites `nodes` and `edges` to `new_graph`.
25 ///
26 /// The returned diff will be a child of all diffs in `nodes`. Edges are
27 /// expressed as pairs of ports. The nodes they belong to must be in `nodes`.
28 ///
29 /// The function `boundary_map` will be called once for every boundary port
30 /// of the new diff. It is passed as argument an owned port, the image of
31 /// the boundary port in a parent diff. It must return the site of the
32 /// boundary port in the new graph, or a sentinel node.
33 pub fn rewrite(
34 nodes: impl IntoIterator<Item = Owned<G::Node, G>>,
35 edges: impl IntoIterator<Item = (Owned<Port<G>, G>, Owned<Port<G>, G>)>,
36 new_graph: G,
37 mut boundary_map: impl FnMut(Owned<Port<G>, G>) -> BoundarySite<G>,
38 ) -> Result<Self, InvalidRewriteError> {
39 // Collect nodes per portdiff
40 let nodes: BTreeMap<_, BTreeSet<_>> =
41 nodes.into_iter().fold(BTreeMap::new(), |mut map, n| {
42 map.entry(n.owner).or_default().insert(n.data);
43 map
44 });
45 // Split edges into edges within and between portdiffs
46 let mut internal_edges: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
47 let mut used_bound_ports: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
48 let mut used_unbound_ports: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
49 for (left, right) in edges {
50 match (left.data, right.data) {
51 (Port::Bound(left_port), Port::Bound(right_port)) => {
52 if left.owner != right.owner {
53 return Err(InvalidRewriteError::BoundPortsEdge(
54 "Edges between bound ports must be on the same portdiff".to_string(),
55 ));
56 }
57 if left_port.edge != right_port.edge {
58 return Err(InvalidRewriteError::BoundPortsEdge(
59 "Edges between bound ports must be on the same edge".to_string(),
60 ));
61 }
62 internal_edges
63 .entry(left.owner)
64 .or_default()
65 .insert(left_port.edge);
66 }
67 (Port::Boundary(left_port), Port::Boundary(right_port)) => {
68 check_valid_edge(&left, &right)?;
69 used_unbound_ports
70 .entry(left.owner)
71 .or_default()
72 .insert(left_port);
73 used_unbound_ports
74 .entry(right.owner)
75 .or_default()
76 .insert(right_port);
77 }
78 (Port::Boundary(left_port), Port::Bound(right_port)) => {
79 check_valid_edge(&left, &right)?;
80 if left.owner == right.owner {
81 return Err(InvalidRewriteError::BoundPortsEdge(
82 "A bound port may only connect distinct diffs".to_string(),
83 ));
84 }
85 used_unbound_ports
86 .entry(left.owner)
87 .or_default()
88 .insert(left_port);
89 used_bound_ports
90 .entry(right.owner)
91 .or_default()
92 .insert(right_port);
93 }
94 (Port::Bound(left_port), Port::Boundary(right_port)) => {
95 check_valid_edge(&right, &left)?;
96 if left.owner == right.owner {
97 return Err(InvalidRewriteError::BoundPortsEdge(
98 "A bound port may only connect distinct diffs".to_string(),
99 ));
100 }
101 used_bound_ports
102 .entry(left.owner)
103 .or_default()
104 .insert(left_port);
105 used_unbound_ports
106 .entry(right.owner)
107 .or_default()
108 .insert(right_port);
109 }
110 }
111 }
112
113 // Create the incoming edges between parents and the new diff
114 let mut parents = Vec::new();
115 let mut boundary = Vec::new();
116 for (i, (diff, nodes)) in nodes.into_iter().enumerate() {
117 let incoming_edge = IncomingEdgeIndex(i);
118 let mut used_bound_ports = used_bound_ports.remove(&diff).unwrap_or_default();
119 let mut used_unbound_ports = used_unbound_ports.remove(&diff).unwrap_or_default();
120
121 // Create subgraph
122 let edges = internal_edges.remove(&diff).unwrap_or_default();
123 let subgraph = Subgraph::new(&diff.graph, nodes, edges);
124
125 // Map boundaries
126 let mut port_map = BiBTreeMap::new();
127 for b in subgraph.boundary(&diff.graph) {
128 if !used_bound_ports.remove(&b) {
129 let port = Port::Bound(b);
130 let site = boundary_map(Owned {
131 data: port,
132 owner: diff.clone(),
133 });
134 let boundary_ind = boundary.len();
135 boundary.push((site, incoming_edge));
136 port_map.insert(port, boundary_ind.into());
137 }
138 }
139 for b in diff.boundary_iter() {
140 let Some(site) = diff.boundary_site(b).try_as_site_ref() else {
141 // Sentinel boundaries cannot be rewritten
142 continue;
143 };
144 if !subgraph.nodes().contains(&site.node) {
145 continue;
146 }
147 if !used_unbound_ports.remove(&b) {
148 let port = Port::Boundary(b);
149 let site = boundary_map(Owned {
150 data: port,
151 owner: diff.clone(),
152 });
153 let boundary_ind = boundary.len();
154 boundary.push((site, incoming_edge));
155 port_map.insert(port, boundary_ind.into());
156 }
157 }
158 let edge_data = EdgeData { subgraph, port_map };
159 parents.push((diff, edge_data));
160
161 // Check that the edges used only valid boundary ports
162 if !used_bound_ports.is_empty() {
163 return Err(InvalidRewriteError::InvalidEdge(
164 "Cross-diff edge uses invalid boundary port".to_string(),
165 ));
166 }
167 if !used_unbound_ports.is_empty() {
168 return Err(InvalidRewriteError::InvalidEdge(
169 "Cross-diff edge uses invalid boundary port".to_string(),
170 ));
171 }
172 }
173 if !internal_edges.is_empty() {
174 return Err(InvalidRewriteError::InvalidEdge(
175 "Edges with no corresponding nodes".to_string(),
176 ));
177 }
178 let data = PortDiffData {
179 graph: new_graph,
180 boundary,
181 value: None,
182 };
183 Ok(PortDiff::new(data, parents))
184 }
185
186 /// Create a new diff that rewrites `edges` to `new_graph`.
187 ///
188 /// The `nodes` are given by the set of end vertices of the edges. See
189 /// [`Self::rewrite`] for more details.
190 pub fn rewrite_edges(
191 edges: impl IntoIterator<Item = (Owned<Port<G>, G>, Owned<Port<G>, G>)> + Clone,
192 new_graph: G,
193 boundary_map: impl FnMut(Owned<Port<G>, G>) -> BoundarySite<G>,
194 ) -> Result<Self, InvalidRewriteError> {
195 let nodes: BTreeSet<_> = edges
196 .clone()
197 .into_iter()
198 .flat_map(|(l, r)| {
199 [l, r].map(|p| Owned {
200 data: p.site().unwrap().node, // TODO: what to do with sentinels?
201 owner: p.owner,
202 })
203 })
204 .collect();
205 Self::rewrite(nodes, edges, new_graph, boundary_map)
206 }
207
208 /// Create a new diff that rewrites the subgraph of `self` induced by `nodes`.
209 ///
210 /// See [`Self::rewrite`] for more details.
211 pub fn rewrite_induced(
212 &self,
213 nodes: &BTreeSet<G::Node>,
214 new_graph: G,
215 mut boundary_map: impl FnMut(Port<G>) -> BoundarySite<G>,
216 ) -> Result<Self, InvalidRewriteError> {
217 let edges = self
218 .graph()
219 .edges_iter()
220 .filter(|&e| {
221 let left_node = self.graph().incident_node(e, EdgeEnd::Left);
222 let right_node = self.graph().incident_node(e, EdgeEnd::Right);
223 nodes.contains(&left_node) && nodes.contains(&right_node)
224 })
225 .map(|edge| {
226 let left_port = Port::Bound(BoundPort {
227 edge,
228 end: EdgeEnd::Left,
229 });
230 let right_port = Port::Bound(BoundPort {
231 edge,
232 end: EdgeEnd::Right,
233 });
234 (
235 Owned {
236 data: left_port,
237 owner: self.clone(),
238 },
239 Owned {
240 data: right_port,
241 owner: self.clone(),
242 },
243 )
244 });
245 let nodes = nodes.into_iter().copied().map(|data| Owned {
246 data,
247 owner: self.clone(),
248 });
249 Self::rewrite(nodes, edges, new_graph, |p| boundary_map(p.data))
250 }
251}
252
253fn check_valid_edge<G: Graph>(
254 left: &Owned<Port<G>, G>,
255 right: &Owned<Port<G>, G>,
256) -> Result<(), InvalidRewriteError> {
257 match left
258 .owner
259 .opposite_ports(left.data)
260 .iter()
261 .find(|p| p == &right)
262 {
263 Some(_) => Ok(()),
264 None => Err(InvalidRewriteError::InvalidEdge(
265 "Valid edges must have opposite ports".to_string(),
266 )),
267 }
268}
269
270// /// The sets of nodes and edges to be rewritten
271// #[derive(Clone)]
272// struct RewriteSubgraph<G: Graph> {
273// nodes: BTreeSet<G::Node>,
274// internal_edges: BTreeSet<G::Edge>,
275// new_boundary_ports: BTreeSet<Site<G::Node, G::PortLabel>>,
276// }
277
278// impl<G: Graph> Default for RewriteSubgraph<G> {
279// fn default() -> Self {
280// Self {
281// nodes: BTreeSet::new(),
282// internal_edges: BTreeSet::new(),
283// new_boundary_ports: BTreeSet::new(),
284// }
285// }
286// }
287
288// impl<G: Graph> RewriteSubgraph<G> {
289// fn collect(
290// nodes: impl IntoIterator<Item = UniqueNodeId<G>>,
291// edges: impl IntoIterator<Item = EdgeData<G>>,
292// ) -> HashMap<PortDiff<G>, Self> {
293// let mut ret_map = HashMap::<PortDiff<G>, Self>::new();
294
295// for node in nodes.into_iter() {
296// ret_map
297// .entry(node.owner)
298// .or_default()
299// .nodes
300// .insert(node.node);
301// }
302
303// for edge in edges {
304// match edge {
305// EdgeData::Internal { owner, edge } => {
306// ret_map
307// .entry(owner)
308// .or_default()
309// .internal_edges
310// .insert(edge);
311// }
312// EdgeData::Boundary { left, right } => {
313// if let Port::Unbound { owner, port } = left {
314// ret_map
315// .entry(owner)
316// .or_default()
317// .new_boundary_ports
318// .insert(port);
319// }
320// if let Port::Unbound { owner, port } = right {
321// ret_map
322// .entry(owner)
323// .or_default()
324// .new_boundary_ports
325// .insert(port);
326// }
327// }
328// }
329// }
330// ret_map
331// }
332
333// fn filter_boundary<'a>(
334// &'a self,
335// boundary: &'a Boundary<G>,
336// ) -> impl Iterator<Item = (Site<G::Node, G::PortLabel>, ParentPort<G>)> + 'a {
337// let mut boundary_ports = self.new_boundary_ports.clone();
338// boundary
339// .iter()
340// .filter(|(port, _)| self.nodes.contains(&port.node))
341// .filter(move |(port, _)| {
342// // Only keep in boundary if not present in new boundary edges
343// !boundary_ports.remove(port)
344// })
345// .map(|(port, parent)| (port.clone(), parent.clone()))
346// }
347
348// fn new_boundary_from_edges<'a>(
349// &'a self,
350// graph: &'a G,
351// ) -> impl Iterator<Item = (Site<G::Node, G::PortLabel>, BoundPort<G::Edge>)> + 'a {
352// graph
353// .edges_iter()
354// .filter(|e| !self.internal_edges.contains(e))
355// .flat_map(move |edge| {
356// let left_port = graph.get_port_site(BoundPort {
357// edge,
358// port: EdgeEnd::Left,
359// });
360// let right_port = graph.get_port_site(BoundPort {
361// edge,
362// port: EdgeEnd::Right,
363// });
364// let mut boundary_ports = Vec::new();
365// if self.nodes.contains(&left_port.node) {
366// boundary_ports.push((
367// left_port,
368// BoundPort {
369// edge,
370// port: EdgeEnd::Left,
371// },
372// ));
373// }
374// if self.nodes.contains(&right_port.node) {
375// boundary_ports.push((
376// right_port,
377// BoundPort {
378// edge,
379// port: EdgeEnd::Right,
380// },
381// ));
382// }
383// boundary_ports
384// })
385// }
386// }
387
388#[cfg(feature = "portgraph")]
389#[cfg(test)]
390mod tests {
391 use insta::assert_snapshot;
392 use itertools::Itertools;
393 use portgraph::{
394 render::DotFormat, LinkMut, LinkView, NodeIndex, PortGraph, PortMut, PortOffset, PortView,
395 };
396 use rstest::rstest;
397
398 use crate::{
399 port::Port,
400 port_diff::tests::{parent_child_diffs, TestPortDiff},
401 Site,
402 };
403
404 use super::*;
405
406 #[ignore = "TODO this is currently not deterministic"]
407 #[rstest]
408 fn test_rewrite(parent_child_diffs: [TestPortDiff; 2]) {
409 let [parent, _] = parent_child_diffs;
410 let rewrite = |v| {
411 let mut rhs = PortGraph::new();
412 let n0 = rhs.add_node(0, 4);
413 let n1 = rhs.add_node(1, 0);
414 rhs.link_nodes(n0, 3, n1, 0).unwrap();
415 parent.rewrite_induced(&BTreeSet::from_iter([v]), rhs, |p| {
416 let offset = Owned::new(p, parent.clone()).site().unwrap().port;
417 Site {
418 node: n0,
419 port: offset,
420 }
421 .into()
422 })
423 };
424 let (_, n1, n2, _) = PortView::nodes_iter(&parent.graph).collect_tuple().unwrap();
425 let child_a = rewrite(n1).unwrap();
426 let child_b = rewrite(n2).unwrap();
427
428 let pg: PortGraph =
429 PortDiff::extract_graph([child_a.clone(), child_b.clone()].to_vec()).unwrap();
430 assert_eq!(pg.node_count(), 6);
431 assert_eq!(pg.link_count(), 3 + 3 + 1 + 2);
432
433 // Now rewrite across child_a and child_b
434 let mut rhs = PortGraph::new();
435 let n0 = rhs.add_node(0, 2);
436 let n1 = rhs.add_node(2, 0);
437 rhs.link_nodes(n0, 0, n1, 0).unwrap();
438 rhs.link_nodes(n0, 1, n1, 1).unwrap();
439
440 let child_a_out0 = child_a
441 .boundary_iter()
442 .find(|&bd| {
443 child_a.boundary_site(bd).try_as_site_ref().unwrap().port == PortOffset::Outgoing(0)
444 })
445 .unwrap();
446 let child_b_in0 = child_b
447 .boundary_iter()
448 .find(|&bd| {
449 child_b.boundary_site(bd).try_as_site_ref().unwrap().port == PortOffset::Incoming(0)
450 })
451 .unwrap();
452 let cross_edge = (
453 Owned::new(Port::Boundary(child_a_out0), child_a.clone()),
454 Owned::new(Port::Boundary(child_b_in0), child_b.clone()),
455 );
456
457 let nodes = BTreeSet::from_iter([
458 Owned::new(n0, child_a.clone()),
459 Owned::new(n0, child_b.clone()),
460 ]);
461 let merged = PortDiff::rewrite(nodes, [cross_edge], rhs, |n| {
462 if n.owner == child_a {
463 Site {
464 node: n0,
465 port: n.site().unwrap().port,
466 }
467 .into()
468 } else {
469 Site {
470 node: n1,
471 port: n.site().unwrap().port,
472 }
473 .into()
474 }
475 })
476 .unwrap();
477 let pg: PortGraph = PortDiff::extract_graph([merged].to_vec()).unwrap();
478 assert_snapshot!("extracted_graph_2", pg.dot_string());
479 }
480
481 #[rstest]
482 fn test_rewrite_empty(parent_child_diffs: [TestPortDiff; 2]) {
483 let [parent, _] = parent_child_diffs;
484 // a -- -- d a --- d
485 // Rewrite a -- b -- c -- d => a --- d
486 // a -- -- d a --- d
487 let rewritten = parent
488 .rewrite_induced(
489 &BTreeSet::from([NodeIndex::new(1), NodeIndex::new(2)]),
490 PortGraph::new(),
491 |p| {
492 let Port::Bound(BoundPort { edge, end }) = p else {
493 panic!("expected bound port")
494 };
495 BoundarySite::Wire {
496 id: edge.out_offset().index(),
497 end,
498 }
499 },
500 )
501 .unwrap();
502 let g = PortDiff::extract_graph([rewritten].to_vec()).unwrap();
503 assert_eq!(g.node_count(), 2);
504 assert_eq!(g.link_count(), 3);
505 }
506}