hugr_core/builder/
circuit.rs

1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12/// Builder to build regions of dataflow graphs that look like Circuits,
13/// where some inputs of operations directly correspond to some outputs.
14///
15/// Allows appending operations by indexing a vector of input wires.
16#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18    /// List of wires that are being tracked, identified by their index in the vector.
19    ///
20    /// Terminating wires may create holes in the vector, but the indices are stable.
21    wires: Vec<Option<Wire>>,
22    builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26/// Error in [`CircuitBuilder`]
27#[non_exhaustive]
28pub enum CircuitBuildError {
29    /// Invalid index for stored wires.
30    #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|op| op.name()).unwrap_or_default())]
31    InvalidWireIndex {
32        /// The operation.
33        op: Option<Box<OpType>>,
34        /// The invalid indices.
35        invalid_index: usize,
36    },
37    /// Some linear inputs had no corresponding output wire.
38    #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39    MismatchedLinearInputs {
40        /// The operation.
41        op: Box<OpType>,
42        /// The index of the input that had no corresponding output wire.
43        index: Vec<usize>,
44    },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48    /// Construct a new [`CircuitBuilder`] from a vector of incoming wires and the
49    /// builder for the graph.
50    pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51        Self {
52            wires: wires.into_iter().map(Some).collect(),
53            builder,
54        }
55    }
56
57    /// Returns the number of wires tracked.
58    #[must_use]
59    pub fn n_wires(&self) -> usize {
60        self.wires.iter().flatten().count()
61    }
62
63    /// Returns the wire associated with the given index.
64    #[must_use]
65    pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66        self.wires.get(index).copied().flatten()
67    }
68
69    /// Returns an iterator over the tracked linear units.
70    pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71        self.wires
72            .iter()
73            .enumerate()
74            .filter_map(|(i, w)| w.map(|_| i))
75    }
76
77    /// Returns an array with the tracked linear units.
78    ///
79    /// # Panics
80    ///
81    /// If the number of outputs does not match `N`.
82    #[must_use]
83    pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84        collect_array(self.tracked_units())
85    }
86
87    #[inline]
88    /// Append an op to the wires in the inner vector with given `indices`.
89    /// The outputs of the operation become the new wires at those indices.
90    /// Only valid for operations that have the same input type row as output
91    /// type row.
92    /// Returns a handle to self to allow chaining.
93    pub fn append(
94        &mut self,
95        op: impl Into<OpType>,
96        indices: impl IntoIterator<Item = usize> + Clone,
97    ) -> Result<&mut Self, BuildError> {
98        self.append_and_consume(op, indices)
99    }
100
101    #[inline]
102    /// The same as [`CircuitBuilder::append_with_outputs`] except it assumes no outputs and
103    /// instead returns a reference to self to allow chaining.
104    pub fn append_and_consume<A: Into<CircuitUnit>>(
105        &mut self,
106        op: impl Into<OpType>,
107        inputs: impl IntoIterator<Item = A>,
108    ) -> Result<&mut Self, BuildError> {
109        self.append_with_outputs(op, inputs)?;
110        Ok(self)
111    }
112
113    /// Append an `op` with some inputs being the stored wires.
114    /// Any inputs of the form [`CircuitUnit::Linear`] are used to index the
115    /// stored wires.
116    /// The outputs at those indices are used to replace the stored wire.
117    /// The remaining outputs are returned.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error on an invalid input unit.
122    pub fn append_with_outputs<A: Into<CircuitUnit>>(
123        &mut self,
124        op: impl Into<OpType>,
125        inputs: impl IntoIterator<Item = A>,
126    ) -> Result<Vec<Wire>, BuildError> {
127        // map of linear port offset to wire vector index
128        let mut linear_inputs = HashMap::new();
129        let op = op.into();
130
131        let input_wires: Result<Vec<Wire>, usize> = inputs
132            .into_iter()
133            .map(Into::into)
134            .enumerate()
135            .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136                CircuitUnit::Wire(wire) => Ok(wire),
137                CircuitUnit::Linear(wire_index) => {
138                    linear_inputs.insert(input_port, wire_index);
139                    self.tracked_wire(wire_index).ok_or(wire_index)
140                }
141            })
142            .collect();
143
144        let input_wires =
145            input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146                op: Some(Box::new(op.clone())),
147                invalid_index,
148            })?;
149
150        let output_wires = self
151            .builder
152            .add_dataflow_op(
153                op.clone(), // TODO: Add extension param
154                input_wires,
155            )?
156            .outputs();
157        let nonlinear_outputs: Vec<Wire> = output_wires
158            .enumerate()
159            .filter_map(|(output_port, wire)| {
160                if let Some(wire_index) = linear_inputs.remove(&output_port) {
161                    // output at output_port replaces input wire from same port
162                    self.wires[wire_index] = Some(wire);
163                    None
164                } else {
165                    Some(wire)
166                }
167            })
168            .collect();
169
170        if !linear_inputs.is_empty() {
171            return Err(CircuitBuildError::MismatchedLinearInputs {
172                op: Box::new(op),
173                index: linear_inputs.values().copied().collect(),
174            }
175            .into());
176        }
177
178        Ok(nonlinear_outputs)
179    }
180
181    /// Append an `op` with some inputs being the stored wires.
182    /// Any inputs of the form [`CircuitUnit::Linear`] are used to index the
183    /// stored wires.
184    /// The outputs at those indices are used to replace the stored wire.
185    /// The remaining outputs are returned as an array.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error on an invalid input unit.
190    ///
191    /// # Panics
192    ///
193    /// If the number of outputs does not match `N`.
194    pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195        &mut self,
196        op: impl Into<OpType>,
197        inputs: impl IntoIterator<Item = A>,
198    ) -> Result<[Wire; N], BuildError> {
199        let outputs = self.append_with_outputs(op, inputs)?;
200        Ok(collect_array(outputs))
201    }
202
203    /// Adds a constant value to the circuit and loads it into a wire.
204    pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205        self.builder.add_load_value(value)
206    }
207
208    /// Add a wire to the list of tracked wires.
209    ///
210    /// Returns the new unit index.
211    pub fn track_wire(&mut self, wire: Wire) -> usize {
212        self.wires.push(Some(wire));
213        self.wires.len() - 1
214    }
215
216    /// Stops tracking a linear unit, and returns the last wire corresponding to it.
217    ///
218    /// Returns the new unit index.
219    ///
220    /// # Errors
221    ///
222    /// Returns a [`CircuitBuildError::InvalidWireIndex`] if the index is invalid.
223    pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224        self.wires
225            .get_mut(index)
226            .and_then(mem::take)
227            .ok_or(CircuitBuildError::InvalidWireIndex {
228                op: None,
229                invalid_index: index,
230            })
231    }
232
233    #[inline]
234    /// Finish building the circuit region and return the dangling wires
235    /// that correspond to the initially provided wires.
236    #[must_use]
237    pub fn finish(self) -> Vec<Wire> {
238        self.wires.into_iter().flatten().collect()
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use cool_asserts::assert_matches;
246
247    use crate::Extension;
248    use crate::builder::{HugrBuilder, ModuleBuilder};
249    use crate::extension::ExtensionId;
250    use crate::extension::prelude::{qb_t, usize_t};
251    use crate::std_extensions::arithmetic::float_types::ConstF64;
252    use crate::utils::test_quantum_extension::{
253        self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
254    };
255    use crate::{
256        builder::{DataflowSubContainer, test::build_main},
257        extension::prelude::bool_t,
258        types::Signature,
259    };
260
261    #[test]
262    fn simple_linear() {
263        let build_res = build_main(
264            Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(),
265            |mut f_build| {
266                let wires = f_build.input_wires().map(Some).collect();
267
268                let mut linear = CircuitBuilder {
269                    wires,
270                    builder: &mut f_build,
271                };
272
273                assert_eq!(linear.n_wires(), 2);
274
275                linear
276                    .append(h_gate(), [0])?
277                    .append(cx_gate(), [0, 1])?
278                    .append(cx_gate(), [1, 0])?;
279
280                let angle = linear.add_constant(ConstF64::new(0.5));
281                linear.append_and_consume(
282                    rz_f64(),
283                    [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
284                )?;
285
286                let outs = linear.finish();
287                f_build.finish_with_outputs(outs)
288            },
289        );
290
291        assert_matches!(build_res, Ok(_));
292    }
293
294    #[test]
295    fn with_nonlinear_and_outputs() {
296        let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
297        let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
298            ext.add_op(
299                "MyOp".into(),
300                String::new(),
301                Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
302                extension_ref,
303            )
304            .unwrap();
305        });
306        let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
307
308        let mut module_builder = ModuleBuilder::new();
309        let mut f_build = module_builder
310            .define_function(
311                "main",
312                Signature::new(
313                    vec![qb_t(), qb_t(), usize_t()],
314                    vec![qb_t(), qb_t(), bool_t()],
315                ),
316            )
317            .unwrap();
318
319        let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
320
321        let mut linear = f_build.as_circuit([q0, q1]);
322
323        let measure_out = linear
324            .append(cx_gate(), [0, 1])
325            .unwrap()
326            .append_and_consume(
327                my_custom_op,
328                [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
329            )
330            .unwrap()
331            .append_with_outputs(measure(), [0])
332            .unwrap();
333
334        let out_qbs = linear.finish();
335        f_build
336            .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
337            .unwrap();
338
339        let mut registry = test_quantum_extension::REG.clone();
340        registry.register(my_ext).unwrap();
341        let build_res = module_builder.finish_hugr();
342
343        assert_matches!(build_res, Ok(_));
344    }
345
346    #[test]
347    fn ancillae() {
348        let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| {
349            let mut circ = f_build.as_circuit(f_build.input_wires());
350            assert_eq!(circ.n_wires(), 1);
351
352            let [q0] = circ.tracked_units_arr();
353            let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
354            let ancilla = circ.track_wire(ancilla);
355
356            assert_ne!(ancilla, 0);
357            assert_eq!(circ.n_wires(), 2);
358            assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
359
360            circ.append(cx_gate(), [q0, ancilla])?;
361            let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
362
363            let q0 = circ.untrack_wire(q0)?;
364
365            assert_eq!(circ.tracked_units_arr(), [ancilla]);
366
367            circ.append_and_consume(q_discard(), [q0])?;
368
369            let outs = circ.finish();
370
371            assert_eq!(outs.len(), 1);
372
373            f_build.finish_with_outputs(outs)
374        });
375
376        assert_matches!(build_res, Ok(_));
377    }
378
379    #[test]
380    fn circuit_builder_errors() {
381        let _build_res = build_main(
382            Signature::new_endo(vec![qb_t(), qb_t()]).into(),
383            |mut f_build| {
384                let mut circ = f_build.as_circuit(f_build.input_wires());
385                let [q0, q1] = circ.tracked_units_arr();
386                let invalid_index = 0xff;
387
388                // Passing an invalid linear index returns an error
389                assert_matches!(
390                    circ.append(cx_gate(), [q0, invalid_index]),
391                    Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
392                    if op == Some(Box::new(cx_gate().into())) && idx == invalid_index,
393                );
394
395                // Untracking an invalid index returns an error
396                assert_matches!(
397                    circ.untrack_wire(invalid_index),
398                    Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
399                    if idx == invalid_index,
400                );
401
402                // Passing a linear index to an operation without a corresponding output returns an error
403                assert_matches!(
404                    circ.append(q_discard(), [q1]),
405                    Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
406                    if *op == q_discard().into() && index == [q1],
407                );
408
409                let outs = circ.finish();
410
411                assert_eq!(outs.len(), 2);
412
413                f_build.finish_with_outputs(outs)
414            },
415        );
416
417        // We do not test the build output, as the internal errors may have left
418        // the hugr in an invalid state.
419    }
420}