hugr_core/builder/
circuit.rs

1use std::collections::HashMap;
2use std::mem;
3
4use thiserror::Error;
5
6use crate::ops::{NamedOp, OpType, Value};
7use crate::utils::collect_array;
8
9use super::{BuildError, Dataflow};
10use crate::{CircuitUnit, Wire};
11
12/// Builder to build regions of dataflow graphs that look like Circuits,
13/// where some inputs of operations directly correspond to some outputs.
14///
15/// Allows appending operations by indexing a vector of input wires.
16#[derive(Debug, PartialEq)]
17pub struct CircuitBuilder<'a, T: ?Sized> {
18    /// List of wires that are being tracked, identified by their index in the vector.
19    ///
20    /// Terminating wires may create holes in the vector, but the indices are stable.
21    wires: Vec<Option<Wire>>,
22    builder: &'a mut T,
23}
24
25#[derive(Debug, Clone, PartialEq, Error)]
26/// Error in [`CircuitBuilder`]
27#[non_exhaustive]
28pub enum CircuitBuildError {
29    /// Invalid index for stored wires.
30    #[error("Invalid wire index {invalid_index} while attempting to add operation {}.", .op.as_ref().map(|o| o.name()).unwrap_or_default())]
31    InvalidWireIndex {
32        /// The operation.
33        op: Option<OpType>,
34        /// The invalid indices.
35        invalid_index: usize,
36    },
37    /// Some linear inputs had no corresponding output wire.
38    #[error("The linear inputs {:?} had no corresponding output wire in operation {}.", .index.as_slice(), .op.name())]
39    MismatchedLinearInputs {
40        /// The operation.
41        op: OpType,
42        /// The index of the input that had no corresponding output wire.
43        index: Vec<usize>,
44    },
45}
46
47impl<'a, T: Dataflow + ?Sized> CircuitBuilder<'a, T> {
48    /// Construct a new [`CircuitBuilder`] from a vector of incoming wires and the
49    /// builder for the graph.
50    pub fn new(wires: impl IntoIterator<Item = Wire>, builder: &'a mut T) -> Self {
51        Self {
52            wires: wires.into_iter().map(Some).collect(),
53            builder,
54        }
55    }
56
57    /// Returns the number of wires tracked.
58    #[must_use]
59    pub fn n_wires(&self) -> usize {
60        self.wires.iter().flatten().count()
61    }
62
63    /// Returns the wire associated with the given index.
64    #[must_use]
65    pub fn tracked_wire(&self, index: usize) -> Option<Wire> {
66        self.wires.get(index).copied().flatten()
67    }
68
69    /// Returns an iterator over the tracked linear units.
70    pub fn tracked_units(&self) -> impl Iterator<Item = usize> + '_ {
71        self.wires
72            .iter()
73            .enumerate()
74            .filter_map(|(i, w)| w.map(|_| i))
75    }
76
77    /// Returns an array with the tracked linear units.
78    ///
79    /// # Panics
80    ///
81    /// If the number of outputs does not match `N`.
82    #[must_use]
83    pub fn tracked_units_arr<const N: usize>(&self) -> [usize; N] {
84        collect_array(self.tracked_units())
85    }
86
87    #[inline]
88    /// Append an op to the wires in the inner vector with given `indices`.
89    /// The outputs of the operation become the new wires at those indices.
90    /// Only valid for operations that have the same input type row as output
91    /// type row.
92    /// Returns a handle to self to allow chaining.
93    pub fn append(
94        &mut self,
95        op: impl Into<OpType>,
96        indices: impl IntoIterator<Item = usize> + Clone,
97    ) -> Result<&mut Self, BuildError> {
98        self.append_and_consume(op, indices)
99    }
100
101    #[inline]
102    /// The same as [`CircuitBuilder::append_with_outputs`] except it assumes no outputs and
103    /// instead returns a reference to self to allow chaining.
104    pub fn append_and_consume<A: Into<CircuitUnit>>(
105        &mut self,
106        op: impl Into<OpType>,
107        inputs: impl IntoIterator<Item = A>,
108    ) -> Result<&mut Self, BuildError> {
109        self.append_with_outputs(op, inputs)?;
110        Ok(self)
111    }
112
113    /// Append an `op` with some inputs being the stored wires.
114    /// Any inputs of the form [`CircuitUnit::Linear`] are used to index the
115    /// stored wires.
116    /// The outputs at those indices are used to replace the stored wire.
117    /// The remaining outputs are returned.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error on an invalid input unit.
122    pub fn append_with_outputs<A: Into<CircuitUnit>>(
123        &mut self,
124        op: impl Into<OpType>,
125        inputs: impl IntoIterator<Item = A>,
126    ) -> Result<Vec<Wire>, BuildError> {
127        // map of linear port offset to wire vector index
128        let mut linear_inputs = HashMap::new();
129        let op = op.into();
130
131        let input_wires: Result<Vec<Wire>, usize> = inputs
132            .into_iter()
133            .map(Into::into)
134            .enumerate()
135            .map(|(input_port, a_w): (usize, CircuitUnit)| match a_w {
136                CircuitUnit::Wire(wire) => Ok(wire),
137                CircuitUnit::Linear(wire_index) => {
138                    linear_inputs.insert(input_port, wire_index);
139                    self.tracked_wire(wire_index).ok_or(wire_index)
140                }
141            })
142            .collect();
143
144        let input_wires =
145            input_wires.map_err(|invalid_index| CircuitBuildError::InvalidWireIndex {
146                op: Some(op.clone()),
147                invalid_index,
148            })?;
149
150        let output_wires = self
151            .builder
152            .add_dataflow_op(
153                op.clone(), // TODO: Add extension param
154                input_wires,
155            )?
156            .outputs();
157        let nonlinear_outputs: Vec<Wire> = output_wires
158            .enumerate()
159            .filter_map(|(output_port, wire)| {
160                if let Some(wire_index) = linear_inputs.remove(&output_port) {
161                    // output at output_port replaces input wire from same port
162                    self.wires[wire_index] = Some(wire);
163                    None
164                } else {
165                    Some(wire)
166                }
167            })
168            .collect();
169
170        if !linear_inputs.is_empty() {
171            return Err(CircuitBuildError::MismatchedLinearInputs {
172                op,
173                index: linear_inputs.values().copied().collect(),
174            }
175            .into());
176        }
177
178        Ok(nonlinear_outputs)
179    }
180
181    /// Append an `op` with some inputs being the stored wires.
182    /// Any inputs of the form [`CircuitUnit::Linear`] are used to index the
183    /// stored wires.
184    /// The outputs at those indices are used to replace the stored wire.
185    /// The remaining outputs are returned as an array.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error on an invalid input unit.
190    ///
191    /// # Panics
192    ///
193    /// If the number of outputs does not match `N`.
194    pub fn append_with_outputs_arr<const N: usize, A: Into<CircuitUnit>>(
195        &mut self,
196        op: impl Into<OpType>,
197        inputs: impl IntoIterator<Item = A>,
198    ) -> Result<[Wire; N], BuildError> {
199        let outputs = self.append_with_outputs(op, inputs)?;
200        Ok(collect_array(outputs))
201    }
202
203    /// Adds a constant value to the circuit and loads it into a wire.
204    pub fn add_constant(&mut self, value: impl Into<Value>) -> Wire {
205        self.builder.add_load_value(value)
206    }
207
208    /// Add a wire to the list of tracked wires.
209    ///
210    /// Returns the new unit index.
211    pub fn track_wire(&mut self, wire: Wire) -> usize {
212        self.wires.push(Some(wire));
213        self.wires.len() - 1
214    }
215
216    /// Stops tracking a linear unit, and returns the last wire corresponding to it.
217    ///
218    /// Returns the new unit index.
219    ///
220    /// # Errors
221    ///
222    /// Returns a [`CircuitBuildError::InvalidWireIndex`] if the index is invalid.
223    pub fn untrack_wire(&mut self, index: usize) -> Result<Wire, CircuitBuildError> {
224        self.wires
225            .get_mut(index)
226            .and_then(mem::take)
227            .ok_or(CircuitBuildError::InvalidWireIndex {
228                op: None,
229                invalid_index: index,
230            })
231    }
232
233    #[inline]
234    /// Finish building the circuit region and return the dangling wires
235    /// that correspond to the initially provided wires.
236    pub fn finish(self) -> Vec<Wire> {
237        self.wires.into_iter().flatten().collect()
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use super::*;
244    use cool_asserts::assert_matches;
245
246    use crate::builder::{Container, HugrBuilder, ModuleBuilder};
247    use crate::extension::prelude::{qb_t, usize_t};
248    use crate::extension::{ExtensionId, ExtensionSet};
249    use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
250    use crate::utils::test_quantum_extension::{
251        self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
252    };
253    use crate::Extension;
254    use crate::{
255        builder::{test::build_main, DataflowSubContainer},
256        extension::prelude::bool_t,
257        types::Signature,
258    };
259
260    #[test]
261    fn simple_linear() {
262        let build_res = build_main(
263            Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()])
264                .with_extension_delta(test_quantum_extension::EXTENSION_ID)
265                .with_extension_delta(float_types::EXTENSION_ID)
266                .into(),
267            |mut f_build| {
268                let wires = f_build.input_wires().map(Some).collect();
269
270                let mut linear = CircuitBuilder {
271                    wires,
272                    builder: &mut f_build,
273                };
274
275                assert_eq!(linear.n_wires(), 2);
276
277                linear
278                    .append(h_gate(), [0])?
279                    .append(cx_gate(), [0, 1])?
280                    .append(cx_gate(), [1, 0])?;
281
282                let angle = linear.add_constant(ConstF64::new(0.5));
283                linear.append_and_consume(
284                    rz_f64(),
285                    [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
286                )?;
287
288                let outs = linear.finish();
289                f_build.finish_with_outputs(outs)
290            },
291        );
292
293        assert_matches!(build_res, Ok(_));
294    }
295
296    #[test]
297    fn with_nonlinear_and_outputs() {
298        let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
299        let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
300            ext.add_op(
301                "MyOp".into(),
302                "".to_string(),
303                Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
304                extension_ref,
305            )
306            .unwrap();
307        });
308        let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap();
309
310        let mut module_builder = ModuleBuilder::new();
311        let mut f_build = module_builder
312            .define_function(
313                "main",
314                Signature::new(
315                    vec![qb_t(), qb_t(), usize_t()],
316                    vec![qb_t(), qb_t(), bool_t()],
317                )
318                .with_extension_delta(ExtensionSet::from_iter([
319                    test_quantum_extension::EXTENSION_ID,
320                    my_ext_name,
321                ])),
322            )
323            .unwrap();
324
325        let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
326
327        let mut linear = f_build.as_circuit([q0, q1]);
328
329        let measure_out = linear
330            .append(cx_gate(), [0, 1])
331            .unwrap()
332            .append_and_consume(
333                my_custom_op,
334                [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
335            )
336            .unwrap()
337            .append_with_outputs(measure(), [0])
338            .unwrap();
339
340        let out_qbs = linear.finish();
341        f_build
342            .finish_with_outputs(out_qbs.into_iter().chain(measure_out))
343            .unwrap();
344
345        let mut registry = test_quantum_extension::REG.clone();
346        registry.register(my_ext).unwrap();
347        let build_res = module_builder.finish_hugr();
348
349        assert_matches!(build_res, Ok(_));
350    }
351
352    #[test]
353    fn ancillae() {
354        let build_res = build_main(
355            Signature::new_endo(qb_t())
356                .with_extension_delta(test_quantum_extension::EXTENSION_ID)
357                .into(),
358            |mut f_build| {
359                let mut circ = f_build.as_circuit(f_build.input_wires());
360                assert_eq!(circ.n_wires(), 1);
361
362                let [q0] = circ.tracked_units_arr();
363                let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?;
364                let ancilla = circ.track_wire(ancilla);
365
366                assert_ne!(ancilla, 0);
367                assert_eq!(circ.n_wires(), 2);
368                assert_eq!(circ.tracked_units_arr(), [q0, ancilla]);
369
370                circ.append(cx_gate(), [q0, ancilla])?;
371                let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?;
372
373                let q0 = circ.untrack_wire(q0)?;
374
375                assert_eq!(circ.tracked_units_arr(), [ancilla]);
376
377                circ.append_and_consume(q_discard(), [q0])?;
378
379                let outs = circ.finish();
380
381                assert_eq!(outs.len(), 1);
382
383                f_build.finish_with_outputs(outs)
384            },
385        );
386
387        assert_matches!(build_res, Ok(_));
388    }
389
390    #[test]
391    fn circuit_builder_errors() {
392        let _build_res = build_main(
393            Signature::new_endo(vec![qb_t(), qb_t()]).into(),
394            |mut f_build| {
395                let mut circ = f_build.as_circuit(f_build.input_wires());
396                let [q0, q1] = circ.tracked_units_arr();
397                let invalid_index = 0xff;
398
399                // Passing an invalid linear index returns an error
400                assert_matches!(
401                    circ.append(cx_gate(), [q0, invalid_index]),
402                    Err(BuildError::CircuitError(CircuitBuildError::InvalidWireIndex { op, invalid_index: idx }))
403                    if op == Some(cx_gate().into()) && idx == invalid_index,
404                );
405
406                // Untracking an invalid index returns an error
407                assert_matches!(
408                    circ.untrack_wire(invalid_index),
409                    Err(CircuitBuildError::InvalidWireIndex { op: None, invalid_index: idx })
410                    if idx == invalid_index,
411                );
412
413                // Passing a linear index to an operation without a corresponding output returns an error
414                assert_matches!(
415                    circ.append(q_discard(), [q1]),
416                    Err(BuildError::CircuitError(CircuitBuildError::MismatchedLinearInputs { op, index }))
417                    if op == q_discard().into() && index == [q1],
418                );
419
420                let outs = circ.finish();
421
422                assert_eq!(outs.len(), 2);
423
424                f_build.finish_with_outputs(outs)
425            },
426        );
427
428        // We do not test the build output, as the internal errors may have left
429        // the hugr in an invalid state.
430    }
431}