1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18 wires: Vec<Option<Wire>>,
22 builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26#[non_exhaustive]
28pub enum CircuitBuildError {
29 #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|op| op.name()).unwrap_or_default())]
31 InvalidWireIndex {
32 op: Option<Box<OpType>>,
34 invalid_index: usize,
36 },
37 #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39 MismatchedLinearInputs {
40 op: Box<OpType>,
42 index: Vec<usize>,
44 },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48 pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51 Self {
52 wires: wires.into_iter().map(Some).collect(),
53 builder,
54 }
55 }
56
57 #[must_use]
59 pub fn n_wires(&self) -> usize {
60 self.wires.iter().flatten().count()
61 }
62
63 #[must_use]
65 pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66 self.wires.get(index).copied().flatten()
67 }
68
69 pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71 self.wires
72 .iter()
73 .enumerate()
74 .filter_map(|(i, w)| w.map(|_| i))
75 }
76
77 #[must_use]
83 pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84 collect_array(self.tracked_units())
85 }
86
87 #[inline]
88 pub fn append(
94 &mut self,
95 op: impl Into<OpType>,
96 indices: impl IntoIterator<Item = usize> + Clone,
97 ) -> Result<&mut Self, BuildError> {
98 self.append_and_consume(op, indices)
99 }
100
101 #[inline]
102 pub fn append_and_consume<A: Into<CircuitUnit>>(
105 &mut self,
106 op: impl Into<OpType>,
107 inputs: impl IntoIterator<Item = A>,
108 ) -> Result<&mut Self, BuildError> {
109 self.append_with_outputs(op, inputs)?;
110 Ok(self)
111 }
112
113 pub fn append_with_outputs<A: Into<CircuitUnit>>(
123 &mut self,
124 op: impl Into<OpType>,
125 inputs: impl IntoIterator<Item = A>,
126 ) -> Result<Vec<Wire>, BuildError> {
127 let mut linear_inputs = HashMap::new();
129 let op = op.into();
130
131 let input_wires: Result<Vec<Wire>, usize> = inputs
132 .into_iter()
133 .map(Into::into)
134 .enumerate()
135 .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136 CircuitUnit::Wire(wire) => Ok(wire),
137 CircuitUnit::Linear(wire_index) => {
138 linear_inputs.insert(input_port, wire_index);
139 self.tracked_wire(wire_index).ok_or(wire_index)
140 }
141 })
142 .collect();
143
144 let input_wires =
145 input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146 op: Some(Box::new(op.clone())),
147 invalid_index,
148 })?;
149
150 let output_wires = self
151 .builder
152 .add_dataflow_op(
153 op.clone(), input_wires,
155 )?
156 .outputs();
157 let nonlinear_outputs: Vec<Wire> = output_wires
158 .enumerate()
159 .filter_map(|(output_port, wire)| {
160 if let Some(wire_index) = linear_inputs.remove(&output_port) {
161 self.wires[wire_index] = Some(wire);
163 None
164 } else {
165 Some(wire)
166 }
167 })
168 .collect();
169
170 if !linear_inputs.is_empty() {
171 return Err(CircuitBuildError::MismatchedLinearInputs {
172 op: Box::new(op),
173 index: linear_inputs.values().copied().collect(),
174 }
175 .into());
176 }
177
178 Ok(nonlinear_outputs)
179 }
180
181 pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195 &mut self,
196 op: impl Into<OpType>,
197 inputs: impl IntoIterator<Item = A>,
198 ) -> Result<[Wire; N], BuildError> {
199 let outputs = self.append_with_outputs(op, inputs)?;
200 Ok(collect_array(outputs))
201 }
202
203 pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205 self.builder.add_load_value(value)
206 }
207
208 pub fn track_wire(&mut self, wire: Wire) -> usize {
212 self.wires.push(Some(wire));
213 self.wires.len() - 1
214 }
215
216 pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224 self.wires
225 .get_mut(index)
226 .and_then(mem::take)
227 .ok_or(CircuitBuildError::InvalidWireIndex {
228 op: None,
229 invalid_index: index,
230 })
231 }
232
233 #[inline]
234 #[must_use]
237 pub fn finish(self) -> Vec<Wire> {
238 self.wires.into_iter().flatten().collect()
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use super::*;
245 use cool_asserts::assert_matches;
246
247 use crate::Extension;
248 use crate::builder::{HugrBuilder, ModuleBuilder};
249 use crate::extension::ExtensionId;
250 use crate::extension::prelude::{qb_t, usize_t};
251 use crate::std_extensions::arithmetic::float_types::ConstF64;
252 use crate::utils::test_quantum_extension::{
253 self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
254 };
255 use crate::{
256 builder::{DataflowSubContainer, test::build_main},
257 extension::prelude::bool_t,
258 types::Signature,
259 };
260
261 #[test]
262 fn simple_linear() {
263 let build_res = build_main(
264 Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(),
265 |mut f_build| {
266 let wires = f_build.input_wires().map(Some).collect();
267
268 let mut linear = CircuitBuilder {
269 wires,
270 builder: &mut f_build,
271 };
272
273 assert_eq!(linear.n_wires(), 2);
274
275 linear
276 .append(h_gate(), [0])?
277 .append(cx_gate(), [0, 1])?
278 .append(cx_gate(), [1, 0])?;
279
280 let angle = linear.add_constant(ConstF64::new(0.5));
281 linear.append_and_consume(
282 rz_f64(),
283 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
284 )?;
285
286 let outs = linear.finish();
287 f_build.finish_with_outputs(outs)
288 },
289 );
290
291 assert_matches!(build_res, Ok(_));
292 }
293
294 #[test]
295 fn with_nonlinear_and_outputs() {
296 let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
297 let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
298 ext.add_op(
299 "MyOp".into(),
300 String::new(),
301 Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
302 extension_ref,
303 )
304 .unwrap();
305 });
306 let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
307
308 let mut module_builder = ModuleBuilder::new();
309 let mut f_build = module_builder
310 .define_function(
311 "main",
312 Signature::new(
313 vec![qb_t(), qb_t(), usize_t()],
314 vec![qb_t(), qb_t(), bool_t()],
315 ),
316 )
317 .unwrap();
318
319 let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
320
321 let mut linear = f_build.as_circuit([q0, q1]);
322
323 let measure_out = linear
324 .append(cx_gate(), [0, 1])
325 .unwrap()
326 .append_and_consume(
327 my_custom_op,
328 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
329 )
330 .unwrap()
331 .append_with_outputs(measure(), [0])
332 .unwrap();
333
334 let out_qbs = linear.finish();
335 f_build
336 .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
337 .unwrap();
338
339 let mut registry = test_quantum_extension::REG.clone();
340 registry.register(my_ext).unwrap();
341 let build_res = module_builder.finish_hugr();
342
343 assert_matches!(build_res, Ok(_));
344 }
345
346 #[test]
347 fn ancillae() {
348 let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| {
349 let mut circ = f_build.as_circuit(f_build.input_wires());
350 assert_eq!(circ.n_wires(), 1);
351
352 let [q0] = circ.tracked_units_arr();
353 let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
354 let ancilla = circ.track_wire(ancilla);
355
356 assert_ne!(ancilla, 0);
357 assert_eq!(circ.n_wires(), 2);
358 assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
359
360 circ.append(cx_gate(), [q0, ancilla])?;
361 let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
362
363 let q0 = circ.untrack_wire(q0)?;
364
365 assert_eq!(circ.tracked_units_arr(), [ancilla]);
366
367 circ.append_and_consume(q_discard(), [q0])?;
368
369 let outs = circ.finish();
370
371 assert_eq!(outs.len(), 1);
372
373 f_build.finish_with_outputs(outs)
374 });
375
376 assert_matches!(build_res, Ok(_));
377 }
378
379 #[test]
380 fn circuit_builder_errors() {
381 let _build_res = build_main(
382 Signature::new_endo(vec![qb_t(), qb_t()]).into(),
383 |mut f_build| {
384 let mut circ = f_build.as_circuit(f_build.input_wires());
385 let [q0, q1] = circ.tracked_units_arr();
386 let invalid_index = 0xff;
387
388 assert_matches!(
390 circ.append(cx_gate(), [q0, invalid_index]),
391 Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
392 if op == Some(Box::new(cx_gate().into())) && idx == invalid_index,
393 );
394
395 assert_matches!(
397 circ.untrack_wire(invalid_index),
398 Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
399 if idx == invalid_index,
400 );
401
402 assert_matches!(
404 circ.append(q_discard(), [q1]),
405 Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
406 if *op == q_discard().into() && index == [q1],
407 );
408
409 let outs = circ.finish();
410
411 assert_eq!(outs.len(), 2);
412
413 f_build.finish_with_outputs(outs)
414 },
415 );
416
417 }
420}