1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::views::HugrView;
4use crate::hugr::{NodeMetadata, ValidationError};
5use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
6use crate::utils::collect_array;
7use crate::{Extension, IncomingPort, Node, OutgoingPort};
8
9use std::iter;
10use std::sync::Arc;
11
12use super::{BuilderWiringError, ModuleBuilder};
13use super::{
14 CircuitBuilder,
15 handle::{BuildHandle, Outputs},
16};
17
18use crate::{
19 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
20 types::EdgeKind,
21};
22
23use crate::extension::ExtensionRegistry;
24use crate::types::{Signature, Type, TypeArg, TypeRow};
25
26use itertools::Itertools;
27
28use super::{
29 BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
30 tail_loop::TailLoopBuilder,
31};
32
33use crate::Hugr;
34
35use crate::hugr::HugrMut;
36
37pub trait Container {
42 fn container_node(&self) -> Node;
44 fn hugr_mut(&mut self) -> &mut Hugr;
46 fn hugr(&self) -> &Hugr;
48 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
52 let node: OpType = node.into();
53
54 let used_extensions = node
56 .used_extensions()
57 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
58 self.use_extensions(used_extensions);
59
60 let parent = self.container_node();
61 self.hugr_mut().add_node_with_parent(parent, node)
62 }
63
64 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
69 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
70 Wire::new(src, src_port)
71 }
72
73 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
82 self.add_child_node(constant.into()).into()
83 }
84
85 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
89 let region = child.entrypoint();
90 self.add_hugr_region(child, region)
91 }
92
93 fn add_hugr_region(&mut self, child: Hugr, region: Node) -> InsertionResult {
97 let parent = self.container_node();
98 self.hugr_mut().insert_region(parent, child, region)
99 }
100
101 fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
106 let parent = self.container_node();
107 self.hugr_mut().insert_from_view(parent, child)
108 }
109
110 fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
112 let parent = self.container_node();
113 self.hugr_mut().set_metadata(parent, key, meta);
115 }
116
117 fn set_child_metadata(
121 &mut self,
122 child: Node,
123 key: impl AsRef<str>,
124 meta: impl Into<NodeMetadata>,
125 ) {
126 self.hugr_mut().set_metadata(child, key, meta);
127 }
128
129 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
131 self.hugr_mut().use_extension(ext);
132 }
133
134 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
136 where
137 ExtensionRegistry: Extend<Reg>,
138 {
139 self.hugr_mut().use_extensions(registry);
140 }
141}
142
143pub trait HugrBuilder: Container {
146 fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
149 debug_assert!(
150 self.hugr()
151 .get_optype(self.hugr().module_root())
152 .is_module()
153 );
154 ModuleBuilder(self.hugr_mut())
155 }
156
157 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
159}
160
161pub trait SubContainer: Container {
163 type ContainerHandle;
166 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
169}
170pub trait Dataflow: Container {
172 fn num_inputs(&self) -> usize;
174 fn io(&self) -> [Node; 2] {
176 self.hugr()
177 .children(self.container_node())
178 .take(2)
179 .collect_vec()
180 .try_into()
181 .expect("First two children should be IO")
182 }
183 fn input(&self) -> BuildHandle<DataflowOpID> {
185 (self.io()[0], self.num_inputs()).into()
186 }
187 fn output(&self) -> DataflowOpID {
189 self.io()[1].into()
190 }
191 fn input_wires(&self) -> Outputs {
193 self.input().outputs()
194 }
195 fn add_dataflow_op(
204 &mut self,
205 nodetype: impl Into<OpType>,
206 input_wires: impl IntoIterator<Item = Wire>,
207 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
208 let outs = add_node_with_wires(self, nodetype, input_wires)?;
209
210 Ok(outs.into())
211 }
212
213 fn add_hugr_with_wires(
225 &mut self,
226 hugr: Hugr,
227 input_wires: impl IntoIterator<Item = Wire>,
228 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
229 let region = hugr.entrypoint();
230 self.add_hugr_region_with_wires(hugr, region, input_wires)
231 }
232
233 fn add_hugr_region_with_wires(
244 &mut self,
245 hugr: Hugr,
246 region: Node,
247 input_wires: impl IntoIterator<Item = Wire>,
248 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
249 let optype = hugr.get_optype(region).clone();
250 let num_outputs = optype.value_output_count();
251 let node = self.add_hugr_region(hugr, region).inserted_entrypoint;
252
253 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
254 op: Box::new(optype),
255 error,
256 })?;
257
258 Ok((node, num_outputs).into())
259 }
260
261 fn add_hugr_view_with_wires(
271 &mut self,
272 hugr: &impl HugrView,
273 input_wires: impl IntoIterator<Item = Wire>,
274 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
275 let node = self.add_hugr_view(hugr).inserted_entrypoint;
276 let optype = hugr.get_optype(hugr.entrypoint()).clone();
277 let num_outputs = optype.value_output_count();
278
279 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
280 op: Box::new(optype),
281 error,
282 })?;
283
284 Ok((node, num_outputs).into())
285 }
286
287 fn set_outputs(
293 &mut self,
294 output_wires: impl IntoIterator<Item = Wire>,
295 ) -> Result<(), BuildError> {
296 let [_, out] = self.io();
297 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
298 BuildError::OutputWiring {
299 container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
300 container_node: self.container_node(),
301 error,
302 }
303 })
304 }
305
306 #[track_caller]
312 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
313 collect_array(self.input_wires())
314 }
315
316 fn dfg_builder(
326 &mut self,
327 signature: Signature,
328 input_wires: impl IntoIterator<Item = Wire>,
329 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
330 let op = ops::DFG {
331 signature: signature.clone(),
332 };
333 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
334
335 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
336 }
337
338 fn dfg_builder_endo(
343 &mut self,
344 inputs: impl IntoIterator<Item = (Type, Wire)>,
345 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
346 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
347 self.dfg_builder(Signature::new_endo(types), input_wires)
348 }
349
350 fn cfg_builder(
361 &mut self,
362 inputs: impl IntoIterator<Item = (Type, Wire)>,
363 output_types: TypeRow,
364 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
365 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
366
367 let inputs: TypeRow = input_types.into();
368
369 let (cfg_node, _) = add_node_with_wires(
370 self,
371 ops::CFG {
372 signature: Signature::new(inputs.clone(), output_types.clone()),
373 },
374 input_wires,
375 )?;
376 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
377 }
378
379 fn load_const(&mut self, cid: &ConstID) -> Wire {
382 let const_node = cid.node();
383 let nodetype = self.hugr().get_optype(const_node);
384 let op: ops::Const = nodetype
385 .clone()
386 .try_into()
387 .expect("ConstID does not refer to Const op.");
388
389 let load_n = self
390 .add_dataflow_op(
391 ops::LoadConstant {
392 datatype: op.get_type().clone(),
393 },
394 vec![Wire::new(const_node, OutgoingPort::from(0))],
396 )
397 .expect("The constant type should match the LoadConstant type.");
398
399 load_n.out_wire(0)
400 }
401
402 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
405 let cid = self.add_constant(constant);
406 self.load_const(&cid)
407 }
408
409 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
412 self.add_load_const(constant.into())
413 }
414
415 fn load_func<const DEFINED: bool>(
421 &mut self,
422 fid: &FuncID<DEFINED>,
423 type_args: &[TypeArg],
424 ) -> Result<Wire, BuildError> {
425 let func_node = fid.node();
426 let func_op = self.hugr().get_optype(func_node);
427 let func_sig = match func_op {
428 OpType::FuncDefn(fd) => fd.signature().clone(),
429 OpType::FuncDecl(fd) => fd.signature().clone(),
430 _ => {
431 return Err(BuildError::UnexpectedType {
432 node: func_node,
433 op_desc: "FuncDecl/FuncDefn",
434 });
435 }
436 };
437
438 let load_n = self.add_dataflow_op(
439 ops::LoadFunction::try_new(func_sig, type_args)?,
440 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
442 )?;
443
444 Ok(load_n.out_wire(0))
445 }
446
447 fn tail_loop_builder(
458 &mut self,
459 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
460 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
461 just_out_types: TypeRow,
462 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
463 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
464 just_inputs.into_iter().unzip();
465 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
466 inputs_outputs.into_iter().unzip();
467 input_wires.extend(rest_input_wires);
468
469 let tail_loop = ops::TailLoop {
470 just_inputs: input_types.into(),
471 just_outputs: just_out_types,
472 rest: rest_types.into(),
473 };
474 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
476
477 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
478 }
479
480 fn conditional_builder(
493 &mut self,
494 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
495 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
496 output_types: TypeRow,
497 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
498 let mut input_wires = vec![sum_wire];
499 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
500 other_inputs.into_iter().unzip();
501
502 input_wires.extend(rest_input_wires);
503 let inputs: TypeRow = input_types.into();
504 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
505 let n_cases = sum_rows.len();
506 let n_out_wires = output_types.len();
507
508 let conditional_id = self.add_dataflow_op(
509 ops::Conditional {
510 sum_rows,
511 other_inputs: inputs,
512 outputs: output_types,
513 },
514 input_wires,
515 )?;
516
517 Ok(ConditionalBuilder {
518 base: self.hugr_mut(),
519 conditional_node: conditional_id.node(),
520 n_out_wires,
521 case_nodes: vec![None; n_cases],
522 })
523 }
524
525 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
528 self.add_other_wire(before.node(), after.node());
529 }
530
531 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
533 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
534
535 if let Some(EdgeKind::Value(typ)) = kind {
536 Ok(typ)
537 } else {
538 Err(BuildError::WireNotFound(wire))
539 }
540 }
541
542 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
550 let values = values.into_iter().collect_vec();
551 let types: Result<Vec<Type>, _> = values
552 .iter()
553 .map(|&wire| self.get_wire_type(wire))
554 .collect();
555 let types = types?.into();
556 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
557 Ok(make_op.out_wire(0))
558 }
559
560 fn make_sum(
570 &mut self,
571 tag: usize,
572 variants: impl IntoIterator<Item = TypeRow>,
573 values: impl IntoIterator<Item = Wire>,
574 ) -> Result<Wire, BuildError> {
575 let make_op = self.add_dataflow_op(
576 Tag {
577 tag,
578 variants: variants.into_iter().collect_vec(),
579 },
580 values.into_iter().collect_vec(),
581 )?;
582 Ok(make_op.out_wire(0))
583 }
584
585 fn make_continue(
595 &mut self,
596 tail_loop: ops::TailLoop,
597 values: impl IntoIterator<Item = Wire>,
598 ) -> Result<Wire, BuildError> {
599 self.make_sum(
600 TailLoop::CONTINUE_TAG,
601 [tail_loop.just_inputs, tail_loop.just_outputs],
602 values,
603 )
604 }
605
606 fn make_break(
616 &mut self,
617 loop_op: ops::TailLoop,
618 values: impl IntoIterator<Item = Wire>,
619 ) -> Result<Wire, BuildError> {
620 self.make_sum(
621 TailLoop::BREAK_TAG,
622 [loop_op.just_inputs, loop_op.just_outputs],
623 values,
624 )
625 }
626
627 fn call<const DEFINED: bool>(
636 &mut self,
637 function: &FuncID<DEFINED>,
638 type_args: &[TypeArg],
639 input_wires: impl IntoIterator<Item = Wire>,
640 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
641 let hugr = self.hugr();
642 let def_op = hugr.get_optype(function.node());
643 let type_scheme = match def_op {
644 OpType::FuncDefn(fd) => fd.signature().clone(),
645 OpType::FuncDecl(fd) => fd.signature().clone(),
646 _ => {
647 return Err(BuildError::UnexpectedType {
648 node: function.node(),
649 op_desc: "FuncDecl/FuncDefn",
650 });
651 }
652 };
653 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
654 let const_in_port = op.static_input_port().unwrap();
655 let op_id = self.add_dataflow_op(op, input_wires)?;
656 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
657
658 self.hugr_mut()
659 .connect(function.node(), src_port, op_id.node(), const_in_port);
660 Ok(op_id)
661 }
662
663 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<'_, Self> {
666 CircuitBuilder::new(wires, self)
667 }
668
669 fn add_barrier(
678 &mut self,
679 wires: impl IntoIterator<Item = Wire>,
680 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
681 let wires = wires.into_iter().collect_vec();
682 let types: Result<Vec<Type>, _> =
683 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
684 let types = types?;
685 let barrier_op =
686 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
687 Ok(barrier_op)
688 }
689}
690
691fn add_node_with_wires<T: Dataflow + ?Sized>(
700 data_builder: &mut T,
701 nodetype: impl Into<OpType>,
702 inputs: impl IntoIterator<Item = Wire>,
703) -> Result<(Node, usize), BuildError> {
704 let op: OpType = nodetype.into();
705 let num_outputs = op.value_output_count();
706 let op_node = data_builder.add_child_node(op.clone());
707
708 wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
709 op: Box::new(op),
710 error,
711 })?;
712
713 Ok((op_node, num_outputs))
714}
715
716fn wire_up_inputs<T: Dataflow + ?Sized>(
724 inputs: impl IntoIterator<Item = Wire>,
725 op_node: Node,
726 data_builder: &mut T,
727) -> Result<(), BuilderWiringError> {
728 for (dst_port, wire) in inputs.into_iter().enumerate() {
729 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
730 }
731 Ok(())
732}
733
734fn wire_up<T: Dataflow + ?Sized>(
740 data_builder: &mut T,
741 src: Node,
742 src_port: impl Into<OutgoingPort>,
743 dst: Node,
744 dst_port: impl Into<IncomingPort>,
745) -> Result<bool, BuilderWiringError> {
746 let src_port = src_port.into();
747 let dst_port = dst_port.into();
748 let base = data_builder.hugr_mut();
749
750 let src_parent = base.get_parent(src);
751 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
752 let dst_parent = base.get_parent(dst);
753 let local_source = src_parent == dst_parent;
754 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
755 if !local_source {
756 if !typ.copyable() {
758 return Err(BuilderWiringError::NonCopyableIntergraph {
759 src,
760 src_offset: src_port.into(),
761 dst,
762 dst_offset: dst_port.into(),
763 typ: Box::new(typ),
764 });
765 }
766
767 let src_parent = src_parent.expect("Node has no parent");
768 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
769 .tuple_windows()
770 .find_map(|(ancestor, ancestor_parent)| {
771 (ancestor_parent == src_parent ||
772 Some(ancestor_parent) == src_parent_parent)
774 .then_some(ancestor)
775 })
776 else {
777 return Err(BuilderWiringError::NoRelationIntergraph {
778 src,
779 src_offset: src_port.into(),
780 dst,
781 dst_offset: dst_port.into(),
782 });
783 };
784
785 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
786 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
787 {
788 base.add_other_edge(src, src_sibling);
790 }
791 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
792 return Err(BuilderWiringError::NoCopyLinear {
794 typ: Box::new(typ),
795 src,
796 src_offset: src_port.into(),
797 });
798 }
799 }
800
801 data_builder
802 .hugr_mut()
803 .connect(src, src_port, dst, dst_port);
804 Ok(local_source
805 && matches!(
806 data_builder
807 .hugr_mut()
808 .get_optype(dst)
809 .port_kind(dst_port)
810 .unwrap(),
811 EdgeKind::Value(_)
812 ))
813}
814
815pub trait DataflowHugr: HugrBuilder + Dataflow {
817 fn finish_hugr_with_outputs(
823 mut self,
824 outputs: impl IntoIterator<Item = Wire>,
825 ) -> Result<Hugr, BuildError>
826 where
827 Self: Sized,
828 {
829 self.set_outputs(outputs)?;
830 Ok(self.finish_hugr()?)
831 }
832}
833
834pub trait DataflowSubContainer: SubContainer + Dataflow {
836 fn finish_with_outputs(
843 mut self,
844 outputs: impl IntoIterator<Item = Wire>,
845 ) -> Result<Self::ContainerHandle, BuildError>
846 where
847 Self: Sized,
848 {
849 self.set_outputs(outputs)?;
850 self.finish_sub_container()
851 }
852}
853
854impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
855impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}