1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18 wires: Vec<Option<Wire>>,
22 builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26#[non_exhaustive]
28pub enum CircuitBuildError {
29 #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|o| o.name()).unwrap_or_default())]
31 InvalidWireIndex {
32 op: Option<OpType>,
34 invalid_index: usize,
36 },
37 #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39 MismatchedLinearInputs {
40 op: OpType,
42 index: Vec<usize>,
44 },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48 pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51 Self {
52 wires: wires.into_iter().map(Some).collect(),
53 builder,
54 }
55 }
56
57 #[must_use]
59 pub fn n_wires(&self) -> usize {
60 self.wires.iter().flatten().count()
61 }
62
63 #[must_use]
65 pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66 self.wires.get(index).copied().flatten()
67 }
68
69 pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71 self.wires
72 .iter()
73 .enumerate()
74 .filter_map(|(i, w)| w.map(|_| i))
75 }
76
77 #[must_use]
83 pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84 collect_array(self.tracked_units())
85 }
86
87 #[inline]
88 pub fn append(
94 &mut self,
95 op: impl Into<OpType>,
96 indices: impl IntoIterator<Item = usize> + Clone,
97 ) -> Result<&mut Self, BuildError> {
98 self.append_and_consume(op, indices)
99 }
100
101 #[inline]
102 pub fn append_and_consume<A: Into<CircuitUnit>>(
105 &mut self,
106 op: impl Into<OpType>,
107 inputs: impl IntoIterator<Item = A>,
108 ) -> Result<&mut Self, BuildError> {
109 self.append_with_outputs(op, inputs)?;
110 Ok(self)
111 }
112
113 pub fn append_with_outputs<A: Into<CircuitUnit>>(
123 &mut self,
124 op: impl Into<OpType>,
125 inputs: impl IntoIterator<Item = A>,
126 ) -> Result<Vec<Wire>, BuildError> {
127 let mut linear_inputs = HashMap::new();
129 let op = op.into();
130
131 let input_wires: Result<Vec<Wire>, usize> = inputs
132 .into_iter()
133 .map(Into::into)
134 .enumerate()
135 .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136 CircuitUnit::Wire(wire) => Ok(wire),
137 CircuitUnit::Linear(wire_index) => {
138 linear_inputs.insert(input_port, wire_index);
139 self.tracked_wire(wire_index).ok_or(wire_index)
140 }
141 })
142 .collect();
143
144 let input_wires =
145 input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146 op: Some(op.clone()),
147 invalid_index,
148 })?;
149
150 let output_wires = self
151 .builder
152 .add_dataflow_op(
153 op.clone(), input_wires,
155 )?
156 .outputs();
157 let nonlinear_outputs: Vec<Wire> = output_wires
158 .enumerate()
159 .filter_map(|(output_port, wire)| {
160 if let Some(wire_index) = linear_inputs.remove(&output_port) {
161 self.wires[wire_index] = Some(wire);
163 None
164 } else {
165 Some(wire)
166 }
167 })
168 .collect();
169
170 if !linear_inputs.is_empty() {
171 return Err(CircuitBuildError::MismatchedLinearInputs {
172 op,
173 index: linear_inputs.values().copied().collect(),
174 }
175 .into());
176 }
177
178 Ok(nonlinear_outputs)
179 }
180
181 pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195 &mut self,
196 op: impl Into<OpType>,
197 inputs: impl IntoIterator<Item = A>,
198 ) -> Result<[Wire; N], BuildError> {
199 let outputs = self.append_with_outputs(op, inputs)?;
200 Ok(collect_array(outputs))
201 }
202
203 pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205 self.builder.add_load_value(value)
206 }
207
208 pub fn track_wire(&mut self, wire: Wire) -> usize {
212 self.wires.push(Some(wire));
213 self.wires.len() - 1
214 }
215
216 pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224 self.wires
225 .get_mut(index)
226 .and_then(mem::take)
227 .ok_or(CircuitBuildError::InvalidWireIndex {
228 op: None,
229 invalid_index: index,
230 })
231 }
232
233 #[inline]
234 pub fn finish(self) -> Vec<Wire> {
237 self.wires.into_iter().flatten().collect()
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use super::*;
244 use cool_asserts::assert_matches;
245
246 use crate::builder::{Container, HugrBuilder, ModuleBuilder};
247 use crate::extension::prelude::{qb_t, usize_t};
248 use crate::extension::{ExtensionId, ExtensionSet};
249 use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
250 use crate::utils::test_quantum_extension::{
251 self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
252 };
253 use crate::Extension;
254 use crate::{
255 builder::{test::build_main, DataflowSubContainer},
256 extension::prelude::bool_t,
257 types::Signature,
258 };
259
260 #[test]
261 fn simple_linear() {
262 let build_res = build_main(
263 Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()])
264 .with_extension_delta(test_quantum_extension::EXTENSION_ID)
265 .with_extension_delta(float_types::EXTENSION_ID)
266 .into(),
267 |mut f_build| {
268 let wires = f_build.input_wires().map(Some).collect();
269
270 let mut linear = CircuitBuilder {
271 wires,
272 builder: &mut f_build,
273 };
274
275 assert_eq!(linear.n_wires(), 2);
276
277 linear
278 .append(h_gate(), [0])?
279 .append(cx_gate(), [0, 1])?
280 .append(cx_gate(), [1, 0])?;
281
282 let angle = linear.add_constant(ConstF64::new(0.5));
283 linear.append_and_consume(
284 rz_f64(),
285 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
286 )?;
287
288 let outs = linear.finish();
289 f_build.finish_with_outputs(outs)
290 },
291 );
292
293 assert_matches!(build_res, Ok(_));
294 }
295
296 #[test]
297 fn with_nonlinear_and_outputs() {
298 let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
299 let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
300 ext.add_op(
301 "MyOp".into(),
302 "".to_string(),
303 Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
304 extension_ref,
305 )
306 .unwrap();
307 });
308 let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
309
310 let mut module_builder = ModuleBuilder::new();
311 let mut f_build = module_builder
312 .define_function(
313 "main",
314 Signature::new(
315 vec![qb_t(), qb_t(), usize_t()],
316 vec![qb_t(), qb_t(), bool_t()],
317 )
318 .with_extension_delta(ExtensionSet::from_iter([
319 test_quantum_extension::EXTENSION_ID,
320 my_ext_name,
321 ])),
322 )
323 .unwrap();
324
325 let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
326
327 let mut linear = f_build.as_circuit([q0, q1]);
328
329 let measure_out = linear
330 .append(cx_gate(), [0, 1])
331 .unwrap()
332 .append_and_consume(
333 my_custom_op,
334 [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
335 )
336 .unwrap()
337 .append_with_outputs(measure(), [0])
338 .unwrap();
339
340 let out_qbs = linear.finish();
341 f_build
342 .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
343 .unwrap();
344
345 let mut registry = test_quantum_extension::REG.clone();
346 registry.register(my_ext).unwrap();
347 let build_res = module_builder.finish_hugr();
348
349 assert_matches!(build_res, Ok(_));
350 }
351
352 #[test]
353 fn ancillae() {
354 let build_res = build_main(
355 Signature::new_endo(qb_t())
356 .with_extension_delta(test_quantum_extension::EXTENSION_ID)
357 .into(),
358 |mut f_build| {
359 let mut circ = f_build.as_circuit(f_build.input_wires());
360 assert_eq!(circ.n_wires(), 1);
361
362 let [q0] = circ.tracked_units_arr();
363 let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
364 let ancilla = circ.track_wire(ancilla);
365
366 assert_ne!(ancilla, 0);
367 assert_eq!(circ.n_wires(), 2);
368 assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
369
370 circ.append(cx_gate(), [q0, ancilla])?;
371 let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
372
373 let q0 = circ.untrack_wire(q0)?;
374
375 assert_eq!(circ.tracked_units_arr(), [ancilla]);
376
377 circ.append_and_consume(q_discard(), [q0])?;
378
379 let outs = circ.finish();
380
381 assert_eq!(outs.len(), 1);
382
383 f_build.finish_with_outputs(outs)
384 },
385 );
386
387 assert_matches!(build_res, Ok(_));
388 }
389
390 #[test]
391 fn circuit_builder_errors() {
392 let _build_res = build_main(
393 Signature::new_endo(vec![qb_t(), qb_t()]).into(),
394 |mut f_build| {
395 let mut circ = f_build.as_circuit(f_build.input_wires());
396 let [q0, q1] = circ.tracked_units_arr();
397 let invalid_index = 0xff;
398
399 assert_matches!(
401 circ.append(cx_gate(), [q0, invalid_index]),
402 Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
403 if op == Some(cx_gate().into()) && idx == invalid_index,
404 );
405
406 assert_matches!(
408 circ.untrack_wire(invalid_index),
409 Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
410 if idx == invalid_index,
411 );
412
413 assert_matches!(
415 circ.append(q_discard(), [q1]),
416 Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
417 if op == q_discard().into() && index == [q1],
418 );
419
420 let outs = circ.finish();
421
422 assert_eq!(outs.len(), 2);
423
424 f_build.finish_with_outputs(outs)
425 },
426 );
427
428 }
431}