1use itertools::Itertools;
2
3use super::build_traits::{HugrBuilder, SubContainer};
4use super::handle::BuildHandle;
5use super::{BuildError, Container, Dataflow, DfgID, FuncID};
6
7use std::marker::PhantomData;
8
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::{HugrView, ValidationError};
11use crate::ops::{self, OpParent};
12use crate::ops::{DataflowParent, Input, Output};
13use crate::{Direction, IncomingPort, OutgoingPort, Wire};
14
15use crate::types::{PolyFuncType, Signature, Type};
16
17use crate::Node;
18use crate::{hugr::HugrMut, Hugr};
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct DFGBuilder<T> {
23 pub(crate) base: T,
24 pub(crate) dfg_node: Node,
25 pub(crate) num_in_wires: usize,
26 pub(crate) num_out_wires: usize,
27}
28
29impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
30 pub(super) fn create_with_io(
35 mut base: T,
36 parent: Node,
37 signature: Signature,
38 ) -> Result<Self, BuildError> {
39 debug_assert_eq!(base.as_ref().children(parent).count(), 0);
40
41 let num_in_wires = signature.input_count();
42 let num_out_wires = signature.output_count();
43 let input = ops::Input {
44 types: signature.input().clone(),
45 };
46 let output = ops::Output {
47 types: signature.output().clone(),
48 };
49 base.as_mut().add_node_with_parent(parent, input);
50 base.as_mut().add_node_with_parent(parent, output);
51
52 Ok(Self {
53 base,
54 dfg_node: parent,
55 num_in_wires,
56 num_out_wires,
57 })
58 }
59
60 pub(super) fn create(base: T, parent: Node) -> Result<Self, BuildError> {
67 let sig = base
68 .as_ref()
69 .get_optype(parent)
70 .inner_function_type()
71 .expect("DFG parent must have an inner function signature.");
72 let num_in_wires = sig.input_count();
73 let num_out_wires = sig.output_count();
74
75 Ok(Self {
76 base,
77 dfg_node: parent,
78 num_in_wires,
79 num_out_wires,
80 })
81 }
82}
83
84impl DFGBuilder<Hugr> {
85 pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
92 let dfg_op = ops::DFG {
93 signature: signature.clone(),
94 };
95 let base = Hugr::new_with_entrypoint(dfg_op).expect("DFG entrypoint should be valid");
96 let root = base.entrypoint();
97 DFGBuilder::create_with_io(base, root, signature)
98 }
99}
100
101impl HugrBuilder for DFGBuilder<Hugr> {
102 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
103 self.base.validate()?;
104 Ok(self.base)
105 }
106}
107
108impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for DFGBuilder<T> {
109 #[inline]
110 fn container_node(&self) -> Node {
111 self.dfg_node
112 }
113
114 #[inline]
115 fn hugr_mut(&mut self) -> &mut Hugr {
116 self.base.as_mut()
117 }
118
119 #[inline]
120 fn hugr(&self) -> &Hugr {
121 self.base.as_ref()
122 }
123}
124
125impl<T: AsMut<Hugr> + AsRef<Hugr>> SubContainer for DFGBuilder<T> {
126 type ContainerHandle = BuildHandle<DfgID>;
127 #[inline]
128 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
129 Ok((self.dfg_node, self.num_out_wires).into())
130 }
131}
132
133impl<T: AsMut<Hugr> + AsRef<Hugr>> Dataflow for DFGBuilder<T> {
134 #[inline]
135 fn num_inputs(&self) -> usize {
136 self.num_in_wires
137 }
138}
139
140#[derive(Debug, Clone, PartialEq)]
143pub struct DFGWrapper<B, T>(DFGBuilder<B>, PhantomData<T>);
144
145impl<B, T> DFGWrapper<B, T> {
146 pub(super) fn from_dfg_builder(db: DFGBuilder<B>) -> Self {
147 Self(db, PhantomData)
148 }
149}
150
151pub type FunctionBuilder<B> = DFGWrapper<B, BuildHandle<FuncID<true>>>;
153
154impl FunctionBuilder<Hugr> {
155 pub fn new(
160 name: impl Into<String>,
161 signature: impl Into<PolyFuncType>,
162 ) -> Result<Self, BuildError> {
163 let signature: PolyFuncType = signature.into();
164 let body = signature.body().clone();
165 let op = ops::FuncDefn {
166 signature,
167 name: name.into(),
168 };
169
170 let base = Hugr::new_with_entrypoint(op).expect("FuncDefn entrypoint should be valid");
171 let root = base.entrypoint();
172
173 let db = DFGBuilder::create_with_io(base, root, body)?;
174 Ok(Self::from_dfg_builder(db))
175 }
176
177 pub fn add_input(&mut self, input_type: Type) -> Wire {
181 let [inp_node, _] = self.io();
182
183 let new_optype = self.update_fn_signature(|mut s| {
185 s.input.to_mut().push(input_type);
186 s
187 });
188
189 let types = new_optype.signature.body().input.clone();
191 self.hugr_mut().replace_op(inp_node, Input { types });
192 let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
193 let new_port = new_port.next().unwrap();
194
195 let new_value_port: OutgoingPort = (new_port - 1).into();
197 let new_order_port: OutgoingPort = new_port.into();
198 let order_edge_targets = self
199 .hugr()
200 .linked_inputs(inp_node, new_value_port)
201 .collect_vec();
202 self.hugr_mut().disconnect(inp_node, new_value_port);
203 for (tgt_node, tgt_port) in order_edge_targets {
204 self.hugr_mut()
205 .connect(inp_node, new_order_port, tgt_node, tgt_port);
206 }
207
208 self.0.num_in_wires += 1;
210
211 self.input_wires().next_back().unwrap()
212 }
213
214 pub fn add_output(&mut self, output_type: Type) {
216 let [_, out_node] = self.io();
217
218 let new_optype = self.update_fn_signature(|mut s| {
220 s.output.to_mut().push(output_type);
221 s
222 });
223
224 let types = new_optype.signature.body().output.clone();
226 self.hugr_mut().replace_op(out_node, Output { types });
227 let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
228 let new_port = new_port.next().unwrap();
229
230 let new_value_port: IncomingPort = (new_port - 1).into();
232 let new_order_port: IncomingPort = new_port.into();
233 let order_edge_sources = self
234 .hugr()
235 .linked_outputs(out_node, new_value_port)
236 .collect_vec();
237 self.hugr_mut().disconnect(out_node, new_value_port);
238 for (src_node, src_port) in order_edge_sources {
239 self.hugr_mut()
240 .connect(src_node, src_port, out_node, new_order_port);
241 }
242
243 self.0.num_out_wires += 1;
245 }
246
247 fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
255 let parent = self.container_node();
256 let old_optype = self
257 .hugr()
258 .get_optype(parent)
259 .as_func_defn()
260 .expect("FunctionBuilder node must be a FuncDefn");
261 let signature = old_optype.inner_signature().into_owned();
262 let name = old_optype.name.clone();
263 self.hugr_mut().replace_op(
264 parent,
265 ops::FuncDefn {
266 signature: f(signature).into(),
267 name,
268 },
269 );
270
271 self.hugr().get_optype(parent).as_func_defn().unwrap()
272 }
273}
274
275impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
276 #[inline]
277 fn container_node(&self) -> Node {
278 self.0.container_node()
279 }
280
281 #[inline]
282 fn hugr_mut(&mut self) -> &mut Hugr {
283 self.0.hugr_mut()
284 }
285
286 #[inline]
287 fn hugr(&self) -> &Hugr {
288 self.0.hugr()
289 }
290}
291
292impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Dataflow for DFGWrapper<B, T> {
293 #[inline]
294 fn num_inputs(&self) -> usize {
295 self.0.num_inputs()
296 }
297}
298
299impl<B: AsMut<Hugr> + AsRef<Hugr>, T: From<BuildHandle<DfgID>>> SubContainer for DFGWrapper<B, T> {
300 type ContainerHandle = T;
301
302 #[inline]
303 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
304 self.0.finish_sub_container().map(Into::into)
305 }
306}
307
308impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
309 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
310 self.0.finish_hugr()
311 }
312}
313
314#[cfg(test)]
315pub(crate) mod test {
316 use cool_asserts::assert_matches;
317 use ops::OpParent;
318 use rstest::rstest;
319 use serde_json::json;
320
321 use crate::builder::build_traits::DataflowHugr;
322 use crate::builder::{
323 endo_sig, inout_sig, BuilderWiringError, DataflowSubContainer, ModuleBuilder,
324 };
325 use crate::extension::prelude::Noop;
326 use crate::extension::prelude::{bool_t, qb_t, usize_t};
327 use crate::extension::SignatureError;
328 use crate::hugr::validate::InterGraphEdgeError;
329 use crate::ops::{handle::NodeHandle, OpTag};
330 use crate::ops::{OpTrait, Value};
331
332 use crate::std_extensions::logic::test::and_op;
333 use crate::types::type_param::TypeParam;
334 use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV};
335 use crate::utils::test_quantum_extension::h_gate;
336 use crate::{builder::test::n_identity, type_row, Wire};
337
338 use super::super::test::simple_dfg_hugr;
339 use super::*;
340 #[test]
341 fn nested_identity() -> Result<(), BuildError> {
342 let build_result = {
343 let mut outer_builder = DFGBuilder::new(endo_sig(vec![usize_t(), qb_t()]))?;
344
345 let [int, qb] = outer_builder.input_wires_arr();
346
347 let q_out = outer_builder.add_dataflow_op(h_gate(), vec![qb])?;
348
349 let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?;
350 let inner_id = n_identity(inner_builder)?;
351
352 outer_builder.finish_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs()))
353 };
354
355 assert_eq!(build_result.err(), None);
356
357 Ok(())
358 }
359
360 fn copy_scaffold<F>(f: F, msg: &'static str) -> Result<(), BuildError>
362 where
363 F: FnOnce(&mut DFGBuilder<Hugr>) -> Result<(), BuildError>,
364 {
365 let build_result = {
366 let mut builder = DFGBuilder::new(inout_sig(bool_t(), vec![bool_t(), bool_t()]))?;
367
368 f(&mut builder)?;
369
370 builder.finish_hugr()
371 };
372 assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
373
374 Ok(())
375 }
376 #[test]
377 fn copy_insertion() -> Result<(), BuildError> {
378 copy_scaffold(
379 |f_build| {
380 let [b1] = f_build.input_wires_arr();
381 f_build.set_outputs([b1, b1])
382 },
383 "Copy input and output",
384 )?;
385
386 copy_scaffold(
387 |f_build| {
388 let [b1] = f_build.input_wires_arr();
389 let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?;
390 f_build.set_outputs([xor.out_wire(0), b1])
391 },
392 "Copy input and use with binary function",
393 )?;
394
395 copy_scaffold(
396 |f_build| {
397 let [b1] = f_build.input_wires_arr();
398 let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?;
399 let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?;
400 f_build.set_outputs([xor2.out_wire(0), b1])
401 },
402 "Copy multiple times",
403 )?;
404
405 Ok(())
406 }
407
408 #[test]
409 fn copy_insertion_qubit() {
410 let builder = || {
411 let mut module_builder = ModuleBuilder::new();
412
413 let f_build = module_builder
414 .define_function("main", Signature::new(vec![qb_t()], vec![qb_t(), qb_t()]))?;
415
416 let [q1] = f_build.input_wires_arr();
417 f_build.finish_with_outputs([q1, q1])?;
418
419 Ok(module_builder.finish_hugr()?)
420 };
421
422 assert_matches!(
423 builder(),
424 Err(BuildError::OutputWiring {
425 error: BuilderWiringError::NoCopyLinear { typ, .. },
426 ..
427 })
428 if typ == qb_t()
429 );
430 }
431
432 #[test]
433 fn simple_inter_graph_edge() {
434 let builder = || -> Result<Hugr, BuildError> {
435 let mut f_build =
436 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
437
438 let [i1] = f_build.input_wires_arr();
439 let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?;
440 let i1 = noop.out_wire(0);
441
442 let mut nested =
443 f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?;
444
445 let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?;
446
447 let nested = nested.finish_with_outputs([id.out_wire(0)])?;
448
449 f_build.finish_hugr_with_outputs([nested.out_wire(0)])
450 };
451
452 assert_matches!(builder(), Ok(_));
453 }
454
455 #[test]
456 fn add_inputs_outputs() {
457 let builder = || -> Result<(Hugr, Node), BuildError> {
458 let mut f_build =
459 FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
460 let f_node = f_build.container_node();
461
462 let [i0] = f_build.input_wires_arr();
463 let noop0 = f_build.add_dataflow_op(Noop(bool_t()), [i0])?;
464
465 f_build.set_order(&f_build.io()[0], &noop0.node());
467 f_build.set_order(&noop0.node(), &f_build.io()[1]);
468
469 f_build.add_output(qb_t());
471 let i1 = f_build.add_input(qb_t());
472 let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
473
474 let hugr = f_build.finish_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
475 Ok((hugr, f_node))
476 };
477
478 let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));
479
480 let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
481 assert_eq!(
482 func_sig.io(),
483 (
484 &vec![bool_t(), qb_t()].into(),
485 &vec![bool_t(), qb_t()].into()
486 )
487 );
488 }
489
490 #[test]
491 fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
492 let mut f_build = FunctionBuilder::new("main", Signature::new(vec![qb_t()], vec![qb_t()]))?;
493
494 let [i1] = f_build.input_wires_arr();
495 let noop = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
496 let i1 = noop.out_wire(0);
497
498 let mut nested = f_build.dfg_builder(Signature::new(type_row![], vec![qb_t()]), [])?;
499
500 let id_res = nested.add_dataflow_op(Noop(qb_t()), [i1]);
501
502 assert_matches!(
505 id_res.map(|bh| bh.handle().node()), Err(BuildError::OperationWiring {
507 error: BuilderWiringError::NonCopyableIntergraph { .. },
508 ..
509 })
510 );
511
512 Ok(())
513 }
514
515 #[rstest]
516 fn dfg_hugr(simple_dfg_hugr: Hugr) {
517 assert_eq!(simple_dfg_hugr.num_nodes(), 7);
518 assert_eq!(simple_dfg_hugr.entry_descendants().count(), 3);
519 assert_matches!(simple_dfg_hugr.entrypoint_optype().tag(), OpTag::Dfg);
520 }
521
522 #[test]
523 fn insert_hugr() -> Result<(), BuildError> {
524 let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?;
526 let [i1] = dfg_builder.input_wires_arr();
527 dfg_builder.set_metadata("x", 42);
528 let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
529
530 let mut module_builder = ModuleBuilder::new();
532
533 let (dfg_node, f_node) = {
534 let mut f_build = module_builder
535 .define_function("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
536
537 let [i1] = f_build.input_wires_arr();
538 let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?;
539 let f = f_build.finish_with_outputs([dfg.out_wire(0)])?;
540 module_builder.set_child_metadata(f.node(), "x", "hi");
541 (dfg.node(), f.node())
542 };
543
544 let hugr = module_builder.finish_hugr()?;
545 assert_eq!(hugr.entry_descendants().count(), 7);
546
547 assert_eq!(hugr.get_metadata(hugr.entrypoint(), "x"), None);
548 assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42)));
549 assert_eq!(hugr.get_metadata(f_node, "x").cloned(), Some(json!("hi")));
550
551 Ok(())
552 }
553
554 #[test]
555 fn barrier_node() -> Result<(), BuildError> {
556 let mut parent = DFGBuilder::new(endo_sig(bool_t()))?;
557
558 let [w] = parent.input_wires_arr();
559
560 let mut dfg_b = parent.dfg_builder(endo_sig(bool_t()), [w])?;
561 let [w] = dfg_b.input_wires_arr();
562
563 let barr0 = dfg_b.add_barrier([w])?;
564 let [w] = barr0.outputs_arr();
565
566 let barr1 = dfg_b.add_barrier([w])?;
567 let [w] = barr1.outputs_arr();
568
569 let dfg = dfg_b.finish_with_outputs([w])?;
570 let [w] = dfg.outputs_arr();
571
572 let mut dfg2_b = parent.dfg_builder(endo_sig(vec![bool_t(), bool_t()]), [w, w])?;
573 let [w1, w2] = dfg2_b.input_wires_arr();
574 let barr2 = dfg2_b.add_barrier([w1, w2])?;
575 let wires: Vec<Wire> = barr2.outputs().collect();
576
577 let dfg2 = dfg2_b.finish_with_outputs(wires)?;
578 let [w, _] = dfg2.outputs_arr();
579 parent.finish_hugr_with_outputs([w])?;
580
581 Ok(())
582 }
583
584 #[test]
585 fn non_cfg_ancestor() -> Result<(), BuildError> {
586 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
587 let mut b = DFGBuilder::new(unit_sig.clone())?;
588 let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
589 let b_child_in_wire = b_child.input().out_wire(0);
590 b_child.finish_with_outputs([])?;
591 let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
592
593 let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;
596
597 let res = b.finish_hugr_with_outputs([b_child_2_handle.out_wire(0)]);
598
599 assert_matches!(
600 res,
601 Err(BuildError::InvalidHUGR(
602 ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
603 ))
604 );
605 Ok(())
606 }
607
608 #[test]
609 fn no_relation_edge() -> Result<(), BuildError> {
610 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
611 let mut b = DFGBuilder::new(unit_sig.clone())?;
612 let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
613 let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?;
614 let b_child_child_in_wire = b_child_child.input().out_wire(0);
615
616 b_child_child.finish_with_outputs([])?;
617 b_child.finish_with_outputs([])?;
618
619 let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
620 let b_child_2_child =
621 b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?;
622
623 let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);
624
625 assert_matches!(
626 res.map(|h| h.handle().node()), Err(BuildError::OutputWiring {
628 error: BuilderWiringError::NoRelationIntergraph { .. },
629 ..
630 })
631 );
632 Ok(())
633 }
634
635 #[test]
636 fn no_outer_row_variables() -> Result<(), BuildError> {
637 let e = crate::hugr::validate::test::extension_with_eval_parallel();
638 let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable);
639 FunctionBuilder::new(
641 "bad_eval",
642 PolyFuncType::new(
643 [TypeParam::new_list(TypeBound::Copyable)],
644 Signature::new(
645 Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
646 vec![],
647 ),
648 ),
649 )?;
650
651 let ev = e.instantiate_extension_op(
653 "eval",
654 [vec![usize_t().into()].into(), vec![tv.into()].into()],
655 );
656 assert_eq!(
657 ev,
658 Err(SignatureError::RowVarWhereTypeExpected {
659 var: RowVariable(0, TypeBound::Copyable)
660 })
661 );
662 Ok(())
663 }
664
665 #[test]
666 fn order_edges() {
667 let (mut hugr, load_constant, call) = {
668 let mut builder = ModuleBuilder::new();
669 let func = builder
670 .declare("func", Signature::new_endo(bool_t()).into())
671 .unwrap();
672 let (load_constant, call) = {
673 let mut builder = builder
674 .define_function("main", Signature::new(Type::EMPTY_TYPEROW, bool_t()))
675 .unwrap();
676 let load_constant = builder.add_load_value(Value::true_val());
677 let [r] = builder
678 .call(&func, &[], [load_constant])
679 .unwrap()
680 .outputs_arr();
681 builder.finish_with_outputs([r]).unwrap();
682 (load_constant.node(), r.node())
683 };
684 (builder.finish_hugr().unwrap(), load_constant, call)
685 };
686
687 let lc_optype = hugr.get_optype(load_constant);
688 let call_optype = hugr.get_optype(call);
689 assert_eq!(EdgeKind::StateOrder, lc_optype.other_input().unwrap());
690 assert_eq!(EdgeKind::StateOrder, lc_optype.other_output().unwrap());
691 assert_eq!(EdgeKind::StateOrder, call_optype.other_input().unwrap());
692 assert_eq!(EdgeKind::StateOrder, call_optype.other_output().unwrap());
693
694 hugr.connect(
695 load_constant,
696 lc_optype.other_output_port().unwrap(),
697 call,
698 call_optype.other_input_port().unwrap(),
699 );
700
701 hugr.validate().unwrap();
702 }
703}