1use itertools::Itertools;
2
3use super::build_traits::{HugrBuilder, SubContainer};
4use super::handle::BuildHandle;
5use super::{BuildError, Container, Dataflow, DfgID, FuncID};
6
7use std::marker::PhantomData;
8
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::{HugrView, ValidationError};
11use crate::ops;
12use crate::ops::{DataflowParent, Input, Output};
13use crate::{Direction, IncomingPort, OutgoingPort, Wire};
14
15use crate::types::{PolyFuncType, Signature, Type};
16
17use crate::Node;
18use crate::{hugr::HugrMut, Hugr};
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct DFGBuilder<T> {
23 pub(crate) base: T,
24 pub(crate) dfg_node: Node,
25 pub(crate) num_in_wires: usize,
26 pub(crate) num_out_wires: usize,
27}
28
29impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
30 pub(super) fn create_with_io(
31 mut base: T,
32 parent: Node,
33 signature: Signature,
34 ) -> Result<Self, BuildError> {
35 let num_in_wires = signature.input().len();
36 let num_out_wires = signature.output().len();
37 let input = ops::Input {
50 types: signature.input().clone(),
51 };
52 let output = ops::Output {
53 types: signature.output().clone(),
54 };
55 base.as_mut().add_node_with_parent(parent, input);
56 base.as_mut().add_node_with_parent(parent, output);
57
58 Ok(Self {
59 base,
60 dfg_node: parent,
61 num_in_wires,
62 num_out_wires,
63 })
64 }
65}
66
67impl DFGBuilder<Hugr> {
68 pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
75 let dfg_op = ops::DFG {
76 signature: signature.clone(),
77 };
78 let base = Hugr::new(dfg_op);
79 let root = base.root();
80 DFGBuilder::create_with_io(base, root, signature)
81 }
82}
83
84impl HugrBuilder for DFGBuilder<Hugr> {
85 fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
86 if cfg!(feature = "extension_inference") {
87 self.base.infer_extensions(false)?;
88 }
89 self.base.validate()?;
90 Ok(self.base)
91 }
92}
93
94impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for DFGBuilder<T> {
95 #[inline]
96 fn container_node(&self) -> Node {
97 self.dfg_node
98 }
99
100 #[inline]
101 fn hugr_mut(&mut self) -> &mut Hugr {
102 self.base.as_mut()
103 }
104
105 #[inline]
106 fn hugr(&self) -> &Hugr {
107 self.base.as_ref()
108 }
109}
110
111impl<T: AsMut<Hugr> + AsRef<Hugr>> SubContainer for DFGBuilder<T> {
112 type ContainerHandle = BuildHandle<DfgID>;
113 #[inline]
114 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
115 Ok((self.dfg_node, self.num_out_wires).into())
116 }
117}
118
119impl<T: AsMut<Hugr> + AsRef<Hugr>> Dataflow for DFGBuilder<T> {
120 #[inline]
121 fn num_inputs(&self) -> usize {
122 self.num_in_wires
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
129pub struct DFGWrapper<B, T>(DFGBuilder<B>, PhantomData<T>);
130
131impl<B, T> DFGWrapper<B, T> {
132 pub(super) fn from_dfg_builder(db: DFGBuilder<B>) -> Self {
133 Self(db, PhantomData)
134 }
135}
136
137pub type FunctionBuilder<B> = DFGWrapper<B, BuildHandle<FuncID<true>>>;
139
140impl FunctionBuilder<Hugr> {
141 pub fn new(
146 name: impl Into<String>,
147 signature: impl Into<PolyFuncType>,
148 ) -> Result<Self, BuildError> {
149 let signature = signature.into();
150 let body = signature.body().clone();
151 let op = ops::FuncDefn {
152 signature,
153 name: name.into(),
154 };
155
156 let base = Hugr::new(op);
157 let root = base.root();
158
159 let db = DFGBuilder::create_with_io(base, root, body)?;
160 Ok(Self::from_dfg_builder(db))
161 }
162
163 pub fn add_input(&mut self, input_type: Type) -> Wire {
167 let [inp_node, _] = self.io();
168
169 let new_optype = self.update_fn_signature(|mut s| {
171 s.input.to_mut().push(input_type);
172 s
173 });
174
175 let types = new_optype.signature.body().input.clone();
177 self.hugr_mut()
178 .replace_op(inp_node, Input { types })
179 .unwrap();
180 let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
181 let new_port = new_port.next().unwrap();
182
183 let new_value_port: OutgoingPort = (new_port - 1).into();
185 let new_order_port: OutgoingPort = new_port.into();
186 let order_edge_targets = self
187 .hugr()
188 .linked_inputs(inp_node, new_value_port)
189 .collect_vec();
190 self.hugr_mut().disconnect(inp_node, new_value_port);
191 for (tgt_node, tgt_port) in order_edge_targets {
192 self.hugr_mut()
193 .connect(inp_node, new_order_port, tgt_node, tgt_port);
194 }
195
196 self.0.num_in_wires += 1;
198
199 self.input_wires().last().unwrap()
200 }
201
202 pub fn add_output(&mut self, output_type: Type) {
204 let [_, out_node] = self.io();
205
206 let new_optype = self.update_fn_signature(|mut s| {
208 s.output.to_mut().push(output_type);
209 s
210 });
211
212 let types = new_optype.signature.body().output.clone();
214 self.hugr_mut()
215 .replace_op(out_node, Output { types })
216 .unwrap();
217 let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
218 let new_port = new_port.next().unwrap();
219
220 let new_value_port: IncomingPort = (new_port - 1).into();
222 let new_order_port: IncomingPort = new_port.into();
223 let order_edge_sources = self
224 .hugr()
225 .linked_outputs(out_node, new_value_port)
226 .collect_vec();
227 self.hugr_mut().disconnect(out_node, new_value_port);
228 for (src_node, src_port) in order_edge_sources {
229 self.hugr_mut()
230 .connect(src_node, src_port, out_node, new_order_port);
231 }
232
233 self.0.num_out_wires += 1;
235 }
236
237 fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
245 let parent = self.container_node();
246 let old_optype = self
247 .hugr()
248 .get_optype(parent)
249 .as_func_defn()
250 .expect("FunctionBuilder node must be a FuncDefn");
251 let signature = old_optype.inner_signature().into_owned();
252 let name = old_optype.name.clone();
253 self.hugr_mut()
254 .replace_op(
255 parent,
256 ops::FuncDefn {
257 signature: f(signature).into(),
258 name,
259 },
260 )
261 .expect("Could not replace FunctionBuilder operation");
262
263 self.hugr().get_optype(parent).as_func_defn().unwrap()
264 }
265}
266
267impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
268 #[inline]
269 fn container_node(&self) -> Node {
270 self.0.container_node()
271 }
272
273 #[inline]
274 fn hugr_mut(&mut self) -> &mut Hugr {
275 self.0.hugr_mut()
276 }
277
278 #[inline]
279 fn hugr(&self) -> &Hugr {
280 self.0.hugr()
281 }
282}
283
284impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Dataflow for DFGWrapper<B, T> {
285 #[inline]
286 fn num_inputs(&self) -> usize {
287 self.0.num_inputs()
288 }
289}
290
291impl<B: AsMut<Hugr> + AsRef<Hugr>, T: From<BuildHandle<DfgID>>> SubContainer for DFGWrapper<B, T> {
292 type ContainerHandle = T;
293
294 #[inline]
295 fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
296 self.0.finish_sub_container().map(Into::into)
297 }
298}
299
300impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
301 fn finish_hugr(self) -> Result<Hugr, ValidationError> {
302 self.0.finish_hugr()
303 }
304}
305
306#[cfg(test)]
307pub(crate) mod test {
308 use cool_asserts::assert_matches;
309 use ops::OpParent;
310 use rstest::rstest;
311 use serde_json::json;
312
313 use crate::builder::build_traits::DataflowHugr;
314 use crate::builder::{
315 endo_sig, inout_sig, BuilderWiringError, DataflowSubContainer, ModuleBuilder,
316 };
317 use crate::extension::prelude::Noop;
318 use crate::extension::prelude::{bool_t, qb_t, usize_t};
319 use crate::extension::SignatureError;
320 use crate::hugr::validate::InterGraphEdgeError;
321 use crate::ops::{handle::NodeHandle, OpTag};
322 use crate::ops::{OpTrait, Value};
323
324 use crate::std_extensions::logic::test::and_op;
325 use crate::types::type_param::TypeParam;
326 use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV};
327 use crate::utils::test_quantum_extension::h_gate;
328 use crate::{builder::test::n_identity, type_row, Wire};
329
330 use super::super::test::simple_dfg_hugr;
331 use super::*;
332 #[test]
333 fn nested_identity() -> Result<(), BuildError> {
334 let build_result = {
335 let mut outer_builder = DFGBuilder::new(endo_sig(vec![usize_t(), qb_t()]))?;
336
337 let [int, qb] = outer_builder.input_wires_arr();
338
339 let q_out = outer_builder.add_dataflow_op(h_gate(), vec![qb])?;
340
341 let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?;
342 let inner_id = n_identity(inner_builder)?;
343
344 outer_builder.finish_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs()))
345 };
346
347 assert_eq!(build_result.err(), None);
348
349 Ok(())
350 }
351
352 fn copy_scaffold<F>(f: F, msg: &'static str) -> Result<(), BuildError>
354 where
355 F: FnOnce(&mut DFGBuilder<Hugr>) -> Result<(), BuildError>,
356 {
357 let build_result = {
358 let mut builder = DFGBuilder::new(inout_sig(bool_t(), vec![bool_t(), bool_t()]))?;
359
360 f(&mut builder)?;
361
362 builder.finish_hugr()
363 };
364 assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
365
366 Ok(())
367 }
368 #[test]
369 fn copy_insertion() -> Result<(), BuildError> {
370 copy_scaffold(
371 |f_build| {
372 let [b1] = f_build.input_wires_arr();
373 f_build.set_outputs([b1, b1])
374 },
375 "Copy input and output",
376 )?;
377
378 copy_scaffold(
379 |f_build| {
380 let [b1] = f_build.input_wires_arr();
381 let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?;
382 f_build.set_outputs([xor.out_wire(0), b1])
383 },
384 "Copy input and use with binary function",
385 )?;
386
387 copy_scaffold(
388 |f_build| {
389 let [b1] = f_build.input_wires_arr();
390 let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?;
391 let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?;
392 f_build.set_outputs([xor2.out_wire(0), b1])
393 },
394 "Copy multiple times",
395 )?;
396
397 Ok(())
398 }
399
400 #[test]
401 fn copy_insertion_qubit() {
402 let builder = || {
403 let mut module_builder = ModuleBuilder::new();
404
405 let f_build = module_builder
406 .define_function("main", Signature::new(vec![qb_t()], vec![qb_t(), qb_t()]))?;
407
408 let [q1] = f_build.input_wires_arr();
409 f_build.finish_with_outputs([q1, q1])?;
410
411 Ok(module_builder.finish_hugr()?)
412 };
413
414 assert_matches!(
415 builder(),
416 Err(BuildError::OutputWiring {
417 error: BuilderWiringError::NoCopyLinear { typ, .. },
418 ..
419 })
420 if typ == qb_t()
421 );
422 }
423
424 #[test]
425 fn simple_inter_graph_edge() {
426 let builder = || -> Result<Hugr, BuildError> {
427 let mut f_build = FunctionBuilder::new(
428 "main",
429 Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(),
430 )?;
431
432 let [i1] = f_build.input_wires_arr();
433 let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?;
434 let i1 = noop.out_wire(0);
435
436 let mut nested = f_build.dfg_builder(
437 Signature::new(type_row![], vec![bool_t()]).with_prelude(),
438 [],
439 )?;
440
441 let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?;
442
443 let nested = nested.finish_with_outputs([id.out_wire(0)])?;
444
445 f_build.finish_hugr_with_outputs([nested.out_wire(0)])
446 };
447
448 assert_matches!(builder(), Ok(_));
449 }
450
451 #[test]
452 fn add_inputs_outputs() {
453 let builder = || -> Result<(Hugr, Node), BuildError> {
454 let mut f_build = FunctionBuilder::new(
455 "main",
456 Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(),
457 )?;
458 let f_node = f_build.container_node();
459
460 let [i0] = f_build.input_wires_arr();
461 let noop0 = f_build.add_dataflow_op(Noop(bool_t()), [i0])?;
462
463 f_build.set_order(&f_build.io()[0], &noop0.node());
465 f_build.set_order(&noop0.node(), &f_build.io()[1]);
466
467 f_build.add_output(qb_t());
469 let i1 = f_build.add_input(qb_t());
470 let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
471
472 let hugr = f_build.finish_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
473 Ok((hugr, f_node))
474 };
475
476 let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));
477
478 let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
479 assert_eq!(
480 func_sig.io(),
481 (
482 &vec![bool_t(), qb_t()].into(),
483 &vec![bool_t(), qb_t()].into()
484 )
485 );
486 }
487
488 #[test]
489 fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
490 let mut f_build = FunctionBuilder::new("main", Signature::new(vec![qb_t()], vec![qb_t()]))?;
491
492 let [i1] = f_build.input_wires_arr();
493 let noop = f_build.add_dataflow_op(Noop(qb_t()), [i1])?;
494 let i1 = noop.out_wire(0);
495
496 let mut nested = f_build.dfg_builder(Signature::new(type_row![], vec![qb_t()]), [])?;
497
498 let id_res = nested.add_dataflow_op(Noop(qb_t()), [i1]);
499
500 assert_matches!(
503 id_res.map(|bh| bh.handle().node()), Err(BuildError::OperationWiring {
505 error: BuilderWiringError::NonCopyableIntergraph { .. },
506 ..
507 })
508 );
509
510 Ok(())
511 }
512
513 #[rstest]
514 fn dfg_hugr(simple_dfg_hugr: Hugr) {
515 assert_eq!(simple_dfg_hugr.node_count(), 3);
516 assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg);
517 }
518
519 #[test]
520 fn insert_hugr() -> Result<(), BuildError> {
521 let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?;
523 let [i1] = dfg_builder.input_wires_arr();
524 dfg_builder.set_metadata("x", 42);
525 let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
526
527 let mut module_builder = ModuleBuilder::new();
529
530 let (dfg_node, f_node) = {
531 let mut f_build = module_builder
532 .define_function("main", Signature::new(vec![bool_t()], vec![bool_t()]))?;
533
534 let [i1] = f_build.input_wires_arr();
535 let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?;
536 let f = f_build.finish_with_outputs([dfg.out_wire(0)])?;
537 module_builder.set_child_metadata(f.node(), "x", "hi");
538 (dfg.node(), f.node())
539 };
540
541 let hugr = module_builder.finish_hugr()?;
542 assert_eq!(hugr.node_count(), 7);
543
544 assert_eq!(hugr.get_metadata(hugr.root(), "x"), None);
545 assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42)));
546 assert_eq!(hugr.get_metadata(f_node, "x").cloned(), Some(json!("hi")));
547
548 Ok(())
549 }
550
551 #[test]
552 fn barrier_node() -> Result<(), BuildError> {
553 let mut parent = DFGBuilder::new(endo_sig(bool_t()))?;
554
555 let [w] = parent.input_wires_arr();
556
557 let mut dfg_b = parent.dfg_builder(endo_sig(bool_t()), [w])?;
558 let [w] = dfg_b.input_wires_arr();
559
560 let barr0 = dfg_b.add_barrier([w])?;
561 let [w] = barr0.outputs_arr();
562
563 let barr1 = dfg_b.add_barrier([w])?;
564 let [w] = barr1.outputs_arr();
565
566 let dfg = dfg_b.finish_with_outputs([w])?;
567 let [w] = dfg.outputs_arr();
568
569 let mut dfg2_b = parent.dfg_builder(endo_sig(vec![bool_t(), bool_t()]), [w, w])?;
570 let [w1, w2] = dfg2_b.input_wires_arr();
571 let barr2 = dfg2_b.add_barrier([w1, w2])?;
572 let wires: Vec<Wire> = barr2.outputs().collect();
573
574 let dfg2 = dfg2_b.finish_with_outputs(wires)?;
575 let [w, _] = dfg2.outputs_arr();
576 parent.finish_hugr_with_outputs([w])?;
577
578 Ok(())
579 }
580
581 #[test]
582 fn non_cfg_ancestor() -> Result<(), BuildError> {
583 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
584 let mut b = DFGBuilder::new(unit_sig.clone())?;
585 let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
586 let b_child_in_wire = b_child.input().out_wire(0);
587 b_child.finish_with_outputs([])?;
588 let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
589
590 let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;
593
594 let res = b.finish_hugr_with_outputs([b_child_2_handle.out_wire(0)]);
595
596 assert_matches!(
597 res,
598 Err(BuildError::InvalidHUGR(
599 ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
600 ))
601 );
602 Ok(())
603 }
604
605 #[test]
606 fn no_relation_edge() -> Result<(), BuildError> {
607 let unit_sig = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]);
608 let mut b = DFGBuilder::new(unit_sig.clone())?;
609 let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
610 let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?;
611 let b_child_child_in_wire = b_child_child.input().out_wire(0);
612
613 b_child_child.finish_with_outputs([])?;
614 b_child.finish_with_outputs([])?;
615
616 let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
617 let b_child_2_child =
618 b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?;
619
620 let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);
621
622 assert_matches!(
623 res.map(|h| h.handle().node()), Err(BuildError::OutputWiring {
625 error: BuilderWiringError::NoRelationIntergraph { .. },
626 ..
627 })
628 );
629 Ok(())
630 }
631
632 #[test]
633 fn no_outer_row_variables() -> Result<(), BuildError> {
634 let e = crate::hugr::validate::test::extension_with_eval_parallel();
635 let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable);
636 FunctionBuilder::new(
638 "bad_eval",
639 PolyFuncType::new(
640 [TypeParam::new_list(TypeBound::Copyable)],
641 Signature::new(
642 Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
643 vec![],
644 ),
645 ),
646 )?;
647
648 let ev = e.instantiate_extension_op(
650 "eval",
651 [vec![usize_t().into()].into(), vec![tv.into()].into()],
652 );
653 assert_eq!(
654 ev,
655 Err(SignatureError::RowVarWhereTypeExpected {
656 var: RowVariable(0, TypeBound::Copyable)
657 })
658 );
659 Ok(())
660 }
661
662 #[test]
663 fn order_edges() {
664 let (mut hugr, load_constant, call) = {
665 let mut builder = ModuleBuilder::new();
666 let func = builder
667 .declare("func", Signature::new_endo(bool_t()).into())
668 .unwrap();
669 let (load_constant, call) = {
670 let mut builder = builder
671 .define_function("main", Signature::new(Type::EMPTY_TYPEROW, bool_t()))
672 .unwrap();
673 let load_constant = builder.add_load_value(Value::true_val());
674 let [r] = builder
675 .call(&func, &[], [load_constant])
676 .unwrap()
677 .outputs_arr();
678 builder.finish_with_outputs([r]).unwrap();
679 (load_constant.node(), r.node())
680 };
681 (builder.finish_hugr().unwrap(), load_constant, call)
682 };
683
684 let lc_optype = hugr.get_optype(load_constant);
685 let call_optype = hugr.get_optype(call);
686 assert_eq!(EdgeKind::StateOrder, lc_optype.other_input().unwrap());
687 assert_eq!(EdgeKind::StateOrder, lc_optype.other_output().unwrap());
688 assert_eq!(EdgeKind::StateOrder, call_optype.other_input().unwrap());
689 assert_eq!(EdgeKind::StateOrder, call_optype.other_output().unwrap());
690
691 hugr.connect(
692 load_constant,
693 lc_optype.other_output_port().unwrap(),
694 call,
695 call_optype.other_input_port().unwrap(),
696 );
697
698 hugr.validate().unwrap();
699 }
700}