1use std::collections::{HashMap, HashSet, VecDeque};
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::hugr::hugrmut::InsertionResult;
9use crate::hugr::HugrMut;
10use crate::ops::{OpTag, OpTrait};
11use crate::types::EdgeKind;
12use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort};
13
14use super::Rewrite;
15
16#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct NewEdgeSpec {
19 pub src: Node,
22 pub tgt: Node,
25 pub kind: NewEdgeKind,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
31pub enum NewEdgeKind {
32 Order,
34 Value {
36 src_pos: OutgoingPort,
38 tgt_pos: IncomingPort,
40 },
41 Static {
43 src_pos: OutgoingPort,
45 tgt_pos: IncomingPort,
47 },
48 ControlFlow {
50 src_pos: OutgoingPort,
52 },
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub struct Replacement {
58 pub removal: Vec<Node>,
63 pub replacement: Hugr,
66 pub adoptions: HashMap<Node, Node>,
73 pub mu_inp: Vec<NewEdgeSpec>,
76 pub mu_out: Vec<NewEdgeSpec>,
79 pub mu_new: Vec<NewEdgeSpec>,
83}
84
85impl NewEdgeSpec {
86 fn check_src(
87 &self,
88 h: &impl HugrView<Node = Node>,
89 err_spec: &NewEdgeSpec,
90 ) -> Result<(), ReplaceError> {
91 let optype = h.get_optype(self.src);
92 let ok = match self.kind {
93 NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder),
94 NewEdgeKind::Value { src_pos, .. } => {
95 matches!(optype.port_kind(src_pos), Some(EdgeKind::Value(_)))
96 }
97 NewEdgeKind::Static { src_pos, .. } => optype
98 .port_kind(src_pos)
99 .as_ref()
100 .is_some_and(EdgeKind::is_static),
101 NewEdgeKind::ControlFlow { src_pos } => {
102 matches!(optype.port_kind(src_pos), Some(EdgeKind::ControlFlow))
103 }
104 };
105 ok.then_some(())
106 .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone()))
107 }
108 fn check_tgt(
109 &self,
110 h: &impl HugrView<Node = Node>,
111 err_spec: &NewEdgeSpec,
112 ) -> Result<(), ReplaceError> {
113 let optype = h.get_optype(self.tgt);
114 let ok = match self.kind {
115 NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder),
116 NewEdgeKind::Value { tgt_pos, .. } => {
117 matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Value(_)))
118 }
119 NewEdgeKind::Static { tgt_pos, .. } => optype
120 .port_kind(tgt_pos)
121 .as_ref()
122 .is_some_and(EdgeKind::is_static),
123 NewEdgeKind::ControlFlow { .. } => matches!(
124 optype.port_kind(IncomingPort::from(0)),
125 Some(EdgeKind::ControlFlow)
126 ),
127 };
128 ok.then_some(())
129 .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone()))
130 }
131
132 fn check_existing_edge(
133 &self,
134 h: &impl HugrView<Node = Node>,
135 legal_src_ancestors: &HashSet<Node>,
136 err_edge: impl Fn() -> NewEdgeSpec,
137 ) -> Result<(), ReplaceError> {
138 if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind
139 {
140 let descends_from_legal = |mut descendant: Node| -> bool {
141 while !legal_src_ancestors.contains(&descendant) {
142 let Some(p) = h.get_parent(descendant) else {
143 return false;
144 };
145 descendant = p;
146 }
147 true
148 };
149 let found_incoming = h
150 .single_linked_output(self.tgt, tgt_pos)
151 .is_some_and(|(src_n, _)| descends_from_legal(src_n));
152 if !found_incoming {
153 return Err(ReplaceError::NoRemovedEdge(err_edge()));
154 };
155 };
156 Ok(())
157 }
158}
159
160impl Replacement {
161 fn check_parent(&self, h: &impl HugrView<Node = Node>) -> Result<Node, ReplaceError> {
162 let parent = self
163 .removal
164 .iter()
165 .map(|n| h.get_parent(*n))
166 .unique()
167 .exactly_one()
168 .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))?
169 .ok_or(ReplaceError::CantReplaceRoot)?; let removed = h.get_optype(parent).tag();
174 let replacement = self.replacement.root_type().tag();
175 if removed != replacement {
176 return Err(ReplaceError::WrongRootNodeTag {
177 removed,
178 replacement,
179 });
180 };
181 Ok(parent)
182 }
183
184 fn get_removed_nodes(
185 &self,
186 h: &impl HugrView<Node = Node>,
187 ) -> Result<HashSet<Node>, ReplaceError> {
188 self.adoptions.keys().try_for_each(|&n| {
190 (self.replacement.contains_node(n)
191 && self.replacement.get_optype(n).is_container()
192 && self.replacement.children(n).next().is_none())
193 .then_some(())
194 .ok_or(ReplaceError::InvalidAdoptingParent(n))
195 })?;
196 let mut transferred: HashSet<Node> = self.adoptions.values().copied().collect();
197 if transferred.len() != self.adoptions.values().len() {
198 return Err(ReplaceError::AdopteesNotSeparateDescendants(
199 self.adoptions
200 .values()
201 .filter(|v| !transferred.remove(v))
202 .copied()
203 .collect(),
204 ));
205 }
206
207 let mut removed = HashSet::new();
208 let mut queue = VecDeque::from_iter(self.removal.iter().copied());
209 while let Some(n) = queue.pop_front() {
210 let new = removed.insert(n);
211 debug_assert!(new); if !transferred.remove(&n) {
213 h.children(n).for_each(|ch| queue.push_back(ch))
214 }
215 }
216 if !transferred.is_empty() {
217 return Err(ReplaceError::AdopteesNotSeparateDescendants(
218 transferred.into_iter().collect(),
219 ));
220 }
221 Ok(removed)
222 }
223}
224impl Rewrite for Replacement {
225 type Error = ReplaceError;
226
227 type ApplyResult = HashMap<Node, Node>;
229
230 const UNCHANGED_ON_FAILURE: bool = false;
231
232 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
233 self.check_parent(h)?;
234 let removed = self.get_removed_nodes(h)?;
235 for e in self.mu_inp.iter().chain(self.mu_new.iter()) {
237 if !h.contains_node(e.src) || removed.contains(&e.src) {
238 return Err(ReplaceError::BadEdgeSpec(
239 Direction::Outgoing,
240 WhichHugr::Retained,
241 e.clone(),
242 ));
243 }
244 e.check_src(h, e)?;
245 }
246 self.mu_out
247 .iter()
248 .try_for_each(|e| match self.replacement.valid_non_root(e.src) {
249 true => e.check_src(&self.replacement, e),
250 false => Err(ReplaceError::BadEdgeSpec(
251 Direction::Outgoing,
252 WhichHugr::Replacement,
253 e.clone(),
254 )),
255 })?;
256 self.mu_inp
258 .iter()
259 .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) {
260 true => e.check_tgt(&self.replacement, e),
261 false => Err(ReplaceError::BadEdgeSpec(
262 Direction::Incoming,
263 WhichHugr::Replacement,
264 e.clone(),
265 )),
266 })?;
267 for e in self.mu_out.iter().chain(self.mu_new.iter()) {
268 if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
269 return Err(ReplaceError::BadEdgeSpec(
270 Direction::Incoming,
271 WhichHugr::Retained,
272 e.clone(),
273 ));
274 }
275 e.check_tgt(h, e)?;
276 e.check_existing_edge(h, &removed, || e.clone())?;
281 }
282 Ok(())
283 }
284
285 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
286 let parent = self.check_parent(h)?;
287 let to_remove = self.get_removed_nodes(h)?;
290
291 let InsertionResult { new_root, node_map } = h.insert_hugr(parent, self.replacement);
294
295 let translate_idx = |n| node_map.get(&n).copied().ok_or(WhichHugr::Replacement);
297 let kept = |n| {
298 let keep = !to_remove.contains(&n);
299 keep.then_some(n).ok_or(WhichHugr::Retained)
300 };
301 transfer_edges(h, self.mu_inp.iter(), kept, translate_idx, None)?;
302
303 transfer_edges(h, self.mu_out.iter(), translate_idx, kept, Some(&to_remove))?;
306
307 transfer_edges(h, self.mu_new.iter(), kept, kept, Some(&to_remove))?;
310
311 let mut remove_top_sibs = self.removal.iter();
314 for new_node in h.children(new_root).collect::<Vec<Node>>().into_iter() {
315 if let Some(top_sib) = remove_top_sibs.next() {
316 h.move_before_sibling(new_node, *top_sib);
317 } else {
318 h.set_parent(new_node, parent);
319 }
320 }
321 debug_assert!(h.children(new_root).next().is_none());
322 h.remove_node(new_root);
323
324 for (new_parent, &old_parent) in self.adoptions.iter() {
326 let new_parent = node_map.get(new_parent).unwrap();
327 debug_assert!(h.children(old_parent).next().is_some());
328 while let Some(ch) = h.first_child(old_parent) {
329 h.set_parent(ch, *new_parent);
330 }
331 }
332
333 to_remove.into_iter().for_each(|n| {
335 h.remove_node(n);
336 });
337 Ok(node_map)
338 }
339
340 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
341 self.removal.iter().copied()
342 }
343}
344
345fn transfer_edges<'a>(
346 h: &mut impl HugrMut,
347 edges: impl Iterator<Item = &'a NewEdgeSpec>,
348 trans_src: impl Fn(Node) -> Result<Node, WhichHugr>,
349 trans_tgt: impl Fn(Node) -> Result<Node, WhichHugr>,
350 legal_src_ancestors: Option<&HashSet<Node>>,
351) -> Result<(), ReplaceError> {
352 for oe in edges {
353 let e = NewEdgeSpec {
354 src: trans_src(oe.src)
356 .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Outgoing, h, oe.clone()))?,
357 tgt: trans_tgt(oe.tgt)
358 .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Incoming, h, oe.clone()))?,
359 ..oe.clone()
360 };
361 if !h.valid_node(e.src) {
362 return Err(ReplaceError::BadEdgeSpec(
363 Direction::Outgoing,
364 WhichHugr::Retained,
365 oe.clone(),
366 ));
367 }
368 if !h.valid_node(e.tgt) {
369 return Err(ReplaceError::BadEdgeSpec(
370 Direction::Incoming,
371 WhichHugr::Retained,
372 oe.clone(),
373 ));
374 };
375 e.check_src(h, oe)?;
376 e.check_tgt(h, oe)?;
377 match e.kind {
378 NewEdgeKind::Order => {
379 h.add_other_edge(e.src, e.tgt);
380 }
381 NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => {
382 if let Some(legal_src_ancestors) = legal_src_ancestors {
383 e.check_existing_edge(h, legal_src_ancestors, || oe.clone())?;
384 h.disconnect(e.tgt, tgt_pos);
385 }
386 h.connect(e.src, src_pos, e.tgt, tgt_pos);
387 }
388 NewEdgeKind::ControlFlow { src_pos } => h.connect(e.src, src_pos, e.tgt, 0),
389 }
390 }
391 Ok(())
392}
393
394#[derive(Clone, Debug, PartialEq, Eq, Error)]
396#[non_exhaustive]
397pub enum ReplaceError {
398 #[error("Cannot replace the root node of the Hugr")]
401 CantReplaceRoot,
402 #[error("Removed nodes had different parents {0:?}")]
404 MultipleParents(Vec<Node>),
405 #[error("Expected replacement root with tag {removed} but found {replacement}")]
407 WrongRootNodeTag {
408 removed: OpTag,
410 replacement: OpTag,
412 },
413 #[error("Node {0} was not an empty container node in the replacement")]
415 InvalidAdoptingParent(Node),
416 #[error("Nodes not free to be moved into new locations: {0:?}")]
419 AdopteesNotSeparateDescendants(Vec<Node>),
420 #[error("{0:?} end of edge {2:?} not found in {1}")]
422 BadEdgeSpec(Direction, WhichHugr, NewEdgeSpec),
423 #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")]
425 NoRemovedEdge(NewEdgeSpec),
426 #[error("The edge kind was not applicable to the {0:?} node: {1:?}")]
428 BadEdgeKind(Direction, NewEdgeSpec),
429}
430
431#[derive(Clone, Debug, PartialEq, Eq)]
433pub enum WhichHugr {
434 Replacement,
436 Retained,
439}
440
441impl std::fmt::Display for WhichHugr {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.write_str(match self {
444 Self::Replacement => "replacement Hugr",
445 Self::Retained => "retained portion of Hugr",
446 })
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use std::collections::HashMap;
453
454 use cool_asserts::assert_matches;
455 use itertools::Itertools;
456
457 use crate::builder::{
458 endo_sig, BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr,
459 DataflowSubContainer, HugrBuilder, SubContainer,
460 };
461 use crate::extension::prelude::{bool_t, usize_t};
462 use crate::extension::{ExtensionRegistry, PRELUDE};
463 use crate::hugr::internal::HugrMutInternals;
464 use crate::hugr::rewrite::replace::WhichHugr;
465 use crate::hugr::{HugrMut, Rewrite};
466 use crate::ops::custom::ExtensionOp;
467 use crate::ops::dataflow::DataflowOpTrait;
468 use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
469 use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
470 use crate::std_extensions::collections::list;
471 use crate::types::{Signature, Type, TypeRow};
472 use crate::utils::{depth, test_quantum_extension};
473 use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort};
474
475 use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement};
476
477 #[test]
478 #[ignore] fn cfg() -> Result<(), Box<dyn std::error::Error>> {
480 let reg = ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]);
481 reg.validate()?;
482 let listy = list::list_type(usize_t());
483 let pop: ExtensionOp = list::ListOp::pop
484 .with_type(usize_t())
485 .to_extension_op()
486 .unwrap();
487 let push: ExtensionOp = list::ListOp::push
488 .with_type(usize_t())
489 .to_extension_op()
490 .unwrap();
491 let just_list = TypeRow::from(vec![listy.clone()]);
492 let intermed = TypeRow::from(vec![listy.clone(), usize_t()]);
493
494 let mut cfg = CFGBuilder::new(endo_sig(just_list.clone()))?;
495
496 let pred_const = cfg.add_constant(ops::Value::unary_unit_sum());
497
498 let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
499 let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
500
501 let exit = cfg.exit_block();
502 cfg.branch(&entry, 0, &bb2)?;
503 cfg.branch(&bb2, 0, &exit)?;
504
505 let mut h = cfg.finish_hugr().unwrap();
506 {
507 let pop = find_node(&h, "pop");
508 let push = find_node(&h, "push");
509 assert_eq!(depth(&h, pop), 2); assert_eq!(depth(&h, push), 2);
511
512 let popp = h.get_parent(pop).unwrap();
513 let pushp = h.get_parent(push).unwrap();
514 assert_ne!(popp, pushp); assert!(h.get_optype(popp).is_dataflow_block());
516 assert!(h.get_optype(pushp).is_dataflow_block());
517
518 assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
519 }
520
521 let mut replacement = Hugr::new(ops::CFG {
524 signature: Signature::new_endo(just_list.clone()),
525 });
526 let r_bb = replacement.add_node_with_parent(
527 replacement.root(),
528 DataflowBlock {
529 inputs: vec![listy.clone()].into(),
530 sum_rows: vec![type_row![]],
531 other_outputs: vec![listy.clone()].into(),
532 extension_delta: list::EXTENSION_ID.into(),
533 },
534 );
535 let r_df1 = replacement.add_node_with_parent(
536 r_bb,
537 DFG {
538 signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone()))
539 .with_extension_delta(list::EXTENSION_ID),
540 },
541 );
542 let r_df2 = replacement.add_node_with_parent(
543 r_bb,
544 DFG {
545 signature: Signature::new(intermed, simple_unary_plus(just_list.clone()))
546 .with_extension_delta(list::EXTENSION_ID),
547 },
548 );
549 [0, 1]
550 .iter()
551 .for_each(|p| replacement.connect(r_df1, *p + 1, r_df2, *p));
552
553 {
554 let inp = replacement.add_node_before(
555 r_df1,
556 ops::Input {
557 types: just_list.clone(),
558 },
559 );
560 let out = replacement.add_node_before(
561 r_df1,
562 ops::Output {
563 types: simple_unary_plus(just_list),
564 },
565 );
566 replacement.connect(inp, 0, r_df1, 0);
567 replacement.connect(r_df2, 0, out, 0);
568 replacement.connect(r_df2, 1, out, 1);
569 }
570
571 h.apply_rewrite(Replacement {
572 removal: vec![entry.node(), bb2.node()],
573 replacement,
574 adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]),
575 mu_inp: vec![],
576 mu_out: vec![NewEdgeSpec {
577 src: r_bb,
578 tgt: exit.node(),
579 kind: NewEdgeKind::ControlFlow {
580 src_pos: OutgoingPort::from(0),
581 },
582 }],
583 mu_new: vec![],
584 })?;
585 h.validate()?;
586 {
587 let pop = find_node(&h, "pop");
588 let push = find_node(&h, "push");
589 assert_eq!(depth(&h, pop), 3); assert_eq!(depth(&h, push), 3);
591
592 let popp = h.get_parent(pop).unwrap();
593 let pushp = h.get_parent(push).unwrap();
594 assert_ne!(popp, pushp); assert!(h.get_optype(popp).is_dfg());
596 assert!(h.get_optype(pushp).is_dfg());
597
598 let grandp = h.get_parent(popp).unwrap();
599 assert_eq!(grandp, h.get_parent(pushp).unwrap());
600 assert!(h.get_optype(grandp).is_dataflow_block());
601 }
602
603 Ok(())
604 }
605
606 fn find_node(h: &Hugr, s: &str) -> crate::Node {
607 h.nodes()
608 .filter(|n| format!("{}", h.get_optype(*n)).contains(s))
609 .exactly_one()
610 .ok()
611 .unwrap()
612 }
613
614 fn single_node_block<T: AsRef<Hugr> + AsMut<Hugr>, O: DataflowOpTrait + Into<OpType>>(
615 h: &mut CFGBuilder<T>,
616 op: O,
617 pred_const: &ConstID,
618 entry: bool,
619 ) -> Result<BasicBlockID, BuildError> {
620 let op_sig = op.signature();
621 let mut bb = if entry {
622 assert_eq!(
623 match h.hugr().get_optype(h.container_node()) {
624 OpType::CFG(c) => &c.signature.input,
625 _ => panic!(),
626 },
627 op_sig.input()
628 );
629 h.simple_entry_builder_exts(op_sig.output.clone(), 1, op_sig.runtime_reqs.clone())?
630 } else {
631 h.simple_block_builder(op_sig.into_owned(), 1)?
632 };
633 let op: OpType = op.into();
634 let op = bb.add_dataflow_op(op, bb.input_wires())?;
635 let load_pred = bb.load_const(pred_const);
636 bb.finish_with_outputs(load_pred, op.outputs())
637 }
638
639 fn simple_unary_plus(t: TypeRow) -> TypeRow {
640 let mut v = t.into_owned();
641 v.insert(0, Type::new_unit_sum(1));
642 v.into()
643 }
644
645 #[test]
646 fn test_invalid() {
647 let utou = Signature::new_endo(vec![usize_t()]);
648 let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| {
649 ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref)
650 .unwrap();
651 ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref)
652 .unwrap();
653 ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref)
654 .unwrap();
655 });
656 let ext_name = ext.name().clone();
657 let foo = ext.instantiate_extension_op("foo", []).unwrap();
658 let bar = ext.instantiate_extension_op("bar", []).unwrap();
659 let baz = ext.instantiate_extension_op("baz", []).unwrap();
660 let mut registry = test_quantum_extension::REG.clone();
661 registry.register(ext).unwrap();
662
663 let mut h = DFGBuilder::new(
664 Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])
665 .with_extension_delta(ext_name.clone()),
666 )
667 .unwrap();
668 let [i, b] = h.input_wires_arr();
669 let mut cond = h
670 .conditional_builder_exts(
671 (vec![type_row![]; 2], b),
672 [(usize_t(), i)],
673 vec![usize_t()].into(),
674 ext_name.clone(),
675 )
676 .unwrap();
677 let mut case1 = cond.case_builder(0).unwrap();
678 let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap();
679 let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
680 let mut case2 = cond.case_builder(1).unwrap();
681 let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap();
682 let mut baz_dfg = case2
683 .dfg_builder(
684 utou.clone().with_extension_delta(ext_name.clone()),
685 bar.outputs(),
686 )
687 .unwrap();
688 let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap();
689 let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
690 let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
691 let cond = cond.finish_sub_container().unwrap();
692 let h = h.finish_hugr_with_outputs(cond.outputs()).unwrap();
693
694 let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone());
695 let r1 = r_hugr.add_node_with_parent(
696 r_hugr.root(),
697 Case {
698 signature: utou.clone(),
699 },
700 );
701 let r2 = r_hugr.add_node_with_parent(
702 r_hugr.root(),
703 Case {
704 signature: utou.clone(),
705 },
706 );
707 let rep: Replacement = Replacement {
708 removal: vec![case1, case2],
709 replacement: r_hugr,
710 adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
711 mu_inp: vec![],
712 mu_out: vec![],
713 mu_new: vec![],
714 };
715 assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
716 rep.verify(&h).unwrap();
717 {
718 let mut target = h.clone();
719 let node_map = rep.clone().apply(&mut target).unwrap();
720 let new_case2 = *node_map.get(&r2).unwrap();
721 assert_eq!(target.get_parent(baz.node()), Some(new_case2));
722 }
723
724 let check_same_errors = |r: Replacement| {
726 let verify_res = r.verify(&h).unwrap_err();
727 let apply_res = r.apply(&mut h.clone()).unwrap_err();
728 assert_eq!(verify_res, apply_res);
729 apply_res
730 };
731 let mut rep2 = rep.clone();
733 rep2.replacement
734 .replace_op(rep2.replacement.root(), h.root_type().clone())
735 .unwrap();
736 assert_eq!(
737 check_same_errors(rep2),
738 ReplaceError::WrongRootNodeTag {
739 removed: OpTag::Conditional,
740 replacement: OpTag::Dfg
741 }
742 );
743 assert_eq!(
745 check_same_errors(Replacement {
746 removal: vec![h.root()],
747 ..rep.clone()
748 }),
749 ReplaceError::CantReplaceRoot
750 );
751 assert_eq!(
752 check_same_errors(Replacement {
753 removal: vec![case1, baz_dfg.node()],
754 ..rep.clone()
755 }),
756 ReplaceError::MultipleParents(vec![cond.node(), case2])
757 );
758 assert_eq!(
760 check_same_errors(Replacement {
761 adoptions: HashMap::from([(r1, case1), (rep.replacement.root(), case2)]),
762 ..rep.clone()
763 }),
764 ReplaceError::InvalidAdoptingParent(rep.replacement.root())
765 );
766 assert_eq!(
767 check_same_errors(Replacement {
768 adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
769 ..rep.clone()
770 }),
771 ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
772 );
773 assert_eq!(
774 check_same_errors(Replacement {
775 adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
776 ..rep.clone()
777 }),
778 ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
779 );
780 let edge_from_removed = NewEdgeSpec {
782 src: case1,
783 tgt: r2,
784 kind: NewEdgeKind::Order,
785 };
786 assert_eq!(
787 check_same_errors(Replacement {
788 mu_inp: vec![edge_from_removed.clone()],
789 ..rep.clone()
790 }),
791 ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed)
792 );
793 let bad_out_edge = NewEdgeSpec {
794 src: h.nodes().max().unwrap(), tgt: cond.node(),
796 kind: NewEdgeKind::Order,
797 };
798 assert_eq!(
799 check_same_errors(Replacement {
800 mu_out: vec![bad_out_edge.clone()],
801 ..rep.clone()
802 }),
803 ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge)
804 );
805 let bad_order_edge = NewEdgeSpec {
806 src: cond.node(),
807 tgt: h.get_io(h.root()).unwrap()[1],
808 kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
809 };
810 assert_matches!(
811 check_same_errors(Replacement {
812 mu_new: vec![bad_order_edge.clone()],
813 ..rep.clone()
814 }),
815 ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, bad_order_edge)
816 );
817 let op = OutgoingPort::from(0);
818 let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
819 let new_out_edge = NewEdgeSpec {
820 src: r1.node(),
821 tgt,
822 kind: NewEdgeKind::Value {
823 src_pos: op,
824 tgt_pos: ip,
825 },
826 };
827 assert_eq!(
828 check_same_errors(Replacement {
829 mu_out: vec![new_out_edge.clone()],
830 ..rep.clone()
831 }),
832 ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
833 );
834 }
835}