1use itertools::Itertools;
2
3use super::build_traits::{HugrBuilder, SubContainer};
4use super::handle::BuildHandle;
5use super::{BuildError, Container, Dataflow, DfgID, FuncID};
6
7use std::marker::PhantomData;
8
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::{HugrView, ValidationError};
11use crate::ops::{self, DataflowParent, FuncDefn, Input, OpParent, Output};
12use crate::types::{PolyFuncType, Signature, Type};
13use crate::{Direction, Hugr, IncomingPort, Node, OutgoingPort, Visibility, Wire, hugr::HugrMut};
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct DFGBuilder<T> {
18 pub(crate) base: T,
19 pub(crate) dfg_node: Node,
20 pub(crate) num_in_wires: usize,
21 pub(crate) num_out_wires: usize,
22}
23
24impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
25 pub(super) fn create_with_io(
30 mut base: T,
31 parent: Node,
32 signature: Signature,
33 ) -> Result<Self, BuildError> {
34 debug_assert_eq!(base.as_ref().children(parent).count(), 0);
35
36 let num_in_wires = signature.input_count();
37 let num_out_wires = signature.output_count();
38 let input = ops::Input {
39 types: signature.input().clone(),
40 };
41 let output = ops::Output {
42 types: signature.output().clone(),
43 };
44 base.as_mut().add_node_with_parent(parent, input);
45 base.as_mut().add_node_with_parent(parent, output);
46
47 Ok(Self {
48 base,
49 dfg_node: parent,
50 num_in_wires,
51 num_out_wires,
52 })
53 }
54
55 pub(super) fn create(base: T, parent: Node) -> Result<Self, BuildError> {
62 let sig = base
63 .as_ref()
64 .get_optype(parent)
65 .inner_function_type()
66 .expect("DFG parent must have an inner function signature.");
67 let num_in_wires = sig.input_count();
68 let num_out_wires = sig.output_count();
69
70 Ok(Self {
71 base,
72 dfg_node: parent,
73 num_in_wires,
74 num_out_wires,
75 })
76 }
77}
78
79impl DFGBuilder<Hugr> {
80 pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
87 let dfg_op = ops::DFG {
88 signature: signature.clone(),
89 };
90 let base = Hugr::new_with_entrypoint(dfg_op).expect("DFG entrypoint should be valid");
91 let root = base.entrypoint();
92 DFGBuilder::create_with_io(base, root, signature)
93 }
94}
95
96impl HugrBuilder for DFGBuilder<Hugr> {
97 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
98 self.base.validate()?;
99 Ok(self.base)
100 }
101}
102
103impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for DFGBuilder<T> {
104 #[inline]
105 fn container_node(&self) -> Node {
106 self.dfg_node
107 }
108
109 #[inline]
110 fn hugr_mut(&mut self) -> &mut Hugr {
111 self.base.as_mut()
112 }
113
114 #[inline]
115 fn hugr(&self) -> &Hugr {
116 self.base.as_ref()
117 }
118}
119
120impl<T: AsMut<Hugr> + AsRef<Hugr>> SubContainer for DFGBuilder<T> {
121 type ContainerHandle = BuildHandle<DfgID>;
122 #[inline]
123 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
124 Ok((self.dfg_node, self.num_out_wires).into())
125 }
126}
127
128impl<T: AsMut<Hugr> + AsRef<Hugr>> Dataflow for DFGBuilder<T> {
129 #[inline]
130 fn num_inputs(&self) -> usize {
131 self.num_in_wires
132 }
133}
134
135#[derive(Debug, Clone, PartialEq)]
138pub struct DFGWrapper<B, T>(DFGBuilder<B>, PhantomData<T>);
139
140impl<B, T> DFGWrapper<B, T> {
141 pub(super) fn from_dfg_builder(db: DFGBuilder<B>) -> Self {
142 Self(db, PhantomData)
143 }
144}
145
146pub type FunctionBuilder<B> = DFGWrapper<B, BuildHandle<FuncID<true>>>;
148
149impl FunctionBuilder<Hugr> {
150 pub fn new(
157 name: impl Into<String>,
158 signature: impl Into<PolyFuncType>,
159 ) -> Result<Self, BuildError> {
160 Self::new_with_op(FuncDefn::new(name, signature))
161 }
162
163 pub fn new_vis(
170 name: impl Into<String>,
171 signature: impl Into<PolyFuncType>,
172 visibility: Visibility,
173 ) -> Result<Self, BuildError> {
174 Self::new_with_op(FuncDefn::new_vis(name, signature, visibility))
175 }
176
177 fn new_with_op(op: FuncDefn) -> Result<Self, BuildError> {
178 let body = op.signature().body().clone();
179
180 let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid");
181 let root = base.entrypoint();
182
183 let db = DFGBuilder::create_with_io(base, root, body)?;
184 Ok(Self::from_dfg_builder(db))
185 }
186}
187
188impl<B: AsMut<Hugr> + AsRef<Hugr>> FunctionBuilder<B> {
189 pub fn with_hugr(
197 mut hugr: B,
198 name: impl Into<String>,
199 signature: impl Into<PolyFuncType>,
200 ) -> Result<Self, BuildError> {
201 let signature: PolyFuncType = signature.into();
202 let body = signature.body().clone();
203 let op = ops::FuncDefn::new(name, signature);
204
205 let module = hugr.as_ref().module_root();
206 let func = hugr.as_mut().add_node_with_parent(module, op);
207
208 let db = DFGBuilder::create_with_io(hugr, func, body)?;
209 Ok(Self::from_dfg_builder(db))
210 }
211
212 pub fn add_input(&mut self, input_type: Type) -> Wire {
216 let [inp_node, _] = self.io();
217
218 let new_optype = self.update_fn_signature(|mut s| {
220 s.input.to_mut().push(input_type);
221 s
222 });
223
224 let types = new_optype.signature().body().input.clone();
226 self.hugr_mut().replace_op(inp_node, Input { types });
227 let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
228 let new_port = new_port.next().unwrap();
229
230 let new_value_port: OutgoingPort = (new_port - 1).into();
232 let new_order_port: OutgoingPort = new_port.into();
233 let order_edge_targets = self
234 .hugr()
235 .linked_inputs(inp_node, new_value_port)
236 .collect_vec();
237 self.hugr_mut().disconnect(inp_node, new_value_port);
238 for (tgt_node, tgt_port) in order_edge_targets {
239 self.hugr_mut()
240 .connect(inp_node, new_order_port, tgt_node, tgt_port);
241 }
242
243 self.0.num_in_wires += 1;
245
246 self.input_wires().next_back().unwrap()
247 }
248
249 pub fn add_output(&mut self, output_type: Type) {
251 let [_, out_node] = self.io();
252
253 let new_optype = self.update_fn_signature(|mut s| {
255 s.output.to_mut().push(output_type);
256 s
257 });
258
259 let types = new_optype.signature().body().output.clone();
261 self.hugr_mut().replace_op(out_node, Output { types });
262 let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
263 let new_port = new_port.next().unwrap();
264
265 let new_value_port: IncomingPort = (new_port - 1).into();
267 let new_order_port: IncomingPort = new_port.into();
268 let order_edge_sources = self
269 .hugr()
270 .linked_outputs(out_node, new_value_port)
271 .collect_vec();
272 self.hugr_mut().disconnect(out_node, new_value_port);
273 for (src_node, src_port) in order_edge_sources {
274 self.hugr_mut()
275 .connect(src_node, src_port, out_node, new_order_port);
276 }
277
278 self.0.num_out_wires += 1;
280 }
281
282 fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
290 let parent = self.container_node();
291
292 let ops::OpType::FuncDefn(fd) = self.hugr_mut().optype_mut(parent) else {
293 panic!("FunctionBuilder node must be a FuncDefn")
294 };
295 *fd.signature_mut() = f(fd.inner_signature().into_owned()).into();
296 &*fd
297 }
298}
299
300impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
301 #[inline]
302 fn container_node(&self) -> Node {
303 self.0.container_node()
304 }
305
306 #[inline]
307 fn hugr_mut(&mut self) -> &mut Hugr {
308 self.0.hugr_mut()
309 }
310
311 #[inline]
312 fn hugr(&self) -> &Hugr {
313 self.0.hugr()
314 }
315}
316
317impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Dataflow for DFGWrapper<B, T> {
318 #[inline]
319 fn num_inputs(&self) -> usize {
320 self.0.num_inputs()
321 }
322}
323
324impl<B: AsMut<Hugr> + AsRef<Hugr>, T: From<BuildHandle<DfgID>>> SubContainer for DFGWrapper<B, T> {
325 type ContainerHandle = T;
326
327 #[inline]
328 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
329 self.0.finish_sub_container().map(Into::into)
330 }
331}
332
333impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
334 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
335 self.0.finish_hugr()
336 }
337}
338
339#[cfg(test)]
340pub(crate) mod test {
341 use cool_asserts::assert_matches;
342 use ops::OpParent;
343 use rstest::rstest;
344 use serde_json::json;
345
346 use crate::builder::build_traits::DataflowHugr;
347 use crate::builder::{
348 BuilderWiringError, DataflowSubContainer, ModuleBuilder, endo_sig, inout_sig,
349 };
350 use crate::extension::SignatureError;
351 use crate::extension::prelude::Noop;
352 use crate::extension::prelude::{bool_t, qb_t, usize_t};
353 use crate::hugr::validate::InterGraphEdgeError;
354 use crate::ops::{OpTag, handle::NodeHandle};
355 use crate::ops::{OpTrait, Value};
356
357 use crate::std_extensions::logic::test::and_op;
358 use crate::types::type_param::TypeParam;
359 use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV};
360 use crate::utils::test_quantum_extension::h_gate;
361 use crate::{Wire, builder::test::n_identity, type_row};
362
363 use super::super::test::simple_dfg_hugr;
364 use super::*;
365 #[test]
366 fn nested_identity() -> Result<(), BuildError> {
367 let build_result = {
368 let mut outer_builder = DFGBuilder::new(endo_sig(vec![usize_t(), qb_t()]))?;
369
370 let [int, qb] = outer_builder.input_wires_arr();
371
372 let q_out = outer_builder.add_dataflow_op(h_gate(), vec![qb])?;
373
374 let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?;
375 let inner_id = n_identity(inner_builder)?;
376
377 outer_builder.finish_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs()))
378 };
379
380 assert_eq!(build_result.err(), None);
381
382 Ok(())
383 }
384
385 fn copy_scaffold<F>(f: F, msg: &'static str) -> Result<(), BuildError>
387 where
388 F: FnOnce(&mut DFGBuilder<Hugr>) -> Result<(), BuildError>,
389 {
390 let build_result = {
391 let mut builder = DFGBuilder::new(inout_sig(bool_t(), vec![bool_t(), bool_t()]))?;
392
393 f(&mut builder)?;
394
395 builder.finish_hugr()
396 };
397 assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
398
399 Ok(())
400 }
401 #[test]
402 fn copy_insertion() -> Result<(), BuildError> {
403 copy_scaffold(
404 |f_build| {
405 let [b1] = f_build.input_wires_arr();
406 f_build.set_outputs([b1, b1])
407 },
408 "Copy input and output",
409 )?;
410
411 copy_scaffold(
412 |f_build| {
413 let [b1] = f_build.input_wires_arr();
414 let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?;
415 f_build.set_outputs([xor.out_wire(0), b1])
416 },
417 "Copy input and use with binary function",
418 )?;
419
420 copy_scaffold(
421 |f_build| {
422 let [b1] = f_build.input_wires_arr();
423 let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?;
424 let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?;
425 f_build.set_outputs([xor2.out_wire(0), b1])
426 },
427 "Copy multiple times",
428 )?;
429
430 Ok(())
431 }
432
433 #[test]
434 fn copy_insertion_qubit() {
435 let builder = || {
436 let mut module_builder = ModuleBuilder::new();
437
438 let f_build = module_builder
439 .define_function("main", Signature::new(vec![qb_t()], vec![qb_t(), qb_t()]))?;
440
441 let [q1] = f_build.input_wires_arr();
442 f_build.finish_with_outputs([q1, q1])?;
443
444 Ok(module_builder.finish_hugr()?)
445 };
446
447 assert_matches!(
448 builder(),
449 Err(BuildError::OutputWiring {
450 error: BuilderWiringError::NoCopyLinear { typ, .. },
451 ..
452 })
453 if *typ == qb_t()
454 );
455 }
456
457 #[test]
458 fn simple_inter_graph_edge() {
459 let builder = || -> Result<Hugr, BuildError> {
460 let mut f_build =
461 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
462
463 let [i1] = f_build.input_wires_arr();
464 let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?;
465 let i1 = noop.out_wire(0);
466
467 let mut nested =
468 f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?;
469
470 let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?;
471
472 let nested = nested.finish_with_outputs([id.out_wire(0)])?;
473
474 f_build.finish_hugr_with_outputs([nested.out_wire(0)])
475 };
476
477 assert_matches!(builder(), Ok(_));
478 }
479
480 #[test]
481 fn add_inputs_outputs() {
482 let builder = || -> Result<(Hugr, Node), BuildError> {
483 let mut f_build =
484 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
485 let f_node = f_build.container_node();
486
487 let [i0] = f_build.input_wires_arr();
488 let noop0 = f_build.add_dataflow_op(Noop(bool_t()), [i0])?;
489
490 f_build.set_order(&f_build.io()[0], &noop0.node());
492 f_build.set_order(&noop0.node(), &f_build.io()[1]);
493
494 f_build.add_output(qb_t());
496 let i1 = f_build.add_input(qb_t());
497 let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
498
499 let hugr = f_build.finish_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
500 Ok((hugr, f_node))
501 };
502
503 let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));
504
505 let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
506 assert_eq!(
507 func_sig.io(),
508 (
509 &vec![bool_t(), qb_t()].into(),
510 &vec![bool_t(), qb_t()].into()
511 )
512 );
513 }
514
515 #[test]
516 fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
517 let mut f_build = FunctionBuilder::new("main", Signature::new(vec![qb_t()], vec![qb_t()]))?;
518
519 let [i1] = f_build.input_wires_arr();
520 let noop = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
521 let i1 = noop.out_wire(0);
522
523 let mut nested = f_build.dfg_builder(Signature::new(type_row![], vec![qb_t()]), [])?;
524
525 let id_res = nested.add_dataflow_op(Noop(qb_t()), [i1]);
526
527 assert_matches!(
530 id_res.map(|bh| bh.handle().node()), Err(BuildError::OperationWiring {
532 error: BuilderWiringError::NonCopyableIntergraph { .. },
533 ..
534 })
535 );
536
537 Ok(())
538 }
539
540 #[rstest]
541 fn dfg_hugr(simple_dfg_hugr: Hugr) {
542 assert_eq!(simple_dfg_hugr.num_nodes(), 7);
543 assert_eq!(simple_dfg_hugr.entry_descendants().count(), 3);
544 assert_matches!(simple_dfg_hugr.entrypoint_optype().tag(), OpTag::Dfg);
545 }
546
547 #[test]
548 fn insert_hugr() -> Result<(), BuildError> {
549 let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?;
551 let [i1] = dfg_builder.input_wires_arr();
552 dfg_builder.set_metadata("x", 42);
553 let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
554
555 let mut module_builder = ModuleBuilder::new();
557
558 let (dfg_node, f_node) = {
559 let mut f_build =
560 module_builder.define_function("main", Signature::new_endo(bool_t()))?;
561
562 let [i1] = f_build.input_wires_arr();
563 let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?;
564 let f = f_build.finish_with_outputs([dfg.out_wire(0)])?;
565 module_builder.set_child_metadata(f.node(), "x", "hi");
566 (dfg.node(), f.node())
567 };
568
569 let hugr = module_builder.finish_hugr()?;
570 assert_eq!(hugr.entry_descendants().count(), 7);
571
572 assert_eq!(hugr.get_metadata(hugr.entrypoint(), "x"), None);
573 assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42)));
574 assert_eq!(hugr.get_metadata(f_node, "x").cloned(), Some(json!("hi")));
575
576 Ok(())
577 }
578
579 #[test]
580 fn barrier_node() -> Result<(), BuildError> {
581 let mut parent = DFGBuilder::new(endo_sig(bool_t()))?;
582
583 let [w] = parent.input_wires_arr();
584
585 let mut dfg_b = parent.dfg_builder(endo_sig(bool_t()), [w])?;
586 let [w] = dfg_b.input_wires_arr();
587
588 let barr0 = dfg_b.add_barrier([w])?;
589 let [w] = barr0.outputs_arr();
590
591 let barr1 = dfg_b.add_barrier([w])?;
592 let [w] = barr1.outputs_arr();
593
594 let dfg = dfg_b.finish_with_outputs([w])?;
595 let [w] = dfg.outputs_arr();
596
597 let mut dfg2_b = parent.dfg_builder(endo_sig(vec![bool_t(), bool_t()]), [w, w])?;
598 let [w1, w2] = dfg2_b.input_wires_arr();
599 let barr2 = dfg2_b.add_barrier([w1, w2])?;
600 let wires: Vec<Wire> = barr2.outputs().collect();
601
602 let dfg2 = dfg2_b.finish_with_outputs(wires)?;
603 let [w, _] = dfg2.outputs_arr();
604 parent.finish_hugr_with_outputs([w])?;
605
606 Ok(())
607 }
608
609 #[test]
610 fn non_cfg_ancestor() -> Result<(), BuildError> {
611 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
612 let mut b = DFGBuilder::new(unit_sig.clone())?;
613 let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
614 let b_child_in_wire = b_child.input().out_wire(0);
615 b_child.finish_with_outputs([])?;
616 let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
617
618 let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;
621
622 let res = b.finish_hugr_with_outputs([b_child_2_handle.out_wire(0)]);
623
624 assert_matches!(
625 res,
626 Err(BuildError::InvalidHUGR(
627 ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
628 ))
629 );
630 Ok(())
631 }
632
633 #[test]
634 fn no_relation_edge() -> Result<(), BuildError> {
635 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
636 let mut b = DFGBuilder::new(unit_sig.clone())?;
637 let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
638 let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?;
639 let b_child_child_in_wire = b_child_child.input().out_wire(0);
640
641 b_child_child.finish_with_outputs([])?;
642 b_child.finish_with_outputs([])?;
643
644 let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
645 let b_child_2_child =
646 b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?;
647
648 let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);
649
650 assert_matches!(
651 res.map(|h| h.handle().node()), Err(BuildError::OutputWiring {
653 error: BuilderWiringError::NoRelationIntergraph { .. },
654 ..
655 })
656 );
657 Ok(())
658 }
659
660 #[test]
661 fn no_outer_row_variables() -> Result<(), BuildError> {
662 let e = crate::hugr::validate::test::extension_with_eval_parallel();
663 let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable);
664 FunctionBuilder::new(
666 "bad_eval",
667 PolyFuncType::new(
668 [TypeParam::new_list_type(TypeBound::Copyable)],
669 Signature::new(
670 Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
671 vec![],
672 ),
673 ),
674 )?;
675
676 let ev = e.instantiate_extension_op(
678 "eval",
679 [vec![usize_t().into()].into(), vec![tv.into()].into()],
680 );
681 assert_eq!(
682 ev,
683 Err(SignatureError::RowVarWhereTypeExpected {
684 var: RowVariable(0, TypeBound::Copyable)
685 })
686 );
687 Ok(())
688 }
689
690 #[test]
691 fn order_edges() {
692 let (mut hugr, load_constant, call) = {
693 let mut builder = ModuleBuilder::new();
694 let func = builder
695 .declare("func", Signature::new_endo(bool_t()).into())
696 .unwrap();
697 let (load_constant, call) = {
698 let mut builder = builder
699 .define_function("main", Signature::new(Type::EMPTY_TYPEROW, bool_t()))
700 .unwrap();
701 let load_constant = builder.add_load_value(Value::true_val());
702 let [r] = builder
703 .call(&func, &[], [load_constant])
704 .unwrap()
705 .outputs_arr();
706 builder.finish_with_outputs([r]).unwrap();
707 (load_constant.node(), r.node())
708 };
709 (builder.finish_hugr().unwrap(), load_constant, call)
710 };
711
712 let lc_optype = hugr.get_optype(load_constant);
713 let call_optype = hugr.get_optype(call);
714 assert_eq!(EdgeKind::StateOrder, lc_optype.other_input().unwrap());
715 assert_eq!(EdgeKind::StateOrder, lc_optype.other_output().unwrap());
716 assert_eq!(EdgeKind::StateOrder, call_optype.other_input().unwrap());
717 assert_eq!(EdgeKind::StateOrder, call_optype.other_output().unwrap());
718
719 hugr.connect(
720 load_constant,
721 lc_optype.other_output_port().unwrap(),
722 call,
723 call_optype.other_input_port().unwrap(),
724 );
725
726 hugr.validate().unwrap();
727 }
728}