1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::linking::{HugrLinking, NodeLinkingDirective};
4use crate::hugr::views::HugrView;
5use crate::hugr::{NodeMetadata, ValidationError};
6use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
7use crate::utils::collect_array;
8use crate::{Extension, IncomingPort, Node, OutgoingPort};
9
10use std::collections::HashMap;
11use std::iter;
12use std::sync::Arc;
13
14use super::{BuilderWiringError, ModuleBuilder};
15use super::{
16 CircuitBuilder,
17 handle::{BuildHandle, Outputs},
18};
19
20use crate::{
21 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
22 types::EdgeKind,
23};
24
25use crate::extension::ExtensionRegistry;
26use crate::types::{Signature, Type, TypeArg, TypeRow};
27
28use itertools::Itertools;
29
30use super::{
31 BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
32 tail_loop::TailLoopBuilder,
33};
34
35use crate::Hugr;
36
37use crate::hugr::HugrMut;
38
39pub trait Container {
44 fn container_node(&self) -> Node;
46 fn hugr_mut(&mut self) -> &mut Hugr;
48 fn hugr(&self) -> &Hugr;
50 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
54 let node: OpType = node.into();
55
56 let used_extensions = node
58 .used_extensions()
59 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
60 self.use_extensions(used_extensions);
61
62 let parent = self.container_node();
63 self.hugr_mut().add_node_with_parent(parent, node)
64 }
65
66 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
71 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
72 Wire::new(src, src_port)
73 }
74
75 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
84 self.add_child_node(constant.into()).into()
85 }
86
87 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
91 let region = child.entrypoint();
92 self.add_hugr_region(child, region)
93 }
94
95 fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult {
99 let parent = self.container_node();
100 self.hugr_mut().insert_region(parent, child, region)
101 }
102
103 fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
107 let parent = self.container_node();
108 self.hugr_mut().insert_from_view(parent, child)
109 }
110
111 fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
113 let parent = self.container_node();
114 self.hugr_mut().set_metadata(parent, key, meta);
116 }
117
118 fn set_child_metadata(
122 &mut self,
123 child: Node,
124 key: impl AsRef<str>,
125 meta: impl Into<NodeMetadata>,
126 ) {
127 self.hugr_mut().set_metadata(child, key, meta);
128 }
129
130 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
132 self.hugr_mut().use_extension(ext);
133 }
134
135 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
137 where
138 ExtensionRegistry: Extend<Reg>,
139 {
140 self.hugr_mut().use_extensions(registry);
141 }
142}
143
144pub trait HugrBuilder: Container {
147 fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
150 debug_assert!(
151 self.hugr()
152 .get_optype(self.hugr().module_root())
153 .is_module()
154 );
155 ModuleBuilder(self.hugr_mut())
156 }
157
158 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
160}
161
162pub trait SubContainer: Container {
164 type ContainerHandle;
167 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
170}
171pub trait Dataflow: Container {
173 fn num_inputs(&self) -> usize;
175 fn io(&self) -> [Node; 2] {
177 self.hugr()
178 .children(self.container_node())
179 .take(2)
180 .collect_vec()
181 .try_into()
182 .expect("First two children should be IO")
183 }
184 fn input(&self) -> BuildHandle<DataflowOpID> {
186 (self.io()[0], self.num_inputs()).into()
187 }
188 fn output(&self) -> DataflowOpID {
190 self.io()[1].into()
191 }
192 fn input_wires(&self) -> Outputs {
194 self.input().outputs()
195 }
196 fn add_dataflow_op(
205 &mut self,
206 nodetype: impl Into<OpType>,
207 input_wires: impl IntoIterator<Item = Wire>,
208 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
209 let outs = add_node_with_wires(self, nodetype, input_wires)?;
210
211 Ok(outs.into())
212 }
213
214 fn add_hugr_with_wires(
226 &mut self,
227 hugr: Hugr,
228 input_wires: impl IntoIterator<Item = Wire>,
229 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
230 let region = hugr.entrypoint();
231 self.add_hugr_region_with_wires(hugr, region, input_wires)
232 }
233
234 fn add_hugr_region_with_wires(
245 &mut self,
246 hugr: Hugr,
247 region: Node,
248 input_wires: impl IntoIterator<Item = Wire>,
249 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
250 let node = self.add_hugr_region(hugr, region).inserted_entrypoint;
251
252 wire_ins_return_outs(input_wires, node, self)
253 }
254
255 fn add_link_hugr_by_node_with_wires(
260 &mut self,
261 hugr: Hugr,
262 input_wires: impl IntoIterator<Item = Wire>,
263 defns: HashMap<Node, NodeLinkingDirective>,
264 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
265 let parent = Some(self.container_node());
266 let ep = hugr.entrypoint();
267 let node = self
268 .hugr_mut()
269 .insert_link_hugr_by_node(parent, hugr, defns)?
270 .node_map[&ep];
271 wire_ins_return_outs(input_wires, node, self)
272 }
273
274 fn add_hugr_view_with_wires(
284 &mut self,
285 hugr: &impl HugrView,
286 input_wires: impl IntoIterator<Item = Wire>,
287 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
288 let node = self.add_hugr_view(hugr).inserted_entrypoint;
289 wire_ins_return_outs(input_wires, node, self)
290 }
291
292 fn add_link_view_by_node_with_wires<H: HugrView>(
296 &mut self,
297 hugr: &H,
298 input_wires: impl IntoIterator<Item = Wire>,
299 defns: HashMap<H::Node, NodeLinkingDirective>,
300 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
301 let parent = Some(self.container_node());
302 let node = self
303 .hugr_mut()
304 .insert_link_view_by_node(parent, hugr, defns)
305 .map_err(|ins_err| BuildError::HugrViewInsertionError(ins_err.to_string()))?
306 .node_map[&hugr.entrypoint()];
307 wire_ins_return_outs(input_wires, node, self)
308 }
309
310 fn set_outputs(
316 &mut self,
317 output_wires: impl IntoIterator<Item = Wire>,
318 ) -> Result<(), BuildError> {
319 let [_, out] = self.io();
320 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
321 BuildError::OutputWiring {
322 container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
323 container_node: self.container_node(),
324 error,
325 }
326 })
327 }
328
329 #[track_caller]
335 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
336 collect_array(self.input_wires())
337 }
338
339 fn dfg_builder(
349 &mut self,
350 signature: Signature,
351 input_wires: impl IntoIterator<Item = Wire>,
352 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
353 let op = ops::DFG {
354 signature: signature.clone(),
355 };
356 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
357
358 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
359 }
360
361 fn dfg_builder_endo(
366 &mut self,
367 inputs: impl IntoIterator<Item = (Type, Wire)>,
368 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
369 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
370 self.dfg_builder(Signature::new_endo(types), input_wires)
371 }
372
373 fn cfg_builder(
384 &mut self,
385 inputs: impl IntoIterator<Item = (Type, Wire)>,
386 output_types: TypeRow,
387 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
388 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
389
390 let inputs: TypeRow = input_types.into();
391
392 let (cfg_node, _) = add_node_with_wires(
393 self,
394 ops::CFG {
395 signature: Signature::new(inputs.clone(), output_types.clone()),
396 },
397 input_wires,
398 )?;
399 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
400 }
401
402 fn load_const(&mut self, cid: &ConstID) -> Wire {
405 let const_node = cid.node();
406 let nodetype = self.hugr().get_optype(const_node);
407 let op: ops::Const = nodetype
408 .clone()
409 .try_into()
410 .expect("ConstID does not refer to Const op.");
411
412 let load_n = self
413 .add_dataflow_op(
414 ops::LoadConstant {
415 datatype: op.get_type().clone(),
416 },
417 vec![Wire::new(const_node, OutgoingPort::from(0))],
419 )
420 .expect("The constant type should match the LoadConstant type.");
421
422 load_n.out_wire(0)
423 }
424
425 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
428 let cid = self.add_constant(constant);
429 self.load_const(&cid)
430 }
431
432 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
435 self.add_load_const(constant.into())
436 }
437
438 fn load_func<const DEFINED: bool>(
444 &mut self,
445 fid: &FuncID<DEFINED>,
446 type_args: &[TypeArg],
447 ) -> Result<Wire, BuildError> {
448 let func_node = fid.node();
449 let func_op = self.hugr().get_optype(func_node);
450 let func_sig = match func_op {
451 OpType::FuncDefn(fd) => fd.signature().clone(),
452 OpType::FuncDecl(fd) => fd.signature().clone(),
453 _ => {
454 return Err(BuildError::UnexpectedType {
455 node: func_node,
456 op_desc: "FuncDecl/FuncDefn",
457 });
458 }
459 };
460
461 let load_n = self.add_dataflow_op(
462 ops::LoadFunction::try_new(func_sig, type_args)?,
463 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
465 )?;
466
467 Ok(load_n.out_wire(0))
468 }
469
470 fn tail_loop_builder(
481 &mut self,
482 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
483 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
484 just_out_types: TypeRow,
485 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
486 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
487 just_inputs.into_iter().unzip();
488 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
489 inputs_outputs.into_iter().unzip();
490 input_wires.extend(rest_input_wires);
491
492 let tail_loop = ops::TailLoop {
493 just_inputs: input_types.into(),
494 just_outputs: just_out_types,
495 rest: rest_types.into(),
496 };
497 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
499
500 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
501 }
502
503 fn conditional_builder(
516 &mut self,
517 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
518 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
519 output_types: TypeRow,
520 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
521 let mut input_wires = vec![sum_wire];
522 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
523 other_inputs.into_iter().unzip();
524
525 input_wires.extend(rest_input_wires);
526 let inputs: TypeRow = input_types.into();
527 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
528 let n_cases = sum_rows.len();
529 let n_out_wires = output_types.len();
530
531 let conditional_id = self.add_dataflow_op(
532 ops::Conditional {
533 sum_rows,
534 other_inputs: inputs,
535 outputs: output_types,
536 },
537 input_wires,
538 )?;
539
540 Ok(ConditionalBuilder {
541 base: self.hugr_mut(),
542 conditional_node: conditional_id.node(),
543 n_out_wires,
544 case_nodes: vec![None; n_cases],
545 })
546 }
547
548 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
551 self.add_other_wire(before.node(), after.node());
552 }
553
554 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
556 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
557
558 if let Some(EdgeKind::Value(typ)) = kind {
559 Ok(typ)
560 } else {
561 Err(BuildError::WireNotFound(wire))
562 }
563 }
564
565 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
573 let values = values.into_iter().collect_vec();
574 let types: Result<Vec<Type>, _> = values
575 .iter()
576 .map(|&wire| self.get_wire_type(wire))
577 .collect();
578 let types = types?.into();
579 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
580 Ok(make_op.out_wire(0))
581 }
582
583 fn make_sum(
593 &mut self,
594 tag: usize,
595 variants: impl IntoIterator<Item = TypeRow>,
596 values: impl IntoIterator<Item = Wire>,
597 ) -> Result<Wire, BuildError> {
598 let make_op = self.add_dataflow_op(
599 Tag {
600 tag,
601 variants: variants.into_iter().collect_vec(),
602 },
603 values.into_iter().collect_vec(),
604 )?;
605 Ok(make_op.out_wire(0))
606 }
607
608 fn make_continue(
618 &mut self,
619 tail_loop: ops::TailLoop,
620 values: impl IntoIterator<Item = Wire>,
621 ) -> Result<Wire, BuildError> {
622 self.make_sum(
623 TailLoop::CONTINUE_TAG,
624 [tail_loop.just_inputs, tail_loop.just_outputs],
625 values,
626 )
627 }
628
629 fn make_break(
639 &mut self,
640 loop_op: ops::TailLoop,
641 values: impl IntoIterator<Item = Wire>,
642 ) -> Result<Wire, BuildError> {
643 self.make_sum(
644 TailLoop::BREAK_TAG,
645 [loop_op.just_inputs, loop_op.just_outputs],
646 values,
647 )
648 }
649
650 fn call<const DEFINED: bool>(
659 &mut self,
660 function: &FuncID<DEFINED>,
661 type_args: &[TypeArg],
662 input_wires: impl IntoIterator<Item = Wire>,
663 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
664 let hugr = self.hugr();
665 let def_op = hugr.get_optype(function.node());
666 let type_scheme = match def_op {
667 OpType::FuncDefn(fd) => fd.signature().clone(),
668 OpType::FuncDecl(fd) => fd.signature().clone(),
669 _ => {
670 return Err(BuildError::UnexpectedType {
671 node: function.node(),
672 op_desc: "FuncDecl/FuncDefn",
673 });
674 }
675 };
676 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
677 let const_in_port = op.static_input_port().unwrap();
678 let op_id = self.add_dataflow_op(op, input_wires)?;
679 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
680
681 self.hugr_mut()
682 .connect(function.node(), src_port, op_id.node(), const_in_port);
683 Ok(op_id)
684 }
685
686 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<'_, Self> {
689 CircuitBuilder::new(wires, self)
690 }
691
692 fn add_barrier(
701 &mut self,
702 wires: impl IntoIterator<Item = Wire>,
703 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
704 let wires = wires.into_iter().collect_vec();
705 let types: Result<Vec<Type>, _> =
706 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
707 let types = types?;
708 let barrier_op =
709 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
710 Ok(barrier_op)
711 }
712}
713
714fn add_node_with_wires<T: Dataflow + ?Sized>(
723 data_builder: &mut T,
724 nodetype: impl Into<OpType>,
725 inputs: impl IntoIterator<Item = Wire>,
726) -> Result<(Node, usize), BuildError> {
727 let op: OpType = nodetype.into();
728 let num_outputs = op.value_output_count();
729 let op_node = data_builder.add_child_node(op.clone());
730
731 wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
732 op: Box::new(op),
733 error,
734 })?;
735
736 Ok((op_node, num_outputs))
737}
738
739fn wire_up_inputs<T: Dataflow + ?Sized>(
747 inputs: impl IntoIterator<Item = Wire>,
748 op_node: Node,
749 data_builder: &mut T,
750) -> Result<(), BuilderWiringError> {
751 for (dst_port, wire) in inputs.into_iter().enumerate() {
752 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
753 }
754 Ok(())
755}
756
757fn wire_ins_return_outs<T: Dataflow + ?Sized>(
758 inputs: impl IntoIterator<Item = Wire>,
759 node: Node,
760 data_builder: &mut T,
761) -> Result<BuildHandle<DataflowOpID>, BuildError> {
762 let op = data_builder.hugr().get_optype(node).clone();
763 let num_outputs = op.value_output_count();
764 wire_up_inputs(inputs, node, data_builder).map_err(|error| BuildError::OperationWiring {
765 op: Box::new(op),
766 error,
767 })?;
768 Ok((node, num_outputs).into())
769}
770
771fn wire_up<T: Dataflow + ?Sized>(
777 data_builder: &mut T,
778 src: Node,
779 src_port: impl Into<OutgoingPort>,
780 dst: Node,
781 dst_port: impl Into<IncomingPort>,
782) -> Result<bool, BuilderWiringError> {
783 let src_port = src_port.into();
784 let dst_port = dst_port.into();
785 let base = data_builder.hugr_mut();
786
787 let src_parent = base.get_parent(src);
788 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
789 let dst_parent = base.get_parent(dst);
790 let local_source = src_parent == dst_parent;
791 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
792 if !local_source {
793 if !typ.copyable() {
795 return Err(BuilderWiringError::NonCopyableIntergraph {
796 src,
797 src_offset: src_port.into(),
798 dst,
799 dst_offset: dst_port.into(),
800 typ: Box::new(typ),
801 });
802 }
803
804 let src_parent = src_parent.expect("Node has no parent");
805 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
806 .tuple_windows()
807 .find_map(|(ancestor, ancestor_parent)| {
808 (ancestor_parent == src_parent ||
809 Some(ancestor_parent) == src_parent_parent)
811 .then_some(ancestor)
812 })
813 else {
814 return Err(BuilderWiringError::NoRelationIntergraph {
815 src,
816 src_offset: src_port.into(),
817 dst,
818 dst_offset: dst_port.into(),
819 });
820 };
821
822 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
823 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
824 {
825 base.add_other_edge(src, src_sibling);
827 }
828 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
829 return Err(BuilderWiringError::NoCopyLinear {
831 typ: Box::new(typ),
832 src,
833 src_offset: src_port.into(),
834 });
835 }
836 }
837
838 data_builder
839 .hugr_mut()
840 .connect(src, src_port, dst, dst_port);
841 Ok(local_source
842 && matches!(
843 data_builder
844 .hugr_mut()
845 .get_optype(dst)
846 .port_kind(dst_port)
847 .unwrap(),
848 EdgeKind::Value(_)
849 ))
850}
851
852pub trait DataflowHugr: HugrBuilder + Dataflow {
854 fn finish_hugr_with_outputs(
860 mut self,
861 outputs: impl IntoIterator<Item = Wire>,
862 ) -> Result<Hugr, BuildError>
863 where
864 Self: Sized,
865 {
866 self.set_outputs(outputs)?;
867 Ok(self.finish_hugr()?)
868 }
869}
870
871pub trait DataflowSubContainer: SubContainer + Dataflow {
873 fn finish_with_outputs(
880 mut self,
881 outputs: impl IntoIterator<Item = Wire>,
882 ) -> Result<Self::ContainerHandle, BuildError>
883 where
884 Self: Sized,
885 {
886 self.set_outputs(outputs)?;
887 self.finish_sub_container()
888 }
889}
890
891impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
892impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}