hugr_core/builder/
build_traits.rs

1use crate::extension::prelude::MakeTuple;
2use crate::hugr::hugrmut::InsertionResult;
3use crate::hugr::views::HugrView;
4use crate::hugr::{NodeMetadata, ValidationError};
5use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop};
6use crate::utils::collect_array;
7use crate::{Extension, IncomingPort, Node, OutgoingPort};
8
9use std::iter;
10use std::sync::Arc;
11
12use super::{BuilderWiringError, ModuleBuilder};
13use super::{
14    CircuitBuilder,
15    handle::{BuildHandle, Outputs},
16};
17
18use crate::{
19    ops::handle::{ConstID, DataflowOpID, FuncID, NodeHandle},
20    types::EdgeKind,
21};
22
23use crate::extension::ExtensionRegistry;
24use crate::types::{Signature, Type, TypeArg, TypeRow};
25
26use itertools::Itertools;
27
28use super::{
29    BuildError, Wire, cfg::CFGBuilder, conditional::ConditionalBuilder, dataflow::DFGBuilder,
30    tail_loop::TailLoopBuilder,
31};
32
33use crate::Hugr;
34
35use crate::hugr::HugrMut;
36
37/// Trait for HUGR container builders.
38/// Containers are nodes that are parents of sibling graphs.
39/// Implementations of this trait allow the child sibling graph to be added to
40/// the HUGR.
41pub trait Container {
42    /// The container node.
43    fn container_node(&self) -> Node;
44    /// The underlying [`Hugr`] being built
45    fn hugr_mut(&mut self) -> &mut Hugr;
46    /// Immutable reference to HUGR being built
47    fn hugr(&self) -> &Hugr;
48    /// Add an [`OpType`] as the final child of the container.
49    ///
50    /// Adds the extensions required by the op to the HUGR, if they are not already present.
51    fn add_child_node(&mut self, node: impl Into<OpType>) -> Node {
52        let node: OpType = node.into();
53
54        // Add the extension the operation is defined in to the HUGR.
55        let used_extensions = node
56            .used_extensions()
57            .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}"));
58        self.use_extensions(used_extensions);
59
60        let parent = self.container_node();
61        self.hugr_mut().add_node_with_parent(parent, node)
62    }
63
64    /// Adds a non-dataflow edge between two nodes. The kind is given by the operation's [`other_inputs`] or  [`other_outputs`]
65    ///
66    /// [`other_inputs`]: crate::ops::OpTrait::other_input
67    /// [`other_outputs`]: crate::ops::OpTrait::other_output
68    fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
69        let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
70        Wire::new(src, src_port)
71    }
72
73    /// Add a constant value to the container and return a handle to it.
74    ///
75    /// Adds the extensions required by the op to the HUGR, if they are not already present.
76    ///
77    /// # Errors
78    ///
79    /// This function will return an error if there is an error in adding the
80    /// [`OpType::Const`] node.
81    fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
82        self.add_child_node(constant.into()).into()
83    }
84
85    /// Insert a HUGR as a child of the container.
86    fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
87        let parent = self.container_node();
88        self.hugr_mut().insert_hugr(parent, child)
89    }
90
91    /// Insert a copy of a HUGR as a child of the container.
92    fn add_hugr_view<H: HugrView>(&mut self, child: &H) -> InsertionResult<H::Node, Node> {
93        let parent = self.container_node();
94        self.hugr_mut().insert_from_view(parent, child)
95    }
96
97    /// Add metadata to the container node.
98    fn set_metadata(&mut self, key: impl AsRef<str>, meta: impl Into<NodeMetadata>) {
99        let parent = self.container_node();
100        // Implementor's container_node() should be a valid node
101        self.hugr_mut().set_metadata(parent, key, meta);
102    }
103
104    /// Add metadata to a child node.
105    ///
106    /// Returns an error if the specified `child` is not a child of this container
107    fn set_child_metadata(
108        &mut self,
109        child: Node,
110        key: impl AsRef<str>,
111        meta: impl Into<NodeMetadata>,
112    ) {
113        self.hugr_mut().set_metadata(child, key, meta);
114    }
115
116    /// Add an extension to the set of extensions used by the hugr.
117    fn use_extension(&mut self, ext: impl Into<Arc<Extension>>) {
118        self.hugr_mut().use_extension(ext);
119    }
120
121    /// Extend the set of extensions used by the hugr with the extensions in the registry.
122    fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
123    where
124        ExtensionRegistry: Extend<Reg>,
125    {
126        self.hugr_mut().use_extensions(registry);
127    }
128}
129
130/// Types implementing this trait can be used to build complete HUGRs
131/// (with varying entrypoint node types)
132pub trait HugrBuilder: Container {
133    /// Allows adding definitions to the module root of which
134    /// this builder is building a part
135    fn module_root_builder(&mut self) -> ModuleBuilder<&mut Hugr> {
136        debug_assert!(
137            self.hugr()
138                .get_optype(self.hugr().module_root())
139                .is_module()
140        );
141        ModuleBuilder(self.hugr_mut())
142    }
143
144    /// Finish building the HUGR, perform any validation checks and return it.
145    fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>>;
146}
147
148/// Types implementing this trait build a container graph region by borrowing a HUGR
149pub trait SubContainer: Container {
150    /// A handle to the finished container node, typically returned when the
151    /// child graph has been finished.
152    type ContainerHandle;
153    /// Consume the container builder and return the handle, may perform some
154    /// checks before finishing.
155    fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError>;
156}
157/// Trait for building dataflow regions of a HUGR.
158pub trait Dataflow: Container {
159    /// Return the number of inputs to the dataflow sibling graph.
160    fn num_inputs(&self) -> usize;
161    /// Return indices of input and output nodes.
162    fn io(&self) -> [Node; 2] {
163        self.hugr()
164            .children(self.container_node())
165            .take(2)
166            .collect_vec()
167            .try_into()
168            .expect("First two children should be IO")
169    }
170    /// Handle to input node.
171    fn input(&self) -> BuildHandle<DataflowOpID> {
172        (self.io()[0], self.num_inputs()).into()
173    }
174    /// Handle to output node.
175    fn output(&self) -> DataflowOpID {
176        self.io()[1].into()
177    }
178    /// Return iterator over all input Value wires.
179    fn input_wires(&self) -> Outputs {
180        self.input().outputs()
181    }
182    /// Add a dataflow [`OpType`] to the sibling graph, wiring up the `input_wires` to the
183    /// incoming ports of the resulting node.
184    ///
185    /// Adds the extensions required by the op to the HUGR, if they are not already present.
186    ///
187    /// # Errors
188    ///
189    /// Returns a [`BuildError::OperationWiring`] error if the `input_wires` cannot be connected.
190    fn add_dataflow_op(
191        &mut self,
192        nodetype: impl Into<OpType>,
193        input_wires: impl IntoIterator<Item = Wire>,
194    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
195        let outs = add_node_with_wires(self, nodetype, input_wires)?;
196
197        Ok(outs.into())
198    }
199
200    /// Insert a hugr-defined op to the sibling graph, wiring up the
201    /// `input_wires` to the incoming ports of the resulting root node.
202    ///
203    /// # Errors
204    ///
205    /// This function will return an error if there is an error when adding the
206    /// node.
207    fn add_hugr_with_wires(
208        &mut self,
209        hugr: Hugr,
210        input_wires: impl IntoIterator<Item = Wire>,
211    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
212        let optype = hugr.get_optype(hugr.entrypoint()).clone();
213        let num_outputs = optype.value_output_count();
214        let node = self.add_hugr(hugr).inserted_entrypoint;
215
216        wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
217            op: Box::new(optype),
218            error,
219        })?;
220
221        Ok((node, num_outputs).into())
222    }
223
224    /// Copy a hugr-defined op into the sibling graph, wiring up the
225    /// `input_wires` to the incoming ports of the resulting root node.
226    ///
227    /// # Errors
228    ///
229    /// This function will return an error if there is an error when adding the
230    /// node.
231    fn add_hugr_view_with_wires(
232        &mut self,
233        hugr: &impl HugrView,
234        input_wires: impl IntoIterator<Item = Wire>,
235    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
236        let node = self.add_hugr_view(hugr).inserted_entrypoint;
237        let optype = hugr.get_optype(hugr.entrypoint()).clone();
238        let num_outputs = optype.value_output_count();
239
240        wire_up_inputs(input_wires, node, self).map_err(|error| BuildError::OperationWiring {
241            op: Box::new(optype),
242            error,
243        })?;
244
245        Ok((node, num_outputs).into())
246    }
247
248    /// Wire up the `output_wires` to the input ports of the Output node.
249    ///
250    /// # Errors
251    ///
252    /// This function will return an error if there is an error when wiring up.
253    fn set_outputs(
254        &mut self,
255        output_wires: impl IntoIterator<Item = Wire>,
256    ) -> Result<(), BuildError> {
257        let [_, out] = self.io();
258        wire_up_inputs(output_wires.into_iter().collect_vec(), out, self).map_err(|error| {
259            BuildError::OutputWiring {
260                container_op: Box::new(self.hugr().get_optype(self.container_node()).clone()),
261                container_node: self.container_node(),
262                error,
263            }
264        })
265    }
266
267    /// Return an array of the input wires.
268    ///
269    /// # Panics
270    ///
271    /// Panics if the number of input Wires does not match the size of the array.
272    #[track_caller]
273    fn input_wires_arr<const N: usize>(&self) -> [Wire; N] {
274        collect_array(self.input_wires())
275    }
276
277    /// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
278    /// given a signature describing its input and output types and extension delta,
279    /// and the input wires (which must match the input types)
280    ///
281    /// # Errors
282    ///
283    /// This function will return an error if there is an error when building
284    /// the DFG node.
285    // TODO: Should this be one function, or should there be a temporary "op" one like with the others?
286    fn dfg_builder(
287        &mut self,
288        signature: Signature,
289        input_wires: impl IntoIterator<Item = Wire>,
290    ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
291        let op = ops::DFG {
292            signature: signature.clone(),
293        };
294        let (dfg_n, _) = add_node_with_wires(self, op, input_wires)?;
295
296        DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
297    }
298
299    /// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
300    /// that is endomorphic (the output types are the same as the input types).
301    /// The `inputs` must be an iterable over pairs of the type of the input and
302    /// the corresponding wire.
303    fn dfg_builder_endo(
304        &mut self,
305        inputs: impl IntoIterator<Item = (Type, Wire)>,
306    ) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
307        let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
308        self.dfg_builder(Signature::new_endo(types), input_wires)
309    }
310
311    /// Return a builder for a [`crate::ops::CFG`] node,
312    /// i.e. a nested controlflow subgraph.
313    /// The `inputs` must be an iterable over pairs of the type of the input and
314    /// the corresponding wire.
315    /// The `output_types` are the types of the outputs.
316    ///
317    /// # Errors
318    ///
319    /// This function will return an error if there is an error when building
320    /// the CFG node.
321    fn cfg_builder(
322        &mut self,
323        inputs: impl IntoIterator<Item = (Type, Wire)>,
324        output_types: TypeRow,
325    ) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
326        let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
327
328        let inputs: TypeRow = input_types.into();
329
330        let (cfg_node, _) = add_node_with_wires(
331            self,
332            ops::CFG {
333                signature: Signature::new(inputs.clone(), output_types.clone()),
334            },
335            input_wires,
336        )?;
337        CFGBuilder::create(self.hugr_mut(), cfg_node, inputs, output_types)
338    }
339
340    /// Load a static constant and return the local dataflow wire for that constant.
341    /// Adds a [`OpType::LoadConstant`] node.
342    fn load_const(&mut self, cid: &ConstID) -> Wire {
343        let const_node = cid.node();
344        let nodetype = self.hugr().get_optype(const_node);
345        let op: ops::Const = nodetype
346            .clone()
347            .try_into()
348            .expect("ConstID does not refer to Const op.");
349
350        let load_n = self
351            .add_dataflow_op(
352                ops::LoadConstant {
353                    datatype: op.get_type().clone(),
354                },
355                // Constant wire from the constant value node
356                vec![Wire::new(const_node, OutgoingPort::from(0))],
357            )
358            .expect("The constant type should match the LoadConstant type.");
359
360        load_n.out_wire(0)
361    }
362
363    /// Load a static constant and return the local dataflow wire for that constant.
364    /// Adds a [`ops::Const`] and a [`ops::LoadConstant`] node.
365    fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
366        let cid = self.add_constant(constant);
367        self.load_const(&cid)
368    }
369
370    /// Load a [`ops::Value`] and return the local dataflow wire for that constant.
371    /// Adds a [`ops::Const`] and a [`ops::LoadConstant`] node.
372    fn add_load_value(&mut self, constant: impl Into<ops::Value>) -> Wire {
373        self.add_load_const(constant.into())
374    }
375
376    /// Load a static function and return the local dataflow wire for that function.
377    /// Adds a [`OpType::LoadFunction`] node.
378    ///
379    /// The `DEF` const generic is used to indicate whether the function is defined
380    /// or just declared.
381    fn load_func<const DEFINED: bool>(
382        &mut self,
383        fid: &FuncID<DEFINED>,
384        type_args: &[TypeArg],
385    ) -> Result<Wire, BuildError> {
386        let func_node = fid.node();
387        let func_op = self.hugr().get_optype(func_node);
388        let func_sig = match func_op {
389            OpType::FuncDefn(fd) => fd.signature().clone(),
390            OpType::FuncDecl(fd) => fd.signature().clone(),
391            _ => {
392                return Err(BuildError::UnexpectedType {
393                    node: func_node,
394                    op_desc: "FuncDecl/FuncDefn",
395                });
396            }
397        };
398
399        let load_n = self.add_dataflow_op(
400            ops::LoadFunction::try_new(func_sig, type_args)?,
401            // Static wire from the function node
402            vec![Wire::new(func_node, func_op.static_output_port().unwrap())],
403        )?;
404
405        Ok(load_n.out_wire(0))
406    }
407
408    /// Return a builder for a [`crate::ops::TailLoop`] node.
409    /// The `inputs` must be an iterable over pairs of the type of the input and
410    /// the corresponding wire.
411    /// The `output_types` are the types of the outputs.
412    ///
413    /// # Errors
414    ///
415    /// This function will return an error if there is an error when building
416    /// the [`ops::TailLoop`] node.
417    ///
418    fn tail_loop_builder(
419        &mut self,
420        just_inputs: impl IntoIterator<Item = (Type, Wire)>,
421        inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
422        just_out_types: TypeRow,
423    ) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
424        let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
425            just_inputs.into_iter().unzip();
426        let (rest_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
427            inputs_outputs.into_iter().unzip();
428        input_wires.extend(rest_input_wires);
429
430        let tail_loop = ops::TailLoop {
431            just_inputs: input_types.into(),
432            just_outputs: just_out_types,
433            rest: rest_types.into(),
434        };
435        // TODO: Make input extensions a parameter
436        let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
437
438        TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
439    }
440
441    /// Return a builder for a [`crate::ops::Conditional`] node.
442    /// `sum_input` is a tuple of the type of the Sum
443    /// variants and the corresponding wire.
444    ///
445    /// The `other_inputs` must be an iterable over pairs of the type of the input and
446    /// the corresponding wire.
447    /// The `output_types` are the types of the outputs.
448    ///
449    /// # Errors
450    ///
451    /// This function will return an error if there is an error when building
452    /// the Conditional node.
453    fn conditional_builder(
454        &mut self,
455        (sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
456        other_inputs: impl IntoIterator<Item = (Type, Wire)>,
457        output_types: TypeRow,
458    ) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
459        let mut input_wires = vec![sum_wire];
460        let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
461            other_inputs.into_iter().unzip();
462
463        input_wires.extend(rest_input_wires);
464        let inputs: TypeRow = input_types.into();
465        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
466        let n_cases = sum_rows.len();
467        let n_out_wires = output_types.len();
468
469        let conditional_id = self.add_dataflow_op(
470            ops::Conditional {
471                sum_rows,
472                other_inputs: inputs,
473                outputs: output_types,
474            },
475            input_wires,
476        )?;
477
478        Ok(ConditionalBuilder {
479            base: self.hugr_mut(),
480            conditional_node: conditional_id.node(),
481            n_out_wires,
482            case_nodes: vec![None; n_cases],
483        })
484    }
485
486    /// Add an order edge from `before` to `after`. Assumes any additional edges
487    /// to both nodes will be Order kind.
488    fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
489        self.add_other_wire(before.node(), after.node());
490    }
491
492    /// Get the type of a Value [`Wire`]. If not valid port or of Value kind, returns None.
493    fn get_wire_type(&self, wire: Wire) -> Result<Type, BuildError> {
494        let kind = self.hugr().get_optype(wire.node()).port_kind(wire.source());
495
496        if let Some(EdgeKind::Value(typ)) = kind {
497            Ok(typ)
498        } else {
499            Err(BuildError::WireNotFound(wire))
500        }
501    }
502
503    /// Add a [`MakeTuple`] node and wire in the `values` Wires,
504    /// returning the Wire corresponding to the tuple.
505    ///
506    /// # Errors
507    ///
508    /// This function will return an error if there is an error adding the
509    /// [`MakeTuple`] node.
510    fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
511        let values = values.into_iter().collect_vec();
512        let types: Result<Vec<Type>, _> = values
513            .iter()
514            .map(|&wire| self.get_wire_type(wire))
515            .collect();
516        let types = types?.into();
517        let make_op = self.add_dataflow_op(MakeTuple(types), values)?;
518        Ok(make_op.out_wire(0))
519    }
520
521    /// Add a [`Tag`] node and wire in the `value` Wire,
522    /// to make a value with Sum type, with `tag` and possible types described
523    /// by `variants`.
524    /// Returns the Wire corresponding to the Sum value.
525    ///
526    /// # Errors
527    ///
528    /// This function will return an error if there is an error adding the
529    /// Tag node.
530    fn make_sum(
531        &mut self,
532        tag: usize,
533        variants: impl IntoIterator<Item = TypeRow>,
534        values: impl IntoIterator<Item = Wire>,
535    ) -> Result<Wire, BuildError> {
536        let make_op = self.add_dataflow_op(
537            Tag {
538                tag,
539                variants: variants.into_iter().collect_vec(),
540            },
541            values.into_iter().collect_vec(),
542        )?;
543        Ok(make_op.out_wire(0))
544    }
545
546    /// Use the wires in `values` to return a wire corresponding to the
547    /// "Continue" variant of a [`ops::TailLoop`] with `loop_signature`.
548    ///
549    /// Packs the values in to a tuple and tags appropriately to generate a
550    /// value of Sum type.
551    ///
552    /// # Errors
553    ///
554    /// This function will return an error if there is an error in adding the nodes.
555    fn make_continue(
556        &mut self,
557        tail_loop: ops::TailLoop,
558        values: impl IntoIterator<Item = Wire>,
559    ) -> Result<Wire, BuildError> {
560        self.make_sum(
561            TailLoop::CONTINUE_TAG,
562            [tail_loop.just_inputs, tail_loop.just_outputs],
563            values,
564        )
565    }
566
567    /// Use the wires in `values` to return a wire corresponding to the
568    /// "Break" variant of a [`ops::TailLoop`] with `loop_signature`.
569    ///
570    /// Packs the values in to a tuple and tags appropriately to generate a
571    /// value of Sum type.
572    ///
573    /// # Errors
574    ///
575    /// This function will return an error if there is an error in adding the nodes.
576    fn make_break(
577        &mut self,
578        loop_op: ops::TailLoop,
579        values: impl IntoIterator<Item = Wire>,
580    ) -> Result<Wire, BuildError> {
581        self.make_sum(
582            TailLoop::BREAK_TAG,
583            [loop_op.just_inputs, loop_op.just_outputs],
584            values,
585        )
586    }
587
588    /// Add a [`ops::Call`] node, calling `function`, with inputs
589    /// specified by `input_wires`. Returns a handle to the corresponding Call node.
590    ///
591    /// # Errors
592    ///
593    /// This function will return an error if there is an error adding the Call
594    /// node, or if `function` does not refer to a [`ops::FuncDecl`] or
595    /// [`ops::FuncDefn`] node.
596    fn call<const DEFINED: bool>(
597        &mut self,
598        function: &FuncID<DEFINED>,
599        type_args: &[TypeArg],
600        input_wires: impl IntoIterator<Item = Wire>,
601    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
602        let hugr = self.hugr();
603        let def_op = hugr.get_optype(function.node());
604        let type_scheme = match def_op {
605            OpType::FuncDefn(fd) => fd.signature().clone(),
606            OpType::FuncDecl(fd) => fd.signature().clone(),
607            _ => {
608                return Err(BuildError::UnexpectedType {
609                    node: function.node(),
610                    op_desc: "FuncDecl/FuncDefn",
611                });
612            }
613        };
614        let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into();
615        let const_in_port = op.static_input_port().unwrap();
616        let op_id = self.add_dataflow_op(op, input_wires)?;
617        let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
618
619        self.hugr_mut()
620            .connect(function.node(), src_port, op_id.node(), const_in_port);
621        Ok(op_id)
622    }
623
624    /// For the vector of `wires`, produce a `CircuitBuilder` where ops can be
625    /// added using indices in to the vector.
626    fn as_circuit(&mut self, wires: impl IntoIterator<Item = Wire>) -> CircuitBuilder<Self> {
627        CircuitBuilder::new(wires, self)
628    }
629
630    /// Add a [Barrier] to a set of wires and return them in the same order.
631    ///
632    /// [Barrier]: crate::extension::prelude::Barrier
633    ///
634    /// # Errors
635    ///
636    /// This function will return an error if there is an error adding the Barrier node
637    /// or retrieving the type of the incoming wires.
638    fn add_barrier(
639        &mut self,
640        wires: impl IntoIterator<Item = Wire>,
641    ) -> Result<BuildHandle<DataflowOpID>, BuildError> {
642        let wires = wires.into_iter().collect_vec();
643        let types: Result<Vec<Type>, _> =
644            wires.iter().map(|&wire| self.get_wire_type(wire)).collect();
645        let types = types?;
646        let barrier_op =
647            self.add_dataflow_op(crate::extension::prelude::Barrier::new(types), wires)?;
648        Ok(barrier_op)
649    }
650}
651
652/// Add a node to the graph, wiring up the `inputs` to the input ports of the resulting node.
653///
654/// Adds the extensions required by the op to the HUGR, if they are not already present.
655///
656/// # Errors
657///
658/// Returns a [`BuildError::OperationWiring`] if any of the connections produces an
659/// invalid edge.
660fn add_node_with_wires<T: Dataflow + ?Sized>(
661    data_builder: &mut T,
662    nodetype: impl Into<OpType>,
663    inputs: impl IntoIterator<Item = Wire>,
664) -> Result<(Node, usize), BuildError> {
665    let op: OpType = nodetype.into();
666    let num_outputs = op.value_output_count();
667    let op_node = data_builder.add_child_node(op.clone());
668
669    wire_up_inputs(inputs, op_node, data_builder).map_err(|error| BuildError::OperationWiring {
670        op: Box::new(op),
671        error,
672    })?;
673
674    Ok((op_node, num_outputs))
675}
676
677/// Connect each of the `inputs` wires sequentially to the input ports of
678/// `op_node`.
679///
680/// # Errors
681///
682/// Returns a [`BuilderWiringError`] if any of the connections produces an
683/// invalid edge.
684fn wire_up_inputs<T: Dataflow + ?Sized>(
685    inputs: impl IntoIterator<Item = Wire>,
686    op_node: Node,
687    data_builder: &mut T,
688) -> Result<(), BuilderWiringError> {
689    for (dst_port, wire) in inputs.into_iter().enumerate() {
690        wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
691    }
692    Ok(())
693}
694
695/// Add edge from src to dst.
696///
697/// # Errors
698///
699/// Returns a [`BuilderWiringError`] if the edge is invalid.
700fn wire_up<T: Dataflow + ?Sized>(
701    data_builder: &mut T,
702    src: Node,
703    src_port: impl Into<OutgoingPort>,
704    dst: Node,
705    dst_port: impl Into<IncomingPort>,
706) -> Result<bool, BuilderWiringError> {
707    let src_port = src_port.into();
708    let dst_port = dst_port.into();
709    let base = data_builder.hugr_mut();
710
711    let src_parent = base.get_parent(src);
712    let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
713    let dst_parent = base.get_parent(dst);
714    let local_source = src_parent == dst_parent;
715    if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
716        if !local_source {
717            // Non-local value sources require a state edge to an ancestor of dst
718            if !typ.copyable() {
719                return Err(BuilderWiringError::NonCopyableIntergraph {
720                    src,
721                    src_offset: src_port.into(),
722                    dst,
723                    dst_offset: dst_port.into(),
724                    typ: Box::new(typ),
725                });
726            }
727
728            let src_parent = src_parent.expect("Node has no parent");
729            let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
730                .tuple_windows()
731                .find_map(|(ancestor, ancestor_parent)| {
732                    (ancestor_parent == src_parent ||
733                        // Dom edge - in CFGs
734                        Some(ancestor_parent) == src_parent_parent)
735                        .then_some(ancestor)
736                })
737            else {
738                return Err(BuilderWiringError::NoRelationIntergraph {
739                    src,
740                    src_offset: src_port.into(),
741                    dst,
742                    dst_offset: dst_port.into(),
743                });
744            };
745
746            if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
747                && !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
748            {
749                // Add a state order constraint unless one of the nodes is a CFG BasicBlock
750                base.add_other_edge(src, src_sibling);
751            }
752        } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
753            // Don't copy linear edges.
754            return Err(BuilderWiringError::NoCopyLinear {
755                typ: Box::new(typ),
756                src,
757                src_offset: src_port.into(),
758            });
759        }
760    }
761
762    data_builder
763        .hugr_mut()
764        .connect(src, src_port, dst, dst_port);
765    Ok(local_source
766        && matches!(
767            data_builder
768                .hugr_mut()
769                .get_optype(dst)
770                .port_kind(dst_port)
771                .unwrap(),
772            EdgeKind::Value(_)
773        ))
774}
775
776/// Trait implemented by builders of Dataflow Hugrs
777pub trait DataflowHugr: HugrBuilder + Dataflow {
778    /// Set outputs of dataflow HUGR and return validated HUGR
779    /// # Errors
780    ///
781    /// * if there is an error when setting outputs
782    /// * if the Hugr does not validate
783    fn finish_hugr_with_outputs(
784        mut self,
785        outputs: impl IntoIterator<Item = Wire>,
786    ) -> Result<Hugr, BuildError>
787    where
788        Self: Sized,
789    {
790        self.set_outputs(outputs)?;
791        Ok(self.finish_hugr()?)
792    }
793}
794
795/// Trait implemented by builders of Dataflow container regions of a HUGR
796pub trait DataflowSubContainer: SubContainer + Dataflow {
797    /// Set the outputs of the graph and consume the builder, while returning a
798    /// handle to the parent.
799    ///
800    /// # Errors
801    ///
802    /// This function will return an error if there is an error when setting outputs.
803    fn finish_with_outputs(
804        mut self,
805        outputs: impl IntoIterator<Item = Wire>,
806    ) -> Result<Self::ContainerHandle, BuildError>
807    where
808        Self: Sized,
809    {
810        self.set_outputs(outputs)?;
811        self.finish_sub_container()
812    }
813}
814
815impl<T: HugrBuilder + Dataflow> DataflowHugr for T {}
816impl<T: SubContainer + Dataflow> DataflowSubContainer for T {}