1use itertools::Itertools;
2
3use super::build_traits::{HugrBuilder, SubContainer};
4use super::handle::BuildHandle;
5use super::{BuildError, Container, Dataflow, DfgID, FuncID};
6
7use std::marker::PhantomData;
8
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::{HugrView, ValidationError};
11use crate::ops::{self, OpParent};
12use crate::ops::{DataflowParent, Input, Output};
13use crate::{Direction, IncomingPort, OutgoingPort, Wire};
14
15use crate::types::{PolyFuncType, Signature, Type};
16
17use crate::Node;
18use crate::{Hugr, hugr::HugrMut};
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct DFGBuilder<T> {
23 pub(crate) base: T,
24 pub(crate) dfg_node: Node,
25 pub(crate) num_in_wires: usize,
26 pub(crate) num_out_wires: usize,
27}
28
29impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
30 pub(super) fn create_with_io(
35 mut base: T,
36 parent: Node,
37 signature: Signature,
38 ) -> Result<Self, BuildError> {
39 debug_assert_eq!(base.as_ref().children(parent).count(), 0);
40
41 let num_in_wires = signature.input_count();
42 let num_out_wires = signature.output_count();
43 let input = ops::Input {
44 types: signature.input().clone(),
45 };
46 let output = ops::Output {
47 types: signature.output().clone(),
48 };
49 base.as_mut().add_node_with_parent(parent, input);
50 base.as_mut().add_node_with_parent(parent, output);
51
52 Ok(Self {
53 base,
54 dfg_node: parent,
55 num_in_wires,
56 num_out_wires,
57 })
58 }
59
60 pub(super) fn create(base: T, parent: Node) -> Result<Self, BuildError> {
67 let sig = base
68 .as_ref()
69 .get_optype(parent)
70 .inner_function_type()
71 .expect("DFG parent must have an inner function signature.");
72 let num_in_wires = sig.input_count();
73 let num_out_wires = sig.output_count();
74
75 Ok(Self {
76 base,
77 dfg_node: parent,
78 num_in_wires,
79 num_out_wires,
80 })
81 }
82}
83
84impl DFGBuilder<Hugr> {
85 pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
92 let dfg_op = ops::DFG {
93 signature: signature.clone(),
94 };
95 let base = Hugr::new_with_entrypoint(dfg_op).expect("DFG entrypoint should be valid");
96 let root = base.entrypoint();
97 DFGBuilder::create_with_io(base, root, signature)
98 }
99}
100
101impl HugrBuilder for DFGBuilder<Hugr> {
102 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
103 self.base.validate()?;
104 Ok(self.base)
105 }
106}
107
108impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for DFGBuilder<T> {
109 #[inline]
110 fn container_node(&self) -> Node {
111 self.dfg_node
112 }
113
114 #[inline]
115 fn hugr_mut(&mut self) -> &mut Hugr {
116 self.base.as_mut()
117 }
118
119 #[inline]
120 fn hugr(&self) -> &Hugr {
121 self.base.as_ref()
122 }
123}
124
125impl<T: AsMut<Hugr> + AsRef<Hugr>> SubContainer for DFGBuilder<T> {
126 type ContainerHandle = BuildHandle<DfgID>;
127 #[inline]
128 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
129 Ok((self.dfg_node, self.num_out_wires).into())
130 }
131}
132
133impl<T: AsMut<Hugr> + AsRef<Hugr>> Dataflow for DFGBuilder<T> {
134 #[inline]
135 fn num_inputs(&self) -> usize {
136 self.num_in_wires
137 }
138}
139
140#[derive(Debug, Clone, PartialEq)]
143pub struct DFGWrapper<B, T>(DFGBuilder<B>, PhantomData<T>);
144
145impl<B, T> DFGWrapper<B, T> {
146 pub(super) fn from_dfg_builder(db: DFGBuilder<B>) -> Self {
147 Self(db, PhantomData)
148 }
149}
150
151pub type FunctionBuilder<B> = DFGWrapper<B, BuildHandle<FuncID<true>>>;
153
154impl FunctionBuilder<Hugr> {
155 pub fn new(
160 name: impl Into<String>,
161 signature: impl Into<PolyFuncType>,
162 ) -> Result<Self, BuildError> {
163 let signature: PolyFuncType = signature.into();
164 let body = signature.body().clone();
165 let op = ops::FuncDefn::new(name, signature);
166
167 let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid");
168 let root = base.entrypoint();
169
170 let db = DFGBuilder::create_with_io(base, root, body)?;
171 Ok(Self::from_dfg_builder(db))
172 }
173
174 pub fn add_input(&mut self, input_type: Type) -> Wire {
178 let [inp_node, _] = self.io();
179
180 let new_optype = self.update_fn_signature(|mut s| {
182 s.input.to_mut().push(input_type);
183 s
184 });
185
186 let types = new_optype.signature().body().input.clone();
188 self.hugr_mut().replace_op(inp_node, Input { types });
189 let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
190 let new_port = new_port.next().unwrap();
191
192 let new_value_port: OutgoingPort = (new_port - 1).into();
194 let new_order_port: OutgoingPort = new_port.into();
195 let order_edge_targets = self
196 .hugr()
197 .linked_inputs(inp_node, new_value_port)
198 .collect_vec();
199 self.hugr_mut().disconnect(inp_node, new_value_port);
200 for (tgt_node, tgt_port) in order_edge_targets {
201 self.hugr_mut()
202 .connect(inp_node, new_order_port, tgt_node, tgt_port);
203 }
204
205 self.0.num_in_wires += 1;
207
208 self.input_wires().next_back().unwrap()
209 }
210
211 pub fn add_output(&mut self, output_type: Type) {
213 let [_, out_node] = self.io();
214
215 let new_optype = self.update_fn_signature(|mut s| {
217 s.output.to_mut().push(output_type);
218 s
219 });
220
221 let types = new_optype.signature().body().output.clone();
223 self.hugr_mut().replace_op(out_node, Output { types });
224 let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
225 let new_port = new_port.next().unwrap();
226
227 let new_value_port: IncomingPort = (new_port - 1).into();
229 let new_order_port: IncomingPort = new_port.into();
230 let order_edge_sources = self
231 .hugr()
232 .linked_outputs(out_node, new_value_port)
233 .collect_vec();
234 self.hugr_mut().disconnect(out_node, new_value_port);
235 for (src_node, src_port) in order_edge_sources {
236 self.hugr_mut()
237 .connect(src_node, src_port, out_node, new_order_port);
238 }
239
240 self.0.num_out_wires += 1;
242 }
243
244 fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
252 let parent = self.container_node();
253
254 let ops::OpType::FuncDefn(fd) = self.hugr_mut().optype_mut(parent) else {
255 panic!("FunctionBuilder node must be a FuncDefn")
256 };
257 *fd.signature_mut() = f(fd.inner_signature().into_owned()).into();
258 &*fd
259 }
260}
261
262impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
263 #[inline]
264 fn container_node(&self) -> Node {
265 self.0.container_node()
266 }
267
268 #[inline]
269 fn hugr_mut(&mut self) -> &mut Hugr {
270 self.0.hugr_mut()
271 }
272
273 #[inline]
274 fn hugr(&self) -> &Hugr {
275 self.0.hugr()
276 }
277}
278
279impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Dataflow for DFGWrapper<B, T> {
280 #[inline]
281 fn num_inputs(&self) -> usize {
282 self.0.num_inputs()
283 }
284}
285
286impl<B: AsMut<Hugr> + AsRef<Hugr>, T: From<BuildHandle<DfgID>>> SubContainer for DFGWrapper<B, T> {
287 type ContainerHandle = T;
288
289 #[inline]
290 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
291 self.0.finish_sub_container().map(Into::into)
292 }
293}
294
295impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
296 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
297 self.0.finish_hugr()
298 }
299}
300
301#[cfg(test)]
302pub(crate) mod test {
303 use cool_asserts::assert_matches;
304 use ops::OpParent;
305 use rstest::rstest;
306 use serde_json::json;
307
308 use crate::builder::build_traits::DataflowHugr;
309 use crate::builder::{
310 BuilderWiringError, DataflowSubContainer, ModuleBuilder, endo_sig, inout_sig,
311 };
312 use crate::extension::SignatureError;
313 use crate::extension::prelude::Noop;
314 use crate::extension::prelude::{bool_t, qb_t, usize_t};
315 use crate::hugr::validate::InterGraphEdgeError;
316 use crate::ops::{OpTag, handle::NodeHandle};
317 use crate::ops::{OpTrait, Value};
318
319 use crate::std_extensions::logic::test::and_op;
320 use crate::types::type_param::TypeParam;
321 use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV};
322 use crate::utils::test_quantum_extension::h_gate;
323 use crate::{Wire, builder::test::n_identity, type_row};
324
325 use super::super::test::simple_dfg_hugr;
326 use super::*;
327 #[test]
328 fn nested_identity() -> Result<(), BuildError> {
329 let build_result = {
330 let mut outer_builder = DFGBuilder::new(endo_sig(vec![usize_t(), qb_t()]))?;
331
332 let [int, qb] = outer_builder.input_wires_arr();
333
334 let q_out = outer_builder.add_dataflow_op(h_gate(), vec![qb])?;
335
336 let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?;
337 let inner_id = n_identity(inner_builder)?;
338
339 outer_builder.finish_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs()))
340 };
341
342 assert_eq!(build_result.err(), None);
343
344 Ok(())
345 }
346
347 fn copy_scaffold<F>(f: F, msg: &'static str) -> Result<(), BuildError>
349 where
350 F: FnOnce(&mut DFGBuilder<Hugr>) -> Result<(), BuildError>,
351 {
352 let build_result = {
353 let mut builder = DFGBuilder::new(inout_sig(bool_t(), vec![bool_t(), bool_t()]))?;
354
355 f(&mut builder)?;
356
357 builder.finish_hugr()
358 };
359 assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
360
361 Ok(())
362 }
363 #[test]
364 fn copy_insertion() -> Result<(), BuildError> {
365 copy_scaffold(
366 |f_build| {
367 let [b1] = f_build.input_wires_arr();
368 f_build.set_outputs([b1, b1])
369 },
370 "Copy input and output",
371 )?;
372
373 copy_scaffold(
374 |f_build| {
375 let [b1] = f_build.input_wires_arr();
376 let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?;
377 f_build.set_outputs([xor.out_wire(0), b1])
378 },
379 "Copy input and use with binary function",
380 )?;
381
382 copy_scaffold(
383 |f_build| {
384 let [b1] = f_build.input_wires_arr();
385 let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?;
386 let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?;
387 f_build.set_outputs([xor2.out_wire(0), b1])
388 },
389 "Copy multiple times",
390 )?;
391
392 Ok(())
393 }
394
395 #[test]
396 fn copy_insertion_qubit() {
397 let builder = || {
398 let mut module_builder = ModuleBuilder::new();
399
400 let f_build = module_builder
401 .define_function("main", Signature::new(vec![qb_t()], vec![qb_t(), qb_t()]))?;
402
403 let [q1] = f_build.input_wires_arr();
404 f_build.finish_with_outputs([q1, q1])?;
405
406 Ok(module_builder.finish_hugr()?)
407 };
408
409 assert_matches!(
410 builder(),
411 Err(BuildError::OutputWiring {
412 error: BuilderWiringError::NoCopyLinear { typ, .. },
413 ..
414 })
415 if typ == qb_t()
416 );
417 }
418
419 #[test]
420 fn simple_inter_graph_edge() {
421 let builder = || -> Result<Hugr, BuildError> {
422 let mut f_build =
423 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
424
425 let [i1] = f_build.input_wires_arr();
426 let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?;
427 let i1 = noop.out_wire(0);
428
429 let mut nested =
430 f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?;
431
432 let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?;
433
434 let nested = nested.finish_with_outputs([id.out_wire(0)])?;
435
436 f_build.finish_hugr_with_outputs([nested.out_wire(0)])
437 };
438
439 assert_matches!(builder(), Ok(_));
440 }
441
442 #[test]
443 fn add_inputs_outputs() {
444 let builder = || -> Result<(Hugr, Node), BuildError> {
445 let mut f_build =
446 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
447 let f_node = f_build.container_node();
448
449 let [i0] = f_build.input_wires_arr();
450 let noop0 = f_build.add_dataflow_op(Noop(bool_t()), [i0])?;
451
452 f_build.set_order(&f_build.io()[0], &noop0.node());
454 f_build.set_order(&noop0.node(), &f_build.io()[1]);
455
456 f_build.add_output(qb_t());
458 let i1 = f_build.add_input(qb_t());
459 let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
460
461 let hugr = f_build.finish_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
462 Ok((hugr, f_node))
463 };
464
465 let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));
466
467 let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
468 assert_eq!(
469 func_sig.io(),
470 (
471 &vec![bool_t(), qb_t()].into(),
472 &vec![bool_t(), qb_t()].into()
473 )
474 );
475 }
476
477 #[test]
478 fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
479 let mut f_build = FunctionBuilder::new("main", Signature::new(vec![qb_t()], vec![qb_t()]))?;
480
481 let [i1] = f_build.input_wires_arr();
482 let noop = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
483 let i1 = noop.out_wire(0);
484
485 let mut nested = f_build.dfg_builder(Signature::new(type_row![], vec![qb_t()]), [])?;
486
487 let id_res = nested.add_dataflow_op(Noop(qb_t()), [i1]);
488
489 assert_matches!(
492 id_res.map(|bh| bh.handle().node()), Err(BuildError::OperationWiring {
494 error: BuilderWiringError::NonCopyableIntergraph { .. },
495 ..
496 })
497 );
498
499 Ok(())
500 }
501
502 #[rstest]
503 fn dfg_hugr(simple_dfg_hugr: Hugr) {
504 assert_eq!(simple_dfg_hugr.num_nodes(), 7);
505 assert_eq!(simple_dfg_hugr.entry_descendants().count(), 3);
506 assert_matches!(simple_dfg_hugr.entrypoint_optype().tag(), OpTag::Dfg);
507 }
508
509 #[test]
510 fn insert_hugr() -> Result<(), BuildError> {
511 let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?;
513 let [i1] = dfg_builder.input_wires_arr();
514 dfg_builder.set_metadata("x", 42);
515 let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
516
517 let mut module_builder = ModuleBuilder::new();
519
520 let (dfg_node, f_node) = {
521 let mut f_build =
522 module_builder.define_function("main", Signature::new_endo(bool_t()))?;
523
524 let [i1] = f_build.input_wires_arr();
525 let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?;
526 let f = f_build.finish_with_outputs([dfg.out_wire(0)])?;
527 module_builder.set_child_metadata(f.node(), "x", "hi");
528 (dfg.node(), f.node())
529 };
530
531 let hugr = module_builder.finish_hugr()?;
532 assert_eq!(hugr.entry_descendants().count(), 7);
533
534 assert_eq!(hugr.get_metadata(hugr.entrypoint(), "x"), None);
535 assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42)));
536 assert_eq!(hugr.get_metadata(f_node, "x").cloned(), Some(json!("hi")));
537
538 Ok(())
539 }
540
541 #[test]
542 fn barrier_node() -> Result<(), BuildError> {
543 let mut parent = DFGBuilder::new(endo_sig(bool_t()))?;
544
545 let [w] = parent.input_wires_arr();
546
547 let mut dfg_b = parent.dfg_builder(endo_sig(bool_t()), [w])?;
548 let [w] = dfg_b.input_wires_arr();
549
550 let barr0 = dfg_b.add_barrier([w])?;
551 let [w] = barr0.outputs_arr();
552
553 let barr1 = dfg_b.add_barrier([w])?;
554 let [w] = barr1.outputs_arr();
555
556 let dfg = dfg_b.finish_with_outputs([w])?;
557 let [w] = dfg.outputs_arr();
558
559 let mut dfg2_b = parent.dfg_builder(endo_sig(vec![bool_t(), bool_t()]), [w, w])?;
560 let [w1, w2] = dfg2_b.input_wires_arr();
561 let barr2 = dfg2_b.add_barrier([w1, w2])?;
562 let wires: Vec<Wire> = barr2.outputs().collect();
563
564 let dfg2 = dfg2_b.finish_with_outputs(wires)?;
565 let [w, _] = dfg2.outputs_arr();
566 parent.finish_hugr_with_outputs([w])?;
567
568 Ok(())
569 }
570
571 #[test]
572 fn non_cfg_ancestor() -> Result<(), BuildError> {
573 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
574 let mut b = DFGBuilder::new(unit_sig.clone())?;
575 let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
576 let b_child_in_wire = b_child.input().out_wire(0);
577 b_child.finish_with_outputs([])?;
578 let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
579
580 let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;
583
584 let res = b.finish_hugr_with_outputs([b_child_2_handle.out_wire(0)]);
585
586 assert_matches!(
587 res,
588 Err(BuildError::InvalidHUGR(
589 ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
590 ))
591 );
592 Ok(())
593 }
594
595 #[test]
596 fn no_relation_edge() -> Result<(), BuildError> {
597 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
598 let mut b = DFGBuilder::new(unit_sig.clone())?;
599 let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
600 let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?;
601 let b_child_child_in_wire = b_child_child.input().out_wire(0);
602
603 b_child_child.finish_with_outputs([])?;
604 b_child.finish_with_outputs([])?;
605
606 let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
607 let b_child_2_child =
608 b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?;
609
610 let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);
611
612 assert_matches!(
613 res.map(|h| h.handle().node()), Err(BuildError::OutputWiring {
615 error: BuilderWiringError::NoRelationIntergraph { .. },
616 ..
617 })
618 );
619 Ok(())
620 }
621
622 #[test]
623 fn no_outer_row_variables() -> Result<(), BuildError> {
624 let e = crate::hugr::validate::test::extension_with_eval_parallel();
625 let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable);
626 FunctionBuilder::new(
628 "bad_eval",
629 PolyFuncType::new(
630 [TypeParam::new_list(TypeBound::Copyable)],
631 Signature::new(
632 Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
633 vec![],
634 ),
635 ),
636 )?;
637
638 let ev = e.instantiate_extension_op(
640 "eval",
641 [vec![usize_t().into()].into(), vec![tv.into()].into()],
642 );
643 assert_eq!(
644 ev,
645 Err(SignatureError::RowVarWhereTypeExpected {
646 var: RowVariable(0, TypeBound::Copyable)
647 })
648 );
649 Ok(())
650 }
651
652 #[test]
653 fn order_edges() {
654 let (mut hugr, load_constant, call) = {
655 let mut builder = ModuleBuilder::new();
656 let func = builder
657 .declare("func", Signature::new_endo(bool_t()).into())
658 .unwrap();
659 let (load_constant, call) = {
660 let mut builder = builder
661 .define_function("main", Signature::new(Type::EMPTY_TYPEROW, bool_t()))
662 .unwrap();
663 let load_constant = builder.add_load_value(Value::true_val());
664 let [r] = builder
665 .call(&func, &[], [load_constant])
666 .unwrap()
667 .outputs_arr();
668 builder.finish_with_outputs([r]).unwrap();
669 (load_constant.node(), r.node())
670 };
671 (builder.finish_hugr().unwrap(), load_constant, call)
672 };
673
674 let lc_optype = hugr.get_optype(load_constant);
675 let call_optype = hugr.get_optype(call);
676 assert_eq!(EdgeKind::StateOrder, lc_optype.other_input().unwrap());
677 assert_eq!(EdgeKind::StateOrder, lc_optype.other_output().unwrap());
678 assert_eq!(EdgeKind::StateOrder, call_optype.other_input().unwrap());
679 assert_eq!(EdgeKind::StateOrder, call_optype.other_output().unwrap());
680
681 hugr.connect(
682 load_constant,
683 lc_optype.other_output_port().unwrap(),
684 call,
685 call_optype.other_input_port().unwrap(),
686 );
687
688 hugr.validate().unwrap();
689 }
690}