1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18 wires: Vec<Option<Wire>>,
22 builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26#[non_exhaustive]
28pub enum CircuitBuildError {
29 #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|o| o.name()).unwrap_or_default())]
31 InvalidWireIndex {
32 op: Option<OpType>,
34 invalid_index: usize,
36 },
37 #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39 MismatchedLinearInputs {
40 op: OpType,
42 index: Vec<usize>,
44 },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48 pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51 Self {
52 wires: wires.into_iter().map(Some).collect(),
53 builder,
54 }
55 }
56
57 #[must_use]
59 pub fn n_wires(&self) -> usize {
60 self.wires.iter().flatten().count()
61 }
62
63 #[must_use]
65 pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66 self.wires.get(index).copied().flatten()
67 }
68
69 pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71 self.wires
72 .iter()
73 .enumerate()
74 .filter_map(|(i, w)| w.map(|_| i))
75 }
76
77 #[must_use]
83 pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84 collect_array(self.tracked_units())
85 }
86
87 #[inline]
88 pub fn append(
94 &mut self,
95 op: impl Into<OpType>,
96 indices: impl IntoIterator<Item = usize> + Clone,
97 ) -> Result<&mut Self, BuildError> {
98 self.append_and_consume(op, indices)
99 }
100
101 #[inline]
102 pub fn append_and_consume<A: Into<CircuitUnit>>(
105 &mut self,
106 op: impl Into<OpType>,
107 inputs: impl IntoIterator<Item = A>,
108 ) -> Result<&mut Self, BuildError> {
109 self.append_with_outputs(op, inputs)?;
110 Ok(self)
111 }
112
113 pub fn append_with_outputs<A: Into<CircuitUnit>>(
123 &mut self,
124 op: impl Into<OpType>,
125 inputs: impl IntoIterator<Item = A>,
126 ) -> Result<Vec<Wire>, BuildError> {
127 let mut linear_inputs = HashMap::new();
129 let op = op.into();
130
131 let input_wires: Result<Vec<Wire>, usize> = inputs
132 .into_iter()
133 .map(Into::into)
134 .enumerate()
135 .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136 CircuitUnit::Wire(wire) => Ok(wire),
137 CircuitUnit::Linear(wire_index) => {
138 linear_inputs.insert(input_port, wire_index);
139 self.tracked_wire(wire_index).ok_or(wire_index)
140 }
141 })
142 .collect();
143
144 let input_wires =
145 input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146 op: Some(op.clone()),
147 invalid_index,
148 })?;
149
150 let output_wires = self
151 .builder
152 .add_dataflow_op(
153 op.clone(), input_wires,
155 )?
156 .outputs();
157 let nonlinear_outputs: Vec<Wire> = output_wires
158 .enumerate()
159 .filter_map(|(output_port, wire)| {
160 if let Some(wire_index) = linear_inputs.remove(&output_port) {
161 self.wires[wire_index] = Some(wire);
163 None
164 } else {
165 Some(wire)
166 }
167 })
168 .collect();
169
170 if !linear_inputs.is_empty() {
171 return Err(CircuitBuildError::MismatchedLinearInputs {
172 op,
173 index: linear_inputs.values().copied().collect(),
174 }
175 .into());
176 }
177
178 Ok(nonlinear_outputs)
179 }
180
181 pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195 &mut self,
196 op: impl Into<OpType>,
197 inputs: impl IntoIterator<Item = A>,
198 ) -> Result<[Wire; N], BuildError> {
199 let outputs = self.append_with_outputs(op, inputs)?;
200 Ok(collect_array(outputs))
201 }
202
203 pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205 self.builder.add_load_value(value)
206 }
207
208 pub fn track_wire(&mut self, wire: Wire) -> usize {
212 self.wires.push(Some(wire));
213 self.wires.len() - 1
214 }
215
216 pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224 self.wires
225 .get_mut(index)
226 .and_then(mem::take)
227 .ok_or(CircuitBuildError::InvalidWireIndex {
228 op: None,
229 invalid_index: index,
230 })
231 }
232
233 #[inline]
234 pub fn finish(self) -> Vec<Wire> {
237 self.wires.into_iter().flatten().collect()
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use super::*;
244 use cool_asserts::assert_matches;
245
246 use crate::builder::{Container, HugrBuilder, ModuleBuilder};
247 use crate::extension::prelude::{qb_t, usize_t};
248 use crate::extension::ExtensionId;
249 use crate::std_extensions::arithmetic::float_types::ConstF64;
250 use crate::utils::test_quantum_extension::{
251 self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
252 };
253 use crate::Extension;
254 use crate::{
255 builder::{test::build_main, DataflowSubContainer},
256 extension::prelude::bool_t,
257 types::Signature,
258 };
259
260 #[test]
261 fn simple_linear() {
262 let build_res = build_main(
263 Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(),
264 |mut f_build| {
265 let wires = f_build.input_wires().map(Some).collect();
266
267 let mut linear = CircuitBuilder {
268 wires,
269 builder: &mut f_build,
270 };
271
272 assert_eq!(linear.n_wires(), 2);
273
274 linear
275 .append(h_gate(), [0])?
276 .append(cx_gate(), [0, 1])?
277 .append(cx_gate(), [1, 0])?;
278
279 let angle = linear.add_constant(ConstF64::new(0.5));
280 linear.append_and_consume(
281 rz_f64(),
282 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
283 )?;
284
285 let outs = linear.finish();
286 f_build.finish_with_outputs(outs)
287 },
288 );
289
290 assert_matches!(build_res, Ok(_));
291 }
292
293 #[test]
294 fn with_nonlinear_and_outputs() {
295 let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
296 let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
297 ext.add_op(
298 "MyOp".into(),
299 "".to_string(),
300 Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
301 extension_ref,
302 )
303 .unwrap();
304 });
305 let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
306
307 let mut module_builder = ModuleBuilder::new();
308 let mut f_build = module_builder
309 .define_function(
310 "main",
311 Signature::new(
312 vec![qb_t(), qb_t(), usize_t()],
313 vec![qb_t(), qb_t(), bool_t()],
314 ),
315 )
316 .unwrap();
317
318 let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
319
320 let mut linear = f_build.as_circuit([q0, q1]);
321
322 let measure_out = linear
323 .append(cx_gate(), [0, 1])
324 .unwrap()
325 .append_and_consume(
326 my_custom_op,
327 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
328 )
329 .unwrap()
330 .append_with_outputs(measure(), [0])
331 .unwrap();
332
333 let out_qbs = linear.finish();
334 f_build
335 .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
336 .unwrap();
337
338 let mut registry = test_quantum_extension::REG.clone();
339 registry.register(my_ext).unwrap();
340 let build_res = module_builder.finish_hugr();
341
342 assert_matches!(build_res, Ok(_));
343 }
344
345 #[test]
346 fn ancillae() {
347 let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| {
348 let mut circ = f_build.as_circuit(f_build.input_wires());
349 assert_eq!(circ.n_wires(), 1);
350
351 let [q0] = circ.tracked_units_arr();
352 let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
353 let ancilla = circ.track_wire(ancilla);
354
355 assert_ne!(ancilla, 0);
356 assert_eq!(circ.n_wires(), 2);
357 assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
358
359 circ.append(cx_gate(), [q0, ancilla])?;
360 let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
361
362 let q0 = circ.untrack_wire(q0)?;
363
364 assert_eq!(circ.tracked_units_arr(), [ancilla]);
365
366 circ.append_and_consume(q_discard(), [q0])?;
367
368 let outs = circ.finish();
369
370 assert_eq!(outs.len(), 1);
371
372 f_build.finish_with_outputs(outs)
373 });
374
375 assert_matches!(build_res, Ok(_));
376 }
377
378 #[test]
379 fn circuit_builder_errors() {
380 let _build_res = build_main(
381 Signature::new_endo(vec![qb_t(), qb_t()]).into(),
382 |mut f_build| {
383 let mut circ = f_build.as_circuit(f_build.input_wires());
384 let [q0, q1] = circ.tracked_units_arr();
385 let invalid_index = 0xff;
386
387 assert_matches!(
389 circ.append(cx_gate(), [q0, invalid_index]),
390 Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
391 if op == Some(cx_gate().into()) && idx == invalid_index,
392 );
393
394 assert_matches!(
396 circ.untrack_wire(invalid_index),
397 Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
398 if idx == invalid_index,
399 );
400
401 assert_matches!(
403 circ.append(q_discard(), [q1]),
404 Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
405 if op == q_discard().into() && index == [q1],
406 );
407
408 let outs = circ.finish();
409
410 assert_eq!(outs.len(), 2);
411
412 f_build.finish_with_outputs(outs)
413 },
414 );
415
416 }
419}