1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18    wires: Vec<Option<Wire>>,
22    builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26#[non_exhaustive]
28pub enum CircuitBuildError {
29    #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|o| o.name()).unwrap_or_default())]
31    InvalidWireIndex {
32        op: Option<OpType>,
34        invalid_index: usize,
36    },
37    #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39    MismatchedLinearInputs {
40        op: OpType,
42        index: Vec<usize>,
44    },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48    pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51        Self {
52            wires: wires.into_iter().map(Some).collect(),
53            builder,
54        }
55    }
56
57    #[must_use]
59    pub fn n_wires(&self) -> usize {
60        self.wires.iter().flatten().count()
61    }
62
63    #[must_use]
65    pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66        self.wires.get(index).copied().flatten()
67    }
68
69    pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71        self.wires
72            .iter()
73            .enumerate()
74            .filter_map(|(i, w)| w.map(|_| i))
75    }
76
77    #[must_use]
83    pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84        collect_array(self.tracked_units())
85    }
86
87    #[inline]
88    pub fn append(
94        &mut self,
95        op: impl Into<OpType>,
96        indices: impl IntoIterator<Item = usize> + Clone,
97    ) -> Result<&mut Self, BuildError> {
98        self.append_and_consume(op, indices)
99    }
100
101    #[inline]
102    pub fn append_and_consume<A: Into<CircuitUnit>>(
105        &mut self,
106        op: impl Into<OpType>,
107        inputs: impl IntoIterator<Item = A>,
108    ) -> Result<&mut Self, BuildError> {
109        self.append_with_outputs(op, inputs)?;
110        Ok(self)
111    }
112
113    pub fn append_with_outputs<A: Into<CircuitUnit>>(
123        &mut self,
124        op: impl Into<OpType>,
125        inputs: impl IntoIterator<Item = A>,
126    ) -> Result<Vec<Wire>, BuildError> {
127        let mut linear_inputs = HashMap::new();
129        let op = op.into();
130
131        let input_wires: Result<Vec<Wire>, usize> = inputs
132            .into_iter()
133            .map(Into::into)
134            .enumerate()
135            .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136                CircuitUnit::Wire(wire) => Ok(wire),
137                CircuitUnit::Linear(wire_index) => {
138                    linear_inputs.insert(input_port, wire_index);
139                    self.tracked_wire(wire_index).ok_or(wire_index)
140                }
141            })
142            .collect();
143
144        let input_wires =
145            input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146                op: Some(op.clone()),
147                invalid_index,
148            })?;
149
150        let output_wires = self
151            .builder
152            .add_dataflow_op(
153                op.clone(), input_wires,
155            )?
156            .outputs();
157        let nonlinear_outputs: Vec<Wire> = output_wires
158            .enumerate()
159            .filter_map(|(output_port, wire)| {
160                if let Some(wire_index) = linear_inputs.remove(&output_port) {
161                    self.wires[wire_index] = Some(wire);
163                    None
164                } else {
165                    Some(wire)
166                }
167            })
168            .collect();
169
170        if !linear_inputs.is_empty() {
171            return Err(CircuitBuildError::MismatchedLinearInputs {
172                op,
173                index: linear_inputs.values().copied().collect(),
174            }
175            .into());
176        }
177
178        Ok(nonlinear_outputs)
179    }
180
181    pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195        &mut self,
196        op: impl Into<OpType>,
197        inputs: impl IntoIterator<Item = A>,
198    ) -> Result<[Wire; N], BuildError> {
199        let outputs = self.append_with_outputs(op, inputs)?;
200        Ok(collect_array(outputs))
201    }
202
203    pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205        self.builder.add_load_value(value)
206    }
207
208    pub fn track_wire(&mut self, wire: Wire) -> usize {
212        self.wires.push(Some(wire));
213        self.wires.len() - 1
214    }
215
216    pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224        self.wires
225            .get_mut(index)
226            .and_then(mem::take)
227            .ok_or(CircuitBuildError::InvalidWireIndex {
228                op: None,
229                invalid_index: index,
230            })
231    }
232
233    #[inline]
234    pub fn finish(self) -> Vec<Wire> {
237        self.wires.into_iter().flatten().collect()
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use super::*;
244    use cool_asserts::assert_matches;
245
246    use crate::builder::{Container, HugrBuilder, ModuleBuilder};
247    use crate::extension::prelude::{qb_t, usize_t};
248    use crate::extension::{ExtensionId, ExtensionSet};
249    use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
250    use crate::utils::test_quantum_extension::{
251        self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
252    };
253    use crate::Extension;
254    use crate::{
255        builder::{test::build_main, DataflowSubContainer},
256        extension::prelude::bool_t,
257        types::Signature,
258    };
259
260    #[test]
261    fn simple_linear() {
262        let build_res = build_main(
263            Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()])
264                .with_extension_delta(test_quantum_extension::EXTENSION_ID)
265                .with_extension_delta(float_types::EXTENSION_ID)
266                .into(),
267            |mut f_build| {
268                let wires = f_build.input_wires().map(Some).collect();
269
270                let mut linear = CircuitBuilder {
271                    wires,
272                    builder: &mut f_build,
273                };
274
275                assert_eq!(linear.n_wires(), 2);
276
277                linear
278                    .append(h_gate(), [0])?
279                    .append(cx_gate(), [0, 1])?
280                    .append(cx_gate(), [1, 0])?;
281
282                let angle = linear.add_constant(ConstF64::new(0.5));
283                linear.append_and_consume(
284                    rz_f64(),
285                    [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
286                )?;
287
288                let outs = linear.finish();
289                f_build.finish_with_outputs(outs)
290            },
291        );
292
293        assert_matches!(build_res, Ok(_));
294    }
295
296    #[test]
297    fn with_nonlinear_and_outputs() {
298        let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
299        let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
300            ext.add_op(
301                "MyOp".into(),
302                "".to_string(),
303                Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
304                extension_ref,
305            )
306            .unwrap();
307        });
308        let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
309
310        let mut module_builder = ModuleBuilder::new();
311        let mut f_build = module_builder
312            .define_function(
313                "main",
314                Signature::new(
315                    vec![qb_t(), qb_t(), usize_t()],
316                    vec![qb_t(), qb_t(), bool_t()],
317                )
318                .with_extension_delta(ExtensionSet::from_iter([
319                    test_quantum_extension::EXTENSION_ID,
320                    my_ext_name,
321                ])),
322            )
323            .unwrap();
324
325        let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
326
327        let mut linear = f_build.as_circuit([q0, q1]);
328
329        let measure_out = linear
330            .append(cx_gate(), [0, 1])
331            .unwrap()
332            .append_and_consume(
333                my_custom_op,
334                [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
335            )
336            .unwrap()
337            .append_with_outputs(measure(), [0])
338            .unwrap();
339
340        let out_qbs = linear.finish();
341        f_build
342            .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
343            .unwrap();
344
345        let mut registry = test_quantum_extension::REG.clone();
346        registry.register(my_ext).unwrap();
347        let build_res = module_builder.finish_hugr();
348
349        assert_matches!(build_res, Ok(_));
350    }
351
352    #[test]
353    fn ancillae() {
354        let build_res = build_main(
355            Signature::new_endo(qb_t())
356                .with_extension_delta(test_quantum_extension::EXTENSION_ID)
357                .into(),
358            |mut f_build| {
359                let mut circ = f_build.as_circuit(f_build.input_wires());
360                assert_eq!(circ.n_wires(), 1);
361
362                let [q0] = circ.tracked_units_arr();
363                let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
364                let ancilla = circ.track_wire(ancilla);
365
366                assert_ne!(ancilla, 0);
367                assert_eq!(circ.n_wires(), 2);
368                assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
369
370                circ.append(cx_gate(), [q0, ancilla])?;
371                let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
372
373                let q0 = circ.untrack_wire(q0)?;
374
375                assert_eq!(circ.tracked_units_arr(), [ancilla]);
376
377                circ.append_and_consume(q_discard(), [q0])?;
378
379                let outs = circ.finish();
380
381                assert_eq!(outs.len(), 1);
382
383                f_build.finish_with_outputs(outs)
384            },
385        );
386
387        assert_matches!(build_res, Ok(_));
388    }
389
390    #[test]
391    fn circuit_builder_errors() {
392        let _build_res = build_main(
393            Signature::new_endo(vec![qb_t(), qb_t()]).into(),
394            |mut f_build| {
395                let mut circ = f_build.as_circuit(f_build.input_wires());
396                let [q0, q1] = circ.tracked_units_arr();
397                let invalid_index = 0xff;
398
399                assert_matches!(
401                    circ.append(cx_gate(), [q0, invalid_index]),
402                    Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
403                    if op == Some(cx_gate().into()) && idx == invalid_index,
404                );
405
406                assert_matches!(
408                    circ.untrack_wire(invalid_index),
409                    Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
410                    if idx == invalid_index,
411                );
412
413                assert_matches!(
415                    circ.append(q_discard(), [q1]),
416                    Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
417                    if op == q_discard().into() && index == [q1],
418                );
419
420                let outs = circ.finish();
421
422                assert_eq!(outs.len(), 2);
423
424                f_build.finish_with_outputs(outs)
425            },
426        );
427
428        }
431}