1use std::collections::{HashMap, HashSet, VecDeque};
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::core::HugrNode;
9use crate::hugr::hugrmut::InsertionResult;
10use crate::hugr::views::check_valid_non_entrypoint;
11use crate::hugr::HugrMut;
12use crate::ops::{OpTag, OpTrait};
13use crate::types::EdgeKind;
14use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort};
15
16use super::{PatchHugrMut, PatchVerification};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct NewEdgeSpec<SrcNode, TgtNode> {
21 pub src: SrcNode,
25 pub tgt: TgtNode,
29 pub kind: NewEdgeKind,
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum NewEdgeKind {
36 Order,
38 Value {
40 src_pos: OutgoingPort,
42 tgt_pos: IncomingPort,
44 },
45 Static {
47 src_pos: OutgoingPort,
49 tgt_pos: IncomingPort,
51 },
52 ControlFlow {
54 src_pos: OutgoingPort,
56 },
57}
58
59#[derive(Debug, Clone, PartialEq)]
61pub struct Replacement<HostNode = Node> {
62 pub removal: Vec<HostNode>,
68 pub replacement: Hugr,
72 pub adoptions: HashMap<Node, HostNode>,
81 pub mu_inp: Vec<NewEdgeSpec<HostNode, Node>>,
85 pub mu_out: Vec<NewEdgeSpec<Node, HostNode>>,
88 pub mu_new: Vec<NewEdgeSpec<HostNode, HostNode>>,
94}
95
96impl<SrcNode: Copy, TgtNode: Copy> NewEdgeSpec<SrcNode, TgtNode> {
97 fn check_src<HostNode>(
98 &self,
99 h: &impl HugrView<Node = SrcNode>,
100 err_spec: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
101 ) -> Result<(), ReplaceError<HostNode>> {
102 let optype = h.get_optype(self.src);
103 let ok = match self.kind {
104 NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder),
105 NewEdgeKind::Value { src_pos, .. } => {
106 matches!(optype.port_kind(src_pos), Some(EdgeKind::Value(_)))
107 }
108 NewEdgeKind::Static { src_pos, .. } => optype
109 .port_kind(src_pos)
110 .as_ref()
111 .is_some_and(EdgeKind::is_static),
112 NewEdgeKind::ControlFlow { src_pos } => {
113 matches!(optype.port_kind(src_pos), Some(EdgeKind::ControlFlow))
114 }
115 };
116 ok.then_some(())
117 .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec(self.clone())))
118 }
119
120 fn check_tgt<HostNode>(
121 &self,
122 h: &impl HugrView<Node = TgtNode>,
123 err_spec: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
124 ) -> Result<(), ReplaceError<HostNode>> {
125 let optype = h.get_optype(self.tgt);
126 let ok = match self.kind {
127 NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder),
128 NewEdgeKind::Value { tgt_pos, .. } => {
129 matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Value(_)))
130 }
131 NewEdgeKind::Static { tgt_pos, .. } => optype
132 .port_kind(tgt_pos)
133 .as_ref()
134 .is_some_and(EdgeKind::is_static),
135 NewEdgeKind::ControlFlow { .. } => matches!(
136 optype.port_kind(IncomingPort::from(0)),
137 Some(EdgeKind::ControlFlow)
138 ),
139 };
140 ok.then_some(())
141 .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec(self.clone())))
142 }
143}
144
145impl<HostNode: HugrNode, N: Clone> NewEdgeSpec<N, HostNode> {
146 fn check_existing_edge(
147 &self,
148 h: &impl HugrView<Node = HostNode>,
149 legal_src_ancestors: &HashSet<HostNode>,
150 err_edge: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
151 ) -> Result<(), ReplaceError<HostNode>> {
152 if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind
153 {
154 let descends_from_legal = |mut descendant: HostNode| -> bool {
155 while !legal_src_ancestors.contains(&descendant) {
156 let Some(p) = h.get_parent(descendant) else {
157 return false;
158 };
159 descendant = p;
160 }
161 true
162 };
163 let found_incoming = h
164 .single_linked_output(self.tgt, tgt_pos)
165 .is_some_and(|(src_n, _)| descends_from_legal(src_n));
166 if !found_incoming {
167 return Err(ReplaceError::NoRemovedEdge(err_edge(self.clone())));
168 };
169 };
170 Ok(())
171 }
172}
173
174impl<HostNode: HugrNode> Replacement<HostNode> {
175 fn check_parent(
176 &self,
177 h: &impl HugrView<Node = HostNode>,
178 ) -> Result<HostNode, ReplaceError<HostNode>> {
179 let parent = self
180 .removal
181 .iter()
182 .map(|n| h.get_parent(*n))
183 .unique()
184 .exactly_one()
185 .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))?
186 .ok_or(ReplaceError::CantReplaceRoot)?; let removed = h.get_optype(parent).tag();
192 let replacement = self.replacement.entrypoint_optype().tag();
193 if removed != replacement {
194 return Err(ReplaceError::WrongRootNodeTag {
195 removed,
196 replacement,
197 });
198 };
199 Ok(parent)
200 }
201
202 fn get_removed_nodes(
203 &self,
204 h: &impl HugrView<Node = HostNode>,
205 ) -> Result<HashSet<HostNode>, ReplaceError<HostNode>> {
206 self.adoptions.keys().try_for_each(|&n| {
208 (self.replacement.contains_node(n)
209 && self.replacement.get_optype(n).is_container()
210 && self.replacement.children(n).next().is_none())
211 .then_some(())
212 .ok_or(ReplaceError::InvalidAdoptingParent(n))
213 })?;
214 let mut transferred: HashSet<HostNode> = self.adoptions.values().copied().collect();
215 if transferred.len() != self.adoptions.values().len() {
216 return Err(ReplaceError::AdopteesNotSeparateDescendants(
217 self.adoptions
218 .values()
219 .filter(|v| !transferred.remove(v))
220 .copied()
221 .collect(),
222 ));
223 }
224
225 let mut removed = HashSet::new();
226 let mut queue = VecDeque::from_iter(self.removal.iter().copied());
227 while let Some(n) = queue.pop_front() {
228 let new = removed.insert(n);
229 debug_assert!(new); if !transferred.remove(&n) {
231 h.children(n).for_each(|ch| queue.push_back(ch))
232 }
233 }
234 if !transferred.is_empty() {
235 return Err(ReplaceError::AdopteesNotSeparateDescendants(
236 transferred.into_iter().collect(),
237 ));
238 }
239 Ok(removed)
240 }
241}
242
243impl<HostNode: HugrNode> PatchVerification for Replacement<HostNode> {
244 type Error = ReplaceError<HostNode>;
245 type Node = HostNode;
246
247 fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), Self::Error> {
248 self.check_parent(h)?;
249 let removed = self.get_removed_nodes(h)?;
250 for e in self.mu_inp.iter() {
252 if !h.contains_node(e.src) || removed.contains(&e.src) {
253 return Err(ReplaceError::BadEdgeSpec(
254 Direction::Outgoing,
255 WhichEdgeSpec::HostToRepl(e.clone()),
256 ));
257 }
258 e.check_src(h, WhichEdgeSpec::HostToRepl)?;
259 }
260 for e in self.mu_new.iter() {
261 if !h.contains_node(e.src) || removed.contains(&e.src) {
262 return Err(ReplaceError::BadEdgeSpec(
263 Direction::Outgoing,
264 WhichEdgeSpec::HostToHost(e.clone()),
265 ));
266 }
267 e.check_src(h, WhichEdgeSpec::HostToHost)?;
268 }
269 self.mu_out.iter().try_for_each(|e| {
270 match check_valid_non_entrypoint(&self.replacement, e.src) {
271 true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost),
272 false => Err(ReplaceError::BadEdgeSpec(
273 Direction::Outgoing,
274 WhichEdgeSpec::ReplToHost(e.clone()),
275 )),
276 }
277 })?;
278 self.mu_inp.iter().try_for_each(|e| {
280 match check_valid_non_entrypoint(&self.replacement, e.tgt) {
281 true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl),
282 false => Err(ReplaceError::BadEdgeSpec(
283 Direction::Incoming,
284 WhichEdgeSpec::HostToRepl(e.clone()),
285 )),
286 }
287 })?;
288 for e in self.mu_out.iter() {
289 if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
290 return Err(ReplaceError::BadEdgeSpec(
291 Direction::Incoming,
292 WhichEdgeSpec::ReplToHost(e.clone()),
293 ));
294 }
295 e.check_tgt(h, WhichEdgeSpec::ReplToHost)?;
296 e.check_existing_edge(h, &removed, WhichEdgeSpec::ReplToHost)?;
302 }
303 for e in self.mu_new.iter() {
304 if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
305 return Err(ReplaceError::BadEdgeSpec(
306 Direction::Incoming,
307 WhichEdgeSpec::HostToHost(e.clone()),
308 ));
309 }
310 e.check_tgt(h, WhichEdgeSpec::HostToHost)?;
311 e.check_existing_edge(h, &removed, WhichEdgeSpec::HostToHost)?;
317 }
318 Ok(())
319 }
320
321 fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
322 self.removal.iter().copied()
323 }
324}
325
326impl<HostNode: HugrNode> PatchHugrMut for Replacement<HostNode> {
327 type Outcome = HashMap<Node, HostNode>;
329
330 const UNCHANGED_ON_FAILURE: bool = false;
331
332 fn apply_hugr_mut(
333 self,
334 h: &mut impl HugrMut<Node = HostNode>,
335 ) -> Result<Self::Outcome, Self::Error> {
336 let parent = self.check_parent(h)?;
337 let to_remove = self.get_removed_nodes(h)?;
341
342 let InsertionResult {
346 inserted_entrypoint,
347 node_map,
348 } = h.insert_hugr(parent, self.replacement);
349
350 let translate_idx = |n| node_map.get(&n).copied();
352 let kept = |n| (!to_remove.contains(&n)).then_some(n);
353 transfer_edges(
354 h,
355 self.mu_inp.iter(),
356 kept,
357 translate_idx,
358 WhichEdgeSpec::HostToRepl,
359 None,
360 )?;
361
362 transfer_edges(
365 h,
366 self.mu_out.iter(),
367 translate_idx,
368 kept,
369 WhichEdgeSpec::ReplToHost,
370 Some(&to_remove),
371 )?;
372
373 transfer_edges(
376 h,
377 self.mu_new.iter(),
378 kept,
379 kept,
380 WhichEdgeSpec::HostToHost,
381 Some(&to_remove),
382 )?;
383
384 let mut remove_top_sibs = self.removal.iter();
387 for new_node in h
388 .children(inserted_entrypoint)
389 .collect::<Vec<HostNode>>()
390 .into_iter()
391 {
392 if let Some(top_sib) = remove_top_sibs.next() {
393 h.move_before_sibling(new_node, *top_sib);
394 } else {
395 h.set_parent(new_node, parent);
396 }
397 }
398 debug_assert!(h.children(inserted_entrypoint).next().is_none());
399 h.remove_node(inserted_entrypoint);
400
401 for (new_parent, &old_parent) in self.adoptions.iter() {
403 let new_parent = node_map.get(new_parent).unwrap();
404 debug_assert!(h.children(old_parent).next().is_some());
405 while let Some(ch) = h.first_child(old_parent) {
406 h.set_parent(ch, *new_parent);
407 }
408 }
409
410 to_remove.into_iter().for_each(|n| {
412 h.remove_node(n);
413 });
414 Ok(node_map)
415 }
416}
417
418fn transfer_edges<'a, SrcNode, TgtNode, HostNode>(
419 h: &mut impl HugrMut<Node = HostNode>,
420 edges: impl Iterator<Item = &'a NewEdgeSpec<SrcNode, TgtNode>>,
421 trans_src: impl Fn(SrcNode) -> Option<HostNode>,
422 trans_tgt: impl Fn(TgtNode) -> Option<HostNode>,
423 err_spec: impl Fn(NewEdgeSpec<SrcNode, TgtNode>) -> WhichEdgeSpec<HostNode>,
424 legal_src_ancestors: Option<&HashSet<HostNode>>,
425) -> Result<(), ReplaceError<HostNode>>
426where
427 SrcNode: 'a + HugrNode,
428 TgtNode: 'a + HugrNode,
429 HostNode: 'a + HugrNode,
430{
431 for oe in edges {
432 let err_spec = err_spec(oe.clone());
433 let e = NewEdgeSpec {
434 src: trans_src(oe.src)
436 .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Outgoing, err_spec.clone()))?,
437 tgt: trans_tgt(oe.tgt)
438 .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?,
439 kind: oe.kind,
440 };
441 if !h.contains_node(e.src) {
442 return Err(ReplaceError::BadEdgeSpec(
443 Direction::Outgoing,
444 err_spec.clone(),
445 ));
446 }
447 if !h.contains_node(e.tgt) {
448 return Err(ReplaceError::BadEdgeSpec(
449 Direction::Incoming,
450 err_spec.clone(),
451 ));
452 };
453 let err_spec = |_| err_spec.clone();
454 e.check_src(h, err_spec)?;
455 e.check_tgt(h, err_spec)?;
456 match e.kind {
457 NewEdgeKind::Order => {
458 h.add_other_edge(e.src, e.tgt);
459 }
460 NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => {
461 if let Some(legal_src_ancestors) = legal_src_ancestors {
462 e.check_existing_edge(h, legal_src_ancestors, err_spec)?;
463 h.disconnect(e.tgt, tgt_pos);
464 }
465 h.connect(e.src, src_pos, e.tgt, tgt_pos);
466 }
467 NewEdgeKind::ControlFlow { src_pos } => h.connect(e.src, src_pos, e.tgt, 0),
468 }
469 }
470 Ok(())
471}
472
473#[derive(Clone, Debug, PartialEq, Eq, Error)]
475#[non_exhaustive]
476pub enum ReplaceError<HostNode = Node> {
477 #[error("Cannot replace the root node of the Hugr")]
480 CantReplaceRoot,
481 #[error("Removed nodes had different parents {0:?}")]
483 MultipleParents(Vec<HostNode>),
484 #[error("Expected replacement root with tag {removed} but found {replacement}")]
486 WrongRootNodeTag {
487 removed: OpTag,
489 replacement: OpTag,
491 },
492 #[error("Node {0} was not an empty container node in the replacement")]
495 InvalidAdoptingParent(Node),
496 #[error("Nodes not free to be moved into new locations: {0:?}")]
500 AdopteesNotSeparateDescendants(Vec<HostNode>),
501 #[error("{0:?} end of edge {1:?} not found in {which_hugr}", which_hugr = .1.which_hugr(*.0))]
503 BadEdgeSpec(Direction, WhichEdgeSpec<HostNode>),
504 #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")]
507 NoRemovedEdge(WhichEdgeSpec<HostNode>),
508 #[error("The edge kind was not applicable to the {0:?} node: {1:?}")]
510 BadEdgeKind(Direction, WhichEdgeSpec<HostNode>),
511}
512
513#[derive(Clone, Debug, PartialEq, Eq)]
515pub enum WhichEdgeSpec<HostNode> {
516 HostToRepl(NewEdgeSpec<HostNode, Node>),
519 ReplToHost(NewEdgeSpec<Node, HostNode>),
521 HostToHost(NewEdgeSpec<HostNode, HostNode>),
524}
525
526impl<HostNode> WhichEdgeSpec<HostNode> {
527 fn which_hugr(&self, d: Direction) -> &str {
528 match (self, d) {
529 (Self::HostToRepl(_), Direction::Incoming)
530 | (Self::ReplToHost(_), Direction::Outgoing) => "replacement Hugr",
531 _ => "retained portion of Hugr",
532 }
533 }
534}
535
536#[cfg(test)]
537mod test {
538 use std::collections::HashMap;
539
540 use cool_asserts::assert_matches;
541 use itertools::Itertools;
542
543 use crate::builder::{
544 endo_sig, BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr,
545 DataflowSubContainer, HugrBuilder, SubContainer,
546 };
547 use crate::extension::prelude::{bool_t, usize_t};
548 use crate::extension::{ExtensionRegistry, PRELUDE};
549 use crate::hugr::internal::HugrMutInternals;
550 use crate::hugr::patch::PatchVerification;
551 use crate::hugr::{HugrMut, Patch};
552 use crate::ops::custom::ExtensionOp;
553 use crate::ops::dataflow::DataflowOpTrait;
554 use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
555 use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
556 use crate::std_extensions::collections::list;
557 use crate::types::{Signature, Type, TypeRow};
558 use crate::utils::{depth, test_quantum_extension};
559 use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort};
560
561 use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement, WhichEdgeSpec};
562
563 #[test]
564 #[ignore] fn cfg() -> Result<(), Box<dyn std::error::Error>> {
566 let reg = ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]);
567 reg.validate()?;
568 let listy = list::list_type(usize_t());
569 let pop: ExtensionOp = list::ListOp::pop
570 .with_type(usize_t())
571 .to_extension_op()
572 .unwrap();
573 let push: ExtensionOp = list::ListOp::push
574 .with_type(usize_t())
575 .to_extension_op()
576 .unwrap();
577 let just_list = TypeRow::from(vec![listy.clone()]);
578 let intermed = TypeRow::from(vec![listy.clone(), usize_t()]);
579
580 let mut cfg = CFGBuilder::new(endo_sig(just_list.clone()))?;
581
582 let pred_const = cfg.add_constant(ops::Value::unary_unit_sum());
583
584 let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
585 let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
586
587 let exit = cfg.exit_block();
588 cfg.branch(&entry, 0, &bb2)?;
589 cfg.branch(&bb2, 0, &exit)?;
590
591 let mut h = cfg.finish_hugr().unwrap();
592 {
593 let pop = find_node(&h, "pop");
594 let push = find_node(&h, "push");
595 assert_eq!(depth(&h, pop), 2); assert_eq!(depth(&h, push), 2);
597
598 let popp = h.get_parent(pop).unwrap();
599 let pushp = h.get_parent(push).unwrap();
600 assert_ne!(popp, pushp); assert!(h.get_optype(popp).is_dataflow_block());
602 assert!(h.get_optype(pushp).is_dataflow_block());
603
604 assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
605 }
606
607 let mut replacement = Hugr::new_with_entrypoint(ops::CFG {
611 signature: Signature::new_endo(just_list.clone()),
612 })
613 .expect("CFG is a valid entrypoint");
614 let r_bb = replacement.add_node_with_parent(
615 replacement.entrypoint(),
616 DataflowBlock {
617 inputs: vec![listy.clone()].into(),
618 sum_rows: vec![type_row![]],
619 other_outputs: vec![listy.clone()].into(),
620 },
621 );
622 let r_df1 = replacement.add_node_with_parent(
623 r_bb,
624 DFG {
625 signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())),
626 },
627 );
628 let r_df2 = replacement.add_node_with_parent(
629 r_bb,
630 DFG {
631 signature: Signature::new(intermed, simple_unary_plus(just_list.clone())),
632 },
633 );
634 [0, 1]
635 .iter()
636 .for_each(|p| replacement.connect(r_df1, *p + 1, r_df2, *p));
637
638 {
639 let inp = replacement.add_node_before(
640 r_df1,
641 ops::Input {
642 types: just_list.clone(),
643 },
644 );
645 let out = replacement.add_node_before(
646 r_df1,
647 ops::Output {
648 types: simple_unary_plus(just_list),
649 },
650 );
651 replacement.connect(inp, 0, r_df1, 0);
652 replacement.connect(r_df2, 0, out, 0);
653 replacement.connect(r_df2, 1, out, 1);
654 }
655
656 h.apply_patch(Replacement {
657 removal: vec![entry.node(), bb2.node()],
658 replacement,
659 adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]),
660 mu_inp: vec![],
661 mu_out: vec![NewEdgeSpec {
662 src: r_bb,
663 tgt: exit.node(),
664 kind: NewEdgeKind::ControlFlow {
665 src_pos: OutgoingPort::from(0),
666 },
667 }],
668 mu_new: vec![],
669 })?;
670 h.validate()?;
671 {
672 let pop = find_node(&h, "pop");
673 let push = find_node(&h, "push");
674 assert_eq!(depth(&h, pop), 3); assert_eq!(depth(&h, push), 3);
676
677 let popp = h.get_parent(pop).unwrap();
678 let pushp = h.get_parent(push).unwrap();
679 assert_ne!(popp, pushp); assert!(h.get_optype(popp).is_dfg());
681 assert!(h.get_optype(pushp).is_dfg());
682
683 let grandp = h.get_parent(popp).unwrap();
684 assert_eq!(grandp, h.get_parent(pushp).unwrap());
685 assert!(h.get_optype(grandp).is_dataflow_block());
686 }
687
688 Ok(())
689 }
690
691 fn find_node(h: &Hugr, s: &str) -> crate::Node {
692 h.entry_descendants()
693 .filter(|n| format!("{}", h.get_optype(*n)).contains(s))
694 .exactly_one()
695 .ok()
696 .unwrap()
697 }
698
699 fn single_node_block<T: AsRef<Hugr> + AsMut<Hugr>, O: DataflowOpTrait + Into<OpType>>(
700 h: &mut CFGBuilder<T>,
701 op: O,
702 pred_const: &ConstID,
703 entry: bool,
704 ) -> Result<BasicBlockID, BuildError> {
705 let op_sig = op.signature();
706 let mut bb = if entry {
707 assert_eq!(
708 match h.hugr().get_optype(h.container_node()) {
709 OpType::CFG(c) => &c.signature.input,
710 _ => panic!(),
711 },
712 op_sig.input()
713 );
714 h.simple_entry_builder(op_sig.output.clone(), 1)?
715 } else {
716 h.simple_block_builder(op_sig.into_owned(), 1)?
717 };
718 let op: OpType = op.into();
719 let op = bb.add_dataflow_op(op, bb.input_wires())?;
720 let load_pred = bb.load_const(pred_const);
721 bb.finish_with_outputs(load_pred, op.outputs())
722 }
723
724 fn simple_unary_plus(t: TypeRow) -> TypeRow {
725 let mut v = t.into_owned();
726 v.insert(0, Type::new_unit_sum(1));
727 v.into()
728 }
729
730 #[test]
731 fn test_invalid() {
732 let utou = Signature::new_endo(vec![usize_t()]);
733 let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| {
734 ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref)
735 .unwrap();
736 ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref)
737 .unwrap();
738 ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref)
739 .unwrap();
740 });
741 let foo = ext.instantiate_extension_op("foo", []).unwrap();
742 let bar = ext.instantiate_extension_op("bar", []).unwrap();
743 let baz = ext.instantiate_extension_op("baz", []).unwrap();
744 let mut registry = test_quantum_extension::REG.clone();
745 registry.register(ext).unwrap();
746
747 let mut h =
748 DFGBuilder::new(Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])).unwrap();
749 let [i, b] = h.input_wires_arr();
750 let mut cond = h
751 .conditional_builder(
752 (vec![type_row![]; 2], b),
753 [(usize_t(), i)],
754 vec![usize_t()].into(),
755 )
756 .unwrap();
757 let mut case1 = cond.case_builder(0).unwrap();
758 let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap();
759 let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
760 let mut case2 = cond.case_builder(1).unwrap();
761 let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap();
762 let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs()).unwrap();
763 let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap();
764 let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
765 let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
766 let cond = cond.finish_sub_container().unwrap();
767 let h = h.finish_hugr_with_outputs(cond.outputs()).unwrap();
768
769 let mut r_hugr = Hugr::new_with_entrypoint(h.get_optype(cond.node()).clone()).unwrap();
770 let r1 = r_hugr.add_node_with_parent(
771 r_hugr.entrypoint(),
772 Case {
773 signature: utou.clone(),
774 },
775 );
776 let r2 = r_hugr.add_node_with_parent(
777 r_hugr.entrypoint(),
778 Case {
779 signature: utou.clone(),
780 },
781 );
782 let rep: Replacement = Replacement {
783 removal: vec![case1, case2],
784 replacement: r_hugr,
785 adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
786 mu_inp: vec![],
787 mu_out: vec![],
788 mu_new: vec![],
789 };
790 assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
791 rep.verify(&h).unwrap();
792 {
793 let mut target = h.clone();
794 let node_map = rep.clone().apply(&mut target).unwrap();
795 let new_case2 = *node_map.get(&r2).unwrap();
796 assert_eq!(target.get_parent(baz.node()), Some(new_case2));
797 }
798
799 let check_same_errors = |r: Replacement| {
801 let verify_res = r.verify(&h).unwrap_err();
802 let apply_res = r.apply(&mut h.clone()).unwrap_err();
803 assert_eq!(verify_res, apply_res);
804 apply_res
805 };
806 let mut rep2 = rep.clone();
808 rep2.replacement
809 .replace_op(rep2.replacement.entrypoint(), h.entrypoint_optype().clone());
810 assert_eq!(
811 check_same_errors(rep2),
812 ReplaceError::WrongRootNodeTag {
813 removed: OpTag::Conditional,
814 replacement: OpTag::Dfg
815 }
816 );
817 assert_eq!(
819 check_same_errors(Replacement {
820 removal: vec![h.module_root()],
821 ..rep.clone()
822 }),
823 ReplaceError::CantReplaceRoot
824 );
825 assert_eq!(
826 check_same_errors(Replacement {
827 removal: vec![case1, baz_dfg.node()],
828 ..rep.clone()
829 }),
830 ReplaceError::MultipleParents(vec![cond.node(), case2])
831 );
832 assert_eq!(
834 check_same_errors(Replacement {
835 adoptions: HashMap::from([(r1, case1), (rep.replacement.entrypoint(), case2)]),
836 ..rep.clone()
837 }),
838 ReplaceError::InvalidAdoptingParent(rep.replacement.entrypoint())
839 );
840 assert_eq!(
841 check_same_errors(Replacement {
842 adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
843 ..rep.clone()
844 }),
845 ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
846 );
847 assert_eq!(
848 check_same_errors(Replacement {
849 adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
850 ..rep.clone()
851 }),
852 ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
853 );
854 let edge_from_removed = NewEdgeSpec {
856 src: case1,
857 tgt: r2,
858 kind: NewEdgeKind::Order,
859 };
860 assert_eq!(
861 check_same_errors(Replacement {
862 mu_inp: vec![edge_from_removed.clone()],
863 ..rep.clone()
864 }),
865 ReplaceError::BadEdgeSpec(
866 Direction::Outgoing,
867 WhichEdgeSpec::HostToRepl(edge_from_removed)
868 )
869 );
870 let bad_out_edge = NewEdgeSpec {
871 src: h.nodes().max().unwrap(), tgt: cond.node(),
873 kind: NewEdgeKind::Order,
874 };
875 assert_eq!(
876 check_same_errors(Replacement {
877 mu_out: vec![bad_out_edge.clone()],
878 ..rep.clone()
879 }),
880 ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichEdgeSpec::ReplToHost(bad_out_edge),)
881 );
882 let bad_order_edge = NewEdgeSpec {
883 src: cond.node(),
884 tgt: h.get_io(h.entrypoint()).unwrap()[1],
885 kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
886 };
887 assert_matches!(
888 check_same_errors(Replacement {
889 mu_new: vec![bad_order_edge.clone()],
890 ..rep.clone()
891 }),
892 ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, WhichEdgeSpec::HostToHost(bad_order_edge))
893 );
894 let op = OutgoingPort::from(0);
895 let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
896 let new_out_edge = NewEdgeSpec {
897 src: r1.node(),
898 tgt,
899 kind: NewEdgeKind::Value {
900 src_pos: op,
901 tgt_pos: ip,
902 },
903 };
904 assert_eq!(
905 check_same_errors(Replacement {
906 mu_out: vec![new_out_edge.clone()],
907 ..rep.clone()
908 }),
909 ReplaceError::BadEdgeKind(Direction::Outgoing, WhichEdgeSpec::ReplToHost(new_out_edge))
910 );
911 }
912}