1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18    wires: Vec<Option<Wire>>,
22    builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26#[non_exhaustive]
28pub enum CircuitBuildError {
29    #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|op| op.name()).unwrap_or_default())]
31    InvalidWireIndex {
32        op: Option<Box<OpType>>,
34        invalid_index: usize,
36    },
37    #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39    MismatchedLinearInputs {
40        op: Box<OpType>,
42        index: Vec<usize>,
44    },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48    pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51        Self {
52            wires: wires.into_iter().map(Some).collect(),
53            builder,
54        }
55    }
56
57    #[must_use]
59    pub fn n_wires(&self) -> usize {
60        self.wires.iter().flatten().count()
61    }
62
63    #[must_use]
65    pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66        self.wires.get(index).copied().flatten()
67    }
68
69    pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71        self.wires
72            .iter()
73            .enumerate()
74            .filter_map(|(i, w)| w.map(|_| i))
75    }
76
77    #[must_use]
83    pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84        collect_array(self.tracked_units())
85    }
86
87    #[inline]
88    pub fn append(
94        &mut self,
95        op: impl Into<OpType>,
96        indices: impl IntoIterator<Item = usize> + Clone,
97    ) -> Result<&mut Self, BuildError> {
98        self.append_and_consume(op, indices)
99    }
100
101    #[inline]
102    pub fn append_and_consume<A: Into<CircuitUnit>>(
105        &mut self,
106        op: impl Into<OpType>,
107        inputs: impl IntoIterator<Item = A>,
108    ) -> Result<&mut Self, BuildError> {
109        self.append_with_outputs(op, inputs)?;
110        Ok(self)
111    }
112
113    pub fn append_with_outputs<A: Into<CircuitUnit>>(
123        &mut self,
124        op: impl Into<OpType>,
125        inputs: impl IntoIterator<Item = A>,
126    ) -> Result<Vec<Wire>, BuildError> {
127        let mut linear_inputs = HashMap::new();
129        let op = op.into();
130
131        let input_wires: Result<Vec<Wire>, usize> = inputs
132            .into_iter()
133            .map(Into::into)
134            .enumerate()
135            .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136                CircuitUnit::Wire(wire) => Ok(wire),
137                CircuitUnit::Linear(wire_index) => {
138                    linear_inputs.insert(input_port, wire_index);
139                    self.tracked_wire(wire_index).ok_or(wire_index)
140                }
141            })
142            .collect();
143
144        let input_wires =
145            input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146                op: Some(Box::new(op.clone())),
147                invalid_index,
148            })?;
149
150        let output_wires = self
151            .builder
152            .add_dataflow_op(
153                op.clone(), input_wires,
155            )?
156            .outputs();
157        let nonlinear_outputs: Vec<Wire> = output_wires
158            .enumerate()
159            .filter_map(|(output_port, wire)| {
160                if let Some(wire_index) = linear_inputs.remove(&output_port) {
161                    self.wires[wire_index] = Some(wire);
163                    None
164                } else {
165                    Some(wire)
166                }
167            })
168            .collect();
169
170        if !linear_inputs.is_empty() {
171            return Err(CircuitBuildError::MismatchedLinearInputs {
172                op: Box::new(op),
173                index: linear_inputs.values().copied().collect(),
174            }
175            .into());
176        }
177
178        Ok(nonlinear_outputs)
179    }
180
181    pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195        &mut self,
196        op: impl Into<OpType>,
197        inputs: impl IntoIterator<Item = A>,
198    ) -> Result<[Wire; N], BuildError> {
199        let outputs = self.append_with_outputs(op, inputs)?;
200        Ok(collect_array(outputs))
201    }
202
203    pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205        self.builder.add_load_value(value)
206    }
207
208    pub fn track_wire(&mut self, wire: Wire) -> usize {
212        self.wires.push(Some(wire));
213        self.wires.len() - 1
214    }
215
216    pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224        self.wires
225            .get_mut(index)
226            .and_then(mem::take)
227            .ok_or(CircuitBuildError::InvalidWireIndex {
228                op: None,
229                invalid_index: index,
230            })
231    }
232
233    #[inline]
234    #[must_use]
237    pub fn finish(self) -> Vec<Wire> {
238        self.wires.into_iter().flatten().collect()
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use cool_asserts::assert_matches;
246
247    use crate::Extension;
248    use crate::builder::{HugrBuilder, ModuleBuilder};
249    use crate::extension::ExtensionId;
250    use crate::extension::prelude::{qb_t, usize_t};
251    use crate::std_extensions::arithmetic::float_types::ConstF64;
252    use crate::utils::test_quantum_extension::{
253        self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
254    };
255    use crate::{
256        builder::{DataflowSubContainer, test::build_main},
257        extension::prelude::bool_t,
258        types::Signature,
259    };
260
261    #[test]
262    fn simple_linear() {
263        let build_res = build_main(
264            Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(),
265            |mut f_build| {
266                let wires = f_build.input_wires().map(Some).collect();
267
268                let mut linear = CircuitBuilder {
269                    wires,
270                    builder: &mut f_build,
271                };
272
273                assert_eq!(linear.n_wires(), 2);
274
275                linear
276                    .append(h_gate(), [0])?
277                    .append(cx_gate(), [0, 1])?
278                    .append(cx_gate(), [1, 0])?;
279
280                let angle = linear.add_constant(ConstF64::new(0.5));
281                linear.append_and_consume(
282                    rz_f64(),
283                    [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
284                )?;
285
286                let outs = linear.finish();
287                f_build.finish_with_outputs(outs)
288            },
289        );
290
291        assert_matches!(build_res, Ok(_));
292    }
293
294    #[test]
295    fn with_nonlinear_and_outputs() {
296        let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
297        let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
298            ext.add_op(
299                "MyOp".into(),
300                String::new(),
301                Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
302                extension_ref,
303            )
304            .unwrap();
305        });
306        let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
307
308        let mut module_builder = ModuleBuilder::new();
309        let mut f_build = module_builder
310            .define_function(
311                "main",
312                Signature::new(
313                    vec![qb_t(), qb_t(), usize_t()],
314                    vec![qb_t(), qb_t(), bool_t()],
315                ),
316            )
317            .unwrap();
318
319        let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
320
321        let mut linear = f_build.as_circuit([q0, q1]);
322
323        let measure_out = linear
324            .append(cx_gate(), [0, 1])
325            .unwrap()
326            .append_and_consume(
327                my_custom_op,
328                [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
329            )
330            .unwrap()
331            .append_with_outputs(measure(), [0])
332            .unwrap();
333
334        let out_qbs = linear.finish();
335        f_build
336            .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
337            .unwrap();
338
339        let mut registry = test_quantum_extension::REG.clone();
340        registry.register(my_ext).unwrap();
341        let build_res = module_builder.finish_hugr();
342
343        assert_matches!(build_res, Ok(_));
344    }
345
346    #[test]
347    fn ancillae() {
348        let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| {
349            let mut circ = f_build.as_circuit(f_build.input_wires());
350            assert_eq!(circ.n_wires(), 1);
351
352            let [q0] = circ.tracked_units_arr();
353            let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
354            let ancilla = circ.track_wire(ancilla);
355
356            assert_ne!(ancilla, 0);
357            assert_eq!(circ.n_wires(), 2);
358            assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
359
360            circ.append(cx_gate(), [q0, ancilla])?;
361            let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
362
363            let q0 = circ.untrack_wire(q0)?;
364
365            assert_eq!(circ.tracked_units_arr(), [ancilla]);
366
367            circ.append_and_consume(q_discard(), [q0])?;
368
369            let outs = circ.finish();
370
371            assert_eq!(outs.len(), 1);
372
373            f_build.finish_with_outputs(outs)
374        });
375
376        assert_matches!(build_res, Ok(_));
377    }
378
379    #[test]
380    fn circuit_builder_errors() {
381        let _build_res = build_main(
382            Signature::new_endo(vec![qb_t(), qb_t()]).into(),
383            |mut f_build| {
384                let mut circ = f_build.as_circuit(f_build.input_wires());
385                let [q0, q1] = circ.tracked_units_arr();
386                let invalid_index = 0xff;
387
388                assert_matches!(
390                    circ.append(cx_gate(), [q0, invalid_index]),
391                    Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
392                    if op == Some(Box::new(cx_gate().into())) && idx == invalid_index,
393                );
394
395                assert_matches!(
397                    circ.untrack_wire(invalid_index),
398                    Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
399                    if idx == invalid_index,
400                );
401
402                assert_matches!(
404                    circ.append(q_discard(), [q1]),
405                    Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
406                    if *op == q_discard().into() && index == [q1],
407                );
408
409                let outs = circ.finish();
410
411                assert_eq!(outs.len(), 2);
412
413                f_build.finish_with_outputs(outs)
414            },
415        );
416
417        }
420}