1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::views::HugrView;
4use crate::hugr::{NodeMetadata, ValidationError};
5use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
6use crate::utils::collect_array;
7use crate::{Extension, IncomingPort, Node, OutgoingPort};
8
9use std::iter;
10use std::sync::Arc;
11
12use super::{BuilderWiringError, ModuleBuilder};
13use super::{
14 CircuitBuilder,
15 handle::{BuildHandle, Outputs},
16};
17
18use crate::{
19 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
20 types::EdgeKind,
21};
22
23use crate::extension::ExtensionRegistry;
24use crate::types::{Signature, Type, TypeArg, TypeRow};
25
26use itertools::Itertools;
27
28use super::{
29 BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
30 tail_loop::TailLoopBuilder,
31};
32
33use crate::Hugr;
34
35use crate::hugr::HugrMut;
36
37pub trait Container {
42 fn container_node(&self) -> Node;
44 fn hugr_mut(&mut self) -> &mut Hugr;
46 fn hugr(&self) -> &Hugr;
48 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
52 let node: OpType = node.into();
53
54 let used_extensions = node
56 .used_extensions()
57 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
58 self.use_extensions(used_extensions);
59
60 let parent = self.container_node();
61 self.hugr_mut().add_node_with_parent(parent, node)
62 }
63
64 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
69 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
70 Wire::new(src, src_port)
71 }
72
73 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
82 self.add_child_node(constant.into()).into()
83 }
84
85 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
89 let region = child.entrypoint();
90 self.add_hugr_region(child, region)
91 }
92
93 fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult {
97 let parent = self.container_node();
98 self.hugr_mut().insert_region(parent, child, region)
99 }
100
101 fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
103 let parent = self.container_node();
104 self.hugr_mut().insert_from_view(parent, child)
105 }
106
107 fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
109 let parent = self.container_node();
110 self.hugr_mut().set_metadata(parent, key, meta);
112 }
113
114 fn set_child_metadata(
118 &mut self,
119 child: Node,
120 key: impl AsRef<str>,
121 meta: impl Into<NodeMetadata>,
122 ) {
123 self.hugr_mut().set_metadata(child, key, meta);
124 }
125
126 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
128 self.hugr_mut().use_extension(ext);
129 }
130
131 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
133 where
134 ExtensionRegistry: Extend<Reg>,
135 {
136 self.hugr_mut().use_extensions(registry);
137 }
138}
139
140pub trait HugrBuilder: Container {
143 fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
146 debug_assert!(
147 self.hugr()
148 .get_optype(self.hugr().module_root())
149 .is_module()
150 );
151 ModuleBuilder(self.hugr_mut())
152 }
153
154 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
156}
157
158pub trait SubContainer: Container {
160 type ContainerHandle;
163 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
166}
167pub trait Dataflow: Container {
169 fn num_inputs(&self) -> usize;
171 fn io(&self) -> [Node; 2] {
173 self.hugr()
174 .children(self.container_node())
175 .take(2)
176 .collect_vec()
177 .try_into()
178 .expect("First two children should be IO")
179 }
180 fn input(&self) -> BuildHandle<DataflowOpID> {
182 (self.io()[0], self.num_inputs()).into()
183 }
184 fn output(&self) -> DataflowOpID {
186 self.io()[1].into()
187 }
188 fn input_wires(&self) -> Outputs {
190 self.input().outputs()
191 }
192 fn add_dataflow_op(
201 &mut self,
202 nodetype: impl Into<OpType>,
203 input_wires: impl IntoIterator<Item = Wire>,
204 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
205 let outs = add_node_with_wires(self, nodetype, input_wires)?;
206
207 Ok(outs.into())
208 }
209
210 fn add_hugr_with_wires(
222 &mut self,
223 hugr: Hugr,
224 input_wires: impl IntoIterator<Item = Wire>,
225 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
226 let region = hugr.entrypoint();
227 self.add_hugr_region_with_wires(hugr, region, input_wires)
228 }
229
230 fn add_hugr_region_with_wires(
241 &mut self,
242 hugr: Hugr,
243 region: Node,
244 input_wires: impl IntoIterator<Item = Wire>,
245 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
246 let optype = hugr.get_optype(region).clone();
247 let num_outputs = optype.value_output_count();
248 let node = self.add_hugr_region(hugr, region).inserted_entrypoint;
249
250 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
251 op: Box::new(optype),
252 error,
253 })?;
254
255 Ok((node, num_outputs).into())
256 }
257
258 fn add_hugr_view_with_wires(
266 &mut self,
267 hugr: &impl HugrView,
268 input_wires: impl IntoIterator<Item = Wire>,
269 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
270 let node = self.add_hugr_view(hugr).inserted_entrypoint;
271 let optype = hugr.get_optype(hugr.entrypoint()).clone();
272 let num_outputs = optype.value_output_count();
273
274 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
275 op: Box::new(optype),
276 error,
277 })?;
278
279 Ok((node, num_outputs).into())
280 }
281
282 fn set_outputs(
288 &mut self,
289 output_wires: impl IntoIterator<Item = Wire>,
290 ) -> Result<(), BuildError> {
291 let [_, out] = self.io();
292 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
293 BuildError::OutputWiring {
294 container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
295 container_node: self.container_node(),
296 error,
297 }
298 })
299 }
300
301 #[track_caller]
307 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
308 collect_array(self.input_wires())
309 }
310
311 fn dfg_builder(
321 &mut self,
322 signature: Signature,
323 input_wires: impl IntoIterator<Item = Wire>,
324 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
325 let op = ops::DFG {
326 signature: signature.clone(),
327 };
328 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
329
330 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
331 }
332
333 fn dfg_builder_endo(
338 &mut self,
339 inputs: impl IntoIterator<Item = (Type, Wire)>,
340 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
341 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
342 self.dfg_builder(Signature::new_endo(types), input_wires)
343 }
344
345 fn cfg_builder(
356 &mut self,
357 inputs: impl IntoIterator<Item = (Type, Wire)>,
358 output_types: TypeRow,
359 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
360 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
361
362 let inputs: TypeRow = input_types.into();
363
364 let (cfg_node, _) = add_node_with_wires(
365 self,
366 ops::CFG {
367 signature: Signature::new(inputs.clone(), output_types.clone()),
368 },
369 input_wires,
370 )?;
371 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
372 }
373
374 fn load_const(&mut self, cid: &ConstID) -> Wire {
377 let const_node = cid.node();
378 let nodetype = self.hugr().get_optype(const_node);
379 let op: ops::Const = nodetype
380 .clone()
381 .try_into()
382 .expect("ConstID does not refer to Const op.");
383
384 let load_n = self
385 .add_dataflow_op(
386 ops::LoadConstant {
387 datatype: op.get_type().clone(),
388 },
389 vec![Wire::new(const_node, OutgoingPort::from(0))],
391 )
392 .expect("The constant type should match the LoadConstant type.");
393
394 load_n.out_wire(0)
395 }
396
397 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
400 let cid = self.add_constant(constant);
401 self.load_const(&cid)
402 }
403
404 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
407 self.add_load_const(constant.into())
408 }
409
410 fn load_func<const DEFINED: bool>(
416 &mut self,
417 fid: &FuncID<DEFINED>,
418 type_args: &[TypeArg],
419 ) -> Result<Wire, BuildError> {
420 let func_node = fid.node();
421 let func_op = self.hugr().get_optype(func_node);
422 let func_sig = match func_op {
423 OpType::FuncDefn(fd) => fd.signature().clone(),
424 OpType::FuncDecl(fd) => fd.signature().clone(),
425 _ => {
426 return Err(BuildError::UnexpectedType {
427 node: func_node,
428 op_desc: "FuncDecl/FuncDefn",
429 });
430 }
431 };
432
433 let load_n = self.add_dataflow_op(
434 ops::LoadFunction::try_new(func_sig, type_args)?,
435 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
437 )?;
438
439 Ok(load_n.out_wire(0))
440 }
441
442 fn tail_loop_builder(
453 &mut self,
454 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
455 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
456 just_out_types: TypeRow,
457 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
458 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
459 just_inputs.into_iter().unzip();
460 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
461 inputs_outputs.into_iter().unzip();
462 input_wires.extend(rest_input_wires);
463
464 let tail_loop = ops::TailLoop {
465 just_inputs: input_types.into(),
466 just_outputs: just_out_types,
467 rest: rest_types.into(),
468 };
469 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
471
472 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
473 }
474
475 fn conditional_builder(
488 &mut self,
489 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
490 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
491 output_types: TypeRow,
492 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
493 let mut input_wires = vec![sum_wire];
494 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
495 other_inputs.into_iter().unzip();
496
497 input_wires.extend(rest_input_wires);
498 let inputs: TypeRow = input_types.into();
499 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
500 let n_cases = sum_rows.len();
501 let n_out_wires = output_types.len();
502
503 let conditional_id = self.add_dataflow_op(
504 ops::Conditional {
505 sum_rows,
506 other_inputs: inputs,
507 outputs: output_types,
508 },
509 input_wires,
510 )?;
511
512 Ok(ConditionalBuilder {
513 base: self.hugr_mut(),
514 conditional_node: conditional_id.node(),
515 n_out_wires,
516 case_nodes: vec![None; n_cases],
517 })
518 }
519
520 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
523 self.add_other_wire(before.node(), after.node());
524 }
525
526 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
528 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
529
530 if let Some(EdgeKind::Value(typ)) = kind {
531 Ok(typ)
532 } else {
533 Err(BuildError::WireNotFound(wire))
534 }
535 }
536
537 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
545 let values = values.into_iter().collect_vec();
546 let types: Result<Vec<Type>, _> = values
547 .iter()
548 .map(|&wire| self.get_wire_type(wire))
549 .collect();
550 let types = types?.into();
551 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
552 Ok(make_op.out_wire(0))
553 }
554
555 fn make_sum(
565 &mut self,
566 tag: usize,
567 variants: impl IntoIterator<Item = TypeRow>,
568 values: impl IntoIterator<Item = Wire>,
569 ) -> Result<Wire, BuildError> {
570 let make_op = self.add_dataflow_op(
571 Tag {
572 tag,
573 variants: variants.into_iter().collect_vec(),
574 },
575 values.into_iter().collect_vec(),
576 )?;
577 Ok(make_op.out_wire(0))
578 }
579
580 fn make_continue(
590 &mut self,
591 tail_loop: ops::TailLoop,
592 values: impl IntoIterator<Item = Wire>,
593 ) -> Result<Wire, BuildError> {
594 self.make_sum(
595 TailLoop::CONTINUE_TAG,
596 [tail_loop.just_inputs, tail_loop.just_outputs],
597 values,
598 )
599 }
600
601 fn make_break(
611 &mut self,
612 loop_op: ops::TailLoop,
613 values: impl IntoIterator<Item = Wire>,
614 ) -> Result<Wire, BuildError> {
615 self.make_sum(
616 TailLoop::BREAK_TAG,
617 [loop_op.just_inputs, loop_op.just_outputs],
618 values,
619 )
620 }
621
622 fn call<const DEFINED: bool>(
631 &mut self,
632 function: &FuncID<DEFINED>,
633 type_args: &[TypeArg],
634 input_wires: impl IntoIterator<Item = Wire>,
635 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
636 let hugr = self.hugr();
637 let def_op = hugr.get_optype(function.node());
638 let type_scheme = match def_op {
639 OpType::FuncDefn(fd) => fd.signature().clone(),
640 OpType::FuncDecl(fd) => fd.signature().clone(),
641 _ => {
642 return Err(BuildError::UnexpectedType {
643 node: function.node(),
644 op_desc: "FuncDecl/FuncDefn",
645 });
646 }
647 };
648 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
649 let const_in_port = op.static_input_port().unwrap();
650 let op_id = self.add_dataflow_op(op, input_wires)?;
651 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
652
653 self.hugr_mut()
654 .connect(function.node(), src_port, op_id.node(), const_in_port);
655 Ok(op_id)
656 }
657
658 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<'_, Self> {
661 CircuitBuilder::new(wires, self)
662 }
663
664 fn add_barrier(
673 &mut self,
674 wires: impl IntoIterator<Item = Wire>,
675 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
676 let wires = wires.into_iter().collect_vec();
677 let types: Result<Vec<Type>, _> =
678 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
679 let types = types?;
680 let barrier_op =
681 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
682 Ok(barrier_op)
683 }
684}
685
686fn add_node_with_wires<T: Dataflow + ?Sized>(
695 data_builder: &mut T,
696 nodetype: impl Into<OpType>,
697 inputs: impl IntoIterator<Item = Wire>,
698) -> Result<(Node, usize), BuildError> {
699 let op: OpType = nodetype.into();
700 let num_outputs = op.value_output_count();
701 let op_node = data_builder.add_child_node(op.clone());
702
703 wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
704 op: Box::new(op),
705 error,
706 })?;
707
708 Ok((op_node, num_outputs))
709}
710
711fn wire_up_inputs<T: Dataflow + ?Sized>(
719 inputs: impl IntoIterator<Item = Wire>,
720 op_node: Node,
721 data_builder: &mut T,
722) -> Result<(), BuilderWiringError> {
723 for (dst_port, wire) in inputs.into_iter().enumerate() {
724 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
725 }
726 Ok(())
727}
728
729fn wire_up<T: Dataflow + ?Sized>(
735 data_builder: &mut T,
736 src: Node,
737 src_port: impl Into<OutgoingPort>,
738 dst: Node,
739 dst_port: impl Into<IncomingPort>,
740) -> Result<bool, BuilderWiringError> {
741 let src_port = src_port.into();
742 let dst_port = dst_port.into();
743 let base = data_builder.hugr_mut();
744
745 let src_parent = base.get_parent(src);
746 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
747 let dst_parent = base.get_parent(dst);
748 let local_source = src_parent == dst_parent;
749 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
750 if !local_source {
751 if !typ.copyable() {
753 return Err(BuilderWiringError::NonCopyableIntergraph {
754 src,
755 src_offset: src_port.into(),
756 dst,
757 dst_offset: dst_port.into(),
758 typ: Box::new(typ),
759 });
760 }
761
762 let src_parent = src_parent.expect("Node has no parent");
763 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
764 .tuple_windows()
765 .find_map(|(ancestor, ancestor_parent)| {
766 (ancestor_parent == src_parent ||
767 Some(ancestor_parent) == src_parent_parent)
769 .then_some(ancestor)
770 })
771 else {
772 return Err(BuilderWiringError::NoRelationIntergraph {
773 src,
774 src_offset: src_port.into(),
775 dst,
776 dst_offset: dst_port.into(),
777 });
778 };
779
780 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
781 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
782 {
783 base.add_other_edge(src, src_sibling);
785 }
786 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
787 return Err(BuilderWiringError::NoCopyLinear {
789 typ: Box::new(typ),
790 src,
791 src_offset: src_port.into(),
792 });
793 }
794 }
795
796 data_builder
797 .hugr_mut()
798 .connect(src, src_port, dst, dst_port);
799 Ok(local_source
800 && matches!(
801 data_builder
802 .hugr_mut()
803 .get_optype(dst)
804 .port_kind(dst_port)
805 .unwrap(),
806 EdgeKind::Value(_)
807 ))
808}
809
810pub trait DataflowHugr: HugrBuilder + Dataflow {
812 fn finish_hugr_with_outputs(
818 mut self,
819 outputs: impl IntoIterator<Item = Wire>,
820 ) -> Result<Hugr, BuildError>
821 where
822 Self: Sized,
823 {
824 self.set_outputs(outputs)?;
825 Ok(self.finish_hugr()?)
826 }
827}
828
829pub trait DataflowSubContainer: SubContainer + Dataflow {
831 fn finish_with_outputs(
838 mut self,
839 outputs: impl IntoIterator<Item = Wire>,
840 ) -> Result<Self::ContainerHandle, BuildError>
841 where
842 Self: Sized,
843 {
844 self.set_outputs(outputs)?;
845 self.finish_sub_container()
846 }
847}
848
849impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
850impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}