1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::views::HugrView;
4use crate::hugr::{NodeMetadata, ValidationError};
5use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
6use crate::utils::collect_array;
7use crate::{Extension, IncomingPort, Node, OutgoingPort};
8
9use std::iter;
10use std::sync::Arc;
11
12use super::{BuilderWiringError, ModuleBuilder};
13use super::{
14 CircuitBuilder,
15 handle::{BuildHandle, Outputs},
16};
17
18use crate::{
19 ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
20 types::EdgeKind,
21};
22
23use crate::extension::ExtensionRegistry;
24use crate::types::{Signature, Type, TypeArg, TypeRow};
25
26use itertools::Itertools;
27
28use super::{
29 BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
30 tail_loop::TailLoopBuilder,
31};
32
33use crate::Hugr;
34
35use crate::hugr::HugrMut;
36
37pub trait Container {
42 fn container_node(&self) -> Node;
44 fn hugr_mut(&mut self) -> &mut Hugr;
46 fn hugr(&self) -> &Hugr;
48 fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
52 let node: OpType = node.into();
53
54 let used_extensions = node
56 .used_extensions()
57 .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
58 self.use_extensions(used_extensions);
59
60 let parent = self.container_node();
61 self.hugr_mut().add_node_with_parent(parent, node)
62 }
63
64 fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
69 let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
70 Wire::new(src, src_port)
71 }
72
73 fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
82 self.add_child_node(constant.into()).into()
83 }
84
85 fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
87 let parent = self.container_node();
88 self.hugr_mut().insert_hugr(parent, child)
89 }
90
91 fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
93 let parent = self.container_node();
94 self.hugr_mut().insert_from_view(parent, child)
95 }
96
97 fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
99 let parent = self.container_node();
100 self.hugr_mut().set_metadata(parent, key, meta);
102 }
103
104 fn set_child_metadata(
108 &mut self,
109 child: Node,
110 key: impl AsRef<str>,
111 meta: impl Into<NodeMetadata>,
112 ) {
113 self.hugr_mut().set_metadata(child, key, meta);
114 }
115
116 fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
118 self.hugr_mut().use_extension(ext);
119 }
120
121 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
123 where
124 ExtensionRegistry: Extend<Reg>,
125 {
126 self.hugr_mut().use_extensions(registry);
127 }
128}
129
130pub trait HugrBuilder: Container {
133 fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
136 debug_assert!(
137 self.hugr()
138 .get_optype(self.hugr().module_root())
139 .is_module()
140 );
141 ModuleBuilder(self.hugr_mut())
142 }
143
144 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
146}
147
148pub trait SubContainer: Container {
150 type ContainerHandle;
153 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
156}
157pub trait Dataflow: Container {
159 fn num_inputs(&self) -> usize;
161 fn io(&self) -> [Node; 2] {
163 self.hugr()
164 .children(self.container_node())
165 .take(2)
166 .collect_vec()
167 .try_into()
168 .expect("First two children should be IO")
169 }
170 fn input(&self) -> BuildHandle<DataflowOpID> {
172 (self.io()[0], self.num_inputs()).into()
173 }
174 fn output(&self) -> DataflowOpID {
176 self.io()[1].into()
177 }
178 fn input_wires(&self) -> Outputs {
180 self.input().outputs()
181 }
182 fn add_dataflow_op(
191 &mut self,
192 nodetype: impl Into<OpType>,
193 input_wires: impl IntoIterator<Item = Wire>,
194 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
195 let outs = add_node_with_wires(self, nodetype, input_wires)?;
196
197 Ok(outs.into())
198 }
199
200 fn add_hugr_with_wires(
208 &mut self,
209 hugr: Hugr,
210 input_wires: impl IntoIterator<Item = Wire>,
211 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
212 let optype = hugr.get_optype(hugr.entrypoint()).clone();
213 let num_outputs = optype.value_output_count();
214 let node = self.add_hugr(hugr).inserted_entrypoint;
215
216 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
217 op: Box::new(optype),
218 error,
219 })?;
220
221 Ok((node, num_outputs).into())
222 }
223
224 fn add_hugr_view_with_wires(
232 &mut self,
233 hugr: &impl HugrView,
234 input_wires: impl IntoIterator<Item = Wire>,
235 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
236 let node = self.add_hugr_view(hugr).inserted_entrypoint;
237 let optype = hugr.get_optype(hugr.entrypoint()).clone();
238 let num_outputs = optype.value_output_count();
239
240 wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
241 op: Box::new(optype),
242 error,
243 })?;
244
245 Ok((node, num_outputs).into())
246 }
247
248 fn set_outputs(
254 &mut self,
255 output_wires: impl IntoIterator<Item = Wire>,
256 ) -> Result<(), BuildError> {
257 let [_, out] = self.io();
258 wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
259 BuildError::OutputWiring {
260 container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
261 container_node: self.container_node(),
262 error,
263 }
264 })
265 }
266
267 #[track_caller]
273 fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
274 collect_array(self.input_wires())
275 }
276
277 fn dfg_builder(
287 &mut self,
288 signature: Signature,
289 input_wires: impl IntoIterator<Item = Wire>,
290 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
291 let op = ops::DFG {
292 signature: signature.clone(),
293 };
294 let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
295
296 DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
297 }
298
299 fn dfg_builder_endo(
304 &mut self,
305 inputs: impl IntoIterator<Item = (Type, Wire)>,
306 ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
307 let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
308 self.dfg_builder(Signature::new_endo(types), input_wires)
309 }
310
311 fn cfg_builder(
322 &mut self,
323 inputs: impl IntoIterator<Item = (Type, Wire)>,
324 output_types: TypeRow,
325 ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
326 let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
327
328 let inputs: TypeRow = input_types.into();
329
330 let (cfg_node, _) = add_node_with_wires(
331 self,
332 ops::CFG {
333 signature: Signature::new(inputs.clone(), output_types.clone()),
334 },
335 input_wires,
336 )?;
337 CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
338 }
339
340 fn load_const(&mut self, cid: &ConstID) -> Wire {
343 let const_node = cid.node();
344 let nodetype = self.hugr().get_optype(const_node);
345 let op: ops::Const = nodetype
346 .clone()
347 .try_into()
348 .expect("ConstID does not refer to Const op.");
349
350 let load_n = self
351 .add_dataflow_op(
352 ops::LoadConstant {
353 datatype: op.get_type().clone(),
354 },
355 vec![Wire::new(const_node, OutgoingPort::from(0))],
357 )
358 .expect("The constant type should match the LoadConstant type.");
359
360 load_n.out_wire(0)
361 }
362
363 fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
366 let cid = self.add_constant(constant);
367 self.load_const(&cid)
368 }
369
370 fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
373 self.add_load_const(constant.into())
374 }
375
376 fn load_func<const DEFINED: bool>(
382 &mut self,
383 fid: &FuncID<DEFINED>,
384 type_args: &[TypeArg],
385 ) -> Result<Wire, BuildError> {
386 let func_node = fid.node();
387 let func_op = self.hugr().get_optype(func_node);
388 let func_sig = match func_op {
389 OpType::FuncDefn(fd) => fd.signature().clone(),
390 OpType::FuncDecl(fd) => fd.signature().clone(),
391 _ => {
392 return Err(BuildError::UnexpectedType {
393 node: func_node,
394 op_desc: "FuncDecl/FuncDefn",
395 });
396 }
397 };
398
399 let load_n = self.add_dataflow_op(
400 ops::LoadFunction::try_new(func_sig, type_args)?,
401 vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
403 )?;
404
405 Ok(load_n.out_wire(0))
406 }
407
408 fn tail_loop_builder(
419 &mut self,
420 just_inputs: impl IntoIterator<Item = (Type, Wire)>,
421 inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
422 just_out_types: TypeRow,
423 ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
424 let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
425 just_inputs.into_iter().unzip();
426 let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
427 inputs_outputs.into_iter().unzip();
428 input_wires.extend(rest_input_wires);
429
430 let tail_loop = ops::TailLoop {
431 just_inputs: input_types.into(),
432 just_outputs: just_out_types,
433 rest: rest_types.into(),
434 };
435 let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
437
438 TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
439 }
440
441 fn conditional_builder(
454 &mut self,
455 (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
456 other_inputs: impl IntoIterator<Item = (Type, Wire)>,
457 output_types: TypeRow,
458 ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
459 let mut input_wires = vec![sum_wire];
460 let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
461 other_inputs.into_iter().unzip();
462
463 input_wires.extend(rest_input_wires);
464 let inputs: TypeRow = input_types.into();
465 let sum_rows: Vec<_> = sum_rows.into_iter().collect();
466 let n_cases = sum_rows.len();
467 let n_out_wires = output_types.len();
468
469 let conditional_id = self.add_dataflow_op(
470 ops::Conditional {
471 sum_rows,
472 other_inputs: inputs,
473 outputs: output_types,
474 },
475 input_wires,
476 )?;
477
478 Ok(ConditionalBuilder {
479 base: self.hugr_mut(),
480 conditional_node: conditional_id.node(),
481 n_out_wires,
482 case_nodes: vec![None; n_cases],
483 })
484 }
485
486 fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
489 self.add_other_wire(before.node(), after.node());
490 }
491
492 fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
494 let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
495
496 if let Some(EdgeKind::Value(typ)) = kind {
497 Ok(typ)
498 } else {
499 Err(BuildError::WireNotFound(wire))
500 }
501 }
502
503 fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
511 let values = values.into_iter().collect_vec();
512 let types: Result<Vec<Type>, _> = values
513 .iter()
514 .map(|&wire| self.get_wire_type(wire))
515 .collect();
516 let types = types?.into();
517 let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
518 Ok(make_op.out_wire(0))
519 }
520
521 fn make_sum(
531 &mut self,
532 tag: usize,
533 variants: impl IntoIterator<Item = TypeRow>,
534 values: impl IntoIterator<Item = Wire>,
535 ) -> Result<Wire, BuildError> {
536 let make_op = self.add_dataflow_op(
537 Tag {
538 tag,
539 variants: variants.into_iter().collect_vec(),
540 },
541 values.into_iter().collect_vec(),
542 )?;
543 Ok(make_op.out_wire(0))
544 }
545
546 fn make_continue(
556 &mut self,
557 tail_loop: ops::TailLoop,
558 values: impl IntoIterator<Item = Wire>,
559 ) -> Result<Wire, BuildError> {
560 self.make_sum(
561 TailLoop::CONTINUE_TAG,
562 [tail_loop.just_inputs, tail_loop.just_outputs],
563 values,
564 )
565 }
566
567 fn make_break(
577 &mut self,
578 loop_op: ops::TailLoop,
579 values: impl IntoIterator<Item = Wire>,
580 ) -> Result<Wire, BuildError> {
581 self.make_sum(
582 TailLoop::BREAK_TAG,
583 [loop_op.just_inputs, loop_op.just_outputs],
584 values,
585 )
586 }
587
588 fn call<const DEFINED: bool>(
597 &mut self,
598 function: &FuncID<DEFINED>,
599 type_args: &[TypeArg],
600 input_wires: impl IntoIterator<Item = Wire>,
601 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
602 let hugr = self.hugr();
603 let def_op = hugr.get_optype(function.node());
604 let type_scheme = match def_op {
605 OpType::FuncDefn(fd) => fd.signature().clone(),
606 OpType::FuncDecl(fd) => fd.signature().clone(),
607 _ => {
608 return Err(BuildError::UnexpectedType {
609 node: function.node(),
610 op_desc: "FuncDecl/FuncDefn",
611 });
612 }
613 };
614 let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
615 let const_in_port = op.static_input_port().unwrap();
616 let op_id = self.add_dataflow_op(op, input_wires)?;
617 let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
618
619 self.hugr_mut()
620 .connect(function.node(), src_port, op_id.node(), const_in_port);
621 Ok(op_id)
622 }
623
624 fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<Self> {
627 CircuitBuilder::new(wires, self)
628 }
629
630 fn add_barrier(
639 &mut self,
640 wires: impl IntoIterator<Item = Wire>,
641 ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
642 let wires = wires.into_iter().collect_vec();
643 let types: Result<Vec<Type>, _> =
644 wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
645 let types = types?;
646 let barrier_op =
647 self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
648 Ok(barrier_op)
649 }
650}
651
652fn add_node_with_wires<T: Dataflow + ?Sized>(
661 data_builder: &mut T,
662 nodetype: impl Into<OpType>,
663 inputs: impl IntoIterator<Item = Wire>,
664) -> Result<(Node, usize), BuildError> {
665 let op: OpType = nodetype.into();
666 let num_outputs = op.value_output_count();
667 let op_node = data_builder.add_child_node(op.clone());
668
669 wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
670 op: Box::new(op),
671 error,
672 })?;
673
674 Ok((op_node, num_outputs))
675}
676
677fn wire_up_inputs<T: Dataflow + ?Sized>(
685 inputs: impl IntoIterator<Item = Wire>,
686 op_node: Node,
687 data_builder: &mut T,
688) -> Result<(), BuilderWiringError> {
689 for (dst_port, wire) in inputs.into_iter().enumerate() {
690 wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
691 }
692 Ok(())
693}
694
695fn wire_up<T: Dataflow + ?Sized>(
701 data_builder: &mut T,
702 src: Node,
703 src_port: impl Into<OutgoingPort>,
704 dst: Node,
705 dst_port: impl Into<IncomingPort>,
706) -> Result<bool, BuilderWiringError> {
707 let src_port = src_port.into();
708 let dst_port = dst_port.into();
709 let base = data_builder.hugr_mut();
710
711 let src_parent = base.get_parent(src);
712 let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
713 let dst_parent = base.get_parent(dst);
714 let local_source = src_parent == dst_parent;
715 if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
716 if !local_source {
717 if !typ.copyable() {
719 return Err(BuilderWiringError::NonCopyableIntergraph {
720 src,
721 src_offset: src_port.into(),
722 dst,
723 dst_offset: dst_port.into(),
724 typ: Box::new(typ),
725 });
726 }
727
728 let src_parent = src_parent.expect("Node has no parent");
729 let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
730 .tuple_windows()
731 .find_map(|(ancestor, ancestor_parent)| {
732 (ancestor_parent == src_parent ||
733 Some(ancestor_parent) == src_parent_parent)
735 .then_some(ancestor)
736 })
737 else {
738 return Err(BuilderWiringError::NoRelationIntergraph {
739 src,
740 src_offset: src_port.into(),
741 dst,
742 dst_offset: dst_port.into(),
743 });
744 };
745
746 if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
747 && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
748 {
749 base.add_other_edge(src, src_sibling);
751 }
752 } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
753 return Err(BuilderWiringError::NoCopyLinear {
755 typ: Box::new(typ),
756 src,
757 src_offset: src_port.into(),
758 });
759 }
760 }
761
762 data_builder
763 .hugr_mut()
764 .connect(src, src_port, dst, dst_port);
765 Ok(local_source
766 && matches!(
767 data_builder
768 .hugr_mut()
769 .get_optype(dst)
770 .port_kind(dst_port)
771 .unwrap(),
772 EdgeKind::Value(_)
773 ))
774}
775
776pub trait DataflowHugr: HugrBuilder + Dataflow {
778 fn finish_hugr_with_outputs(
784 mut self,
785 outputs: impl IntoIterator<Item = Wire>,
786 ) -> Result<Hugr, BuildError>
787 where
788 Self: Sized,
789 {
790 self.set_outputs(outputs)?;
791 Ok(self.finish_hugr()?)
792 }
793}
794
795pub trait DataflowSubContainer: SubContainer + Dataflow {
797 fn finish_with_outputs(
804 mut self,
805 outputs: impl IntoIterator<Item = Wire>,
806 ) -> Result<Self::ContainerHandle, BuildError>
807 where
808 Self: Sized,
809 {
810 self.set_outputs(outputs)?;
811 self.finish_sub_container()
812 }
813}
814
815impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
816impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}