hugr_core/
utils.rs

1//! General utilities.
2
3use std::fmt::{self, Debug, Display};
4
5use itertools::Itertools;
6
7use crate::{Hugr, HugrView, IncomingPort, Node, ops::Value};
8
9/// Write a comma separated list of of some types.
10/// Like `debug_list`, but using the Display instance rather than Debug,
11/// and not adding surrounding square brackets.
12pub fn display_list<T>(ts: impl IntoIterator<Item = T>, f: &mut fmt::Formatter) -> fmt::Result
13where
14    T: Display,
15{
16    display_list_with_separator(ts, f, ", ")
17}
18
19/// Write a separated list of of some types, using a custom separator.
20/// Like `debug_list`, but using the Display instance rather than Debug,
21/// and not adding surrounding square brackets.
22pub fn display_list_with_separator<T>(
23    ts: impl IntoIterator<Item = T>,
24    f: &mut fmt::Formatter,
25    sep: &str,
26) -> fmt::Result
27where
28    T: Display,
29{
30    let mut first = true;
31    for t in ts {
32        if !first {
33            f.write_str(sep)?;
34        }
35        t.fmt(f)?;
36        if first {
37            first = false;
38        }
39    }
40    Ok(())
41}
42
43/// Collect a vector into an array.
44///
45/// This is useful for deconstructing a vectors content.
46///
47/// # Example
48///
49/// ```ignore
50/// let iter = 0..3;
51/// let [a, b, c] = crate::utils::collect_array(iter);
52/// assert_eq!(b, 1);
53/// ```
54///
55/// # Panics
56///
57/// If the length of the slice is not equal to `N`.
58///
59/// See also [`try_collect_array`] for a non-panicking version.
60#[inline]
61#[track_caller]
62pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
63    match try_collect_array(arr) {
64        Ok(v) => v,
65        Err(v) => panic!("Expected {N} elements, got {v:?}"),
66    }
67}
68
69/// Collect a vector into an array.
70///
71/// This is useful for deconstructing a vectors content.
72///
73/// # Example
74///
75/// ```ignore
76/// let iter = 0..3;
77/// let [a, b, c] = crate::utils::try_collect_array(iter)
78///     .unwrap_or_else(|v| panic!("Expected 3 elements, got {:?}", v));
79/// assert_eq!(b, 1);
80/// ```
81///
82/// See also [`collect_array`].
83#[inline]
84#[track_caller]
85pub fn try_collect_array<const N: usize, T>(
86    arr: impl IntoIterator<Item = T>,
87) -> Result<[T; N], Vec<T>> {
88    arr.into_iter().collect_vec().try_into()
89}
90
91/// Helper method to skip serialization of default values in serde.
92///
93/// ```ignore
94/// use serde::Serialize;
95///
96/// #[derive(Serialize)]
97/// struct MyStruct {
98///     #[serde(skip_serializing_if = "crate::utils::is_default")]
99///     field: i32,
100/// }
101/// ```
102///
103/// From <https://github.com/serde-rs/serde/issues/818>.
104#[allow(dead_code)]
105pub(crate) fn is_default<T: Default + PartialEq>(t: &T) -> bool {
106    *t == Default::default()
107}
108
109/// An empty type.
110///
111/// # Example
112///
113/// ```ignore
114/// fn foo(never: Never) -> ! {
115///     match never {}
116/// }
117/// ```
118#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
119pub enum Never {}
120
121#[cfg(test)]
122pub(crate) mod test_quantum_extension {
123    use std::sync::{Arc, LazyLock};
124
125    use crate::ops::{OpName, OpNameRef};
126    use crate::std_extensions::arithmetic::float_ops;
127    use crate::std_extensions::logic;
128    use crate::types::FuncValueType;
129    use crate::{
130        Extension,
131        extension::{
132            ExtensionId, ExtensionRegistry, PRELUDE,
133            prelude::{bool_t, qb_t},
134        },
135        ops::ExtensionOp,
136        std_extensions::arithmetic::float_types,
137        type_row,
138        types::{PolyFuncTypeRV, Signature},
139    };
140
141    fn one_qb_func() -> PolyFuncTypeRV {
142        FuncValueType::new_endo(qb_t()).into()
143    }
144
145    fn two_qb_func() -> PolyFuncTypeRV {
146        FuncValueType::new_endo(vec![qb_t(), qb_t()]).into()
147    }
148    /// The extension identifier.
149    pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum");
150    fn extension() -> Arc<Extension> {
151        Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| {
152            extension
153                .add_op(
154                    OpName::new_inline("H"),
155                    "Hadamard".into(),
156                    one_qb_func(),
157                    extension_ref,
158                )
159                .unwrap();
160            extension
161                .add_op(
162                    OpName::new_inline("RzF64"),
163                    "Rotation specified by float".into(),
164                    Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]),
165                    extension_ref,
166                )
167                .unwrap();
168
169            extension
170                .add_op(
171                    OpName::new_inline("CX"),
172                    "CX".into(),
173                    two_qb_func(),
174                    extension_ref,
175                )
176                .unwrap();
177
178            extension
179                .add_op(
180                    OpName::new_inline("Measure"),
181                    "Measure a qubit, returning the qubit and the measurement result.".into(),
182                    Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]),
183                    extension_ref,
184                )
185                .unwrap();
186
187            extension
188                .add_op(
189                    OpName::new_inline("QAlloc"),
190                    "Allocate a new qubit.".into(),
191                    Signature::new(type_row![], vec![qb_t()]),
192                    extension_ref,
193                )
194                .unwrap();
195
196            extension
197                .add_op(
198                    OpName::new_inline("QDiscard"),
199                    "Discard a qubit.".into(),
200                    Signature::new(vec![qb_t()], type_row![]),
201                    extension_ref,
202                )
203                .unwrap();
204        })
205    }
206
207    /// Quantum extension definition.
208    pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(extension);
209
210    /// A registry with all necessary extensions to run tests internally, including the test quantum extension.
211    pub static REG: LazyLock<ExtensionRegistry> = LazyLock::new(|| {
212        ExtensionRegistry::new([
213            EXTENSION.clone(),
214            PRELUDE.clone(),
215            float_types::EXTENSION.clone(),
216            float_ops::EXTENSION.clone(),
217            logic::EXTENSION.clone(),
218        ])
219    });
220
221    fn get_gate(gate_name: &OpNameRef) -> ExtensionOp {
222        EXTENSION.instantiate_extension_op(gate_name, []).unwrap()
223    }
224    pub(crate) fn h_gate() -> ExtensionOp {
225        get_gate("H")
226    }
227
228    pub(crate) fn cx_gate() -> ExtensionOp {
229        get_gate("CX")
230    }
231
232    pub(crate) fn measure() -> ExtensionOp {
233        get_gate("Measure")
234    }
235
236    pub(crate) fn rz_f64() -> ExtensionOp {
237        get_gate("RzF64")
238    }
239
240    pub(crate) fn q_alloc() -> ExtensionOp {
241        get_gate("QAlloc")
242    }
243
244    pub(crate) fn q_discard() -> ExtensionOp {
245        get_gate("QDiscard")
246    }
247}
248
249/// Sort folding inputs with [`IncomingPort`] as key
250fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> {
251    let mut v: Vec<_> = consts.iter().collect();
252    v.sort_by_key(|(i, _)| i);
253    v
254}
255
256/// Sort some input constants by port and just return the constants.
257#[must_use]
258pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
259    sort_by_in_port(consts)
260        .into_iter()
261        .map(|(_, c)| c)
262        .collect()
263}
264
265/// Calculate the depth of a node in the hierarchy.
266pub fn depth(h: &Hugr, n: Node) -> u32 {
267    match h.get_parent(n) {
268        Some(p) => 1 + depth(h, p),
269        None => 0,
270    }
271}
272
273#[allow(dead_code)]
274// Test only utils
275#[cfg(test)]
276pub(crate) mod test {
277    #[allow(unused_imports)]
278    use crate::HugrView;
279    use crate::{
280        Hugr,
281        ops::{OpType, Value},
282    };
283
284    /// Check that a hugr just loads and returns a single expected constant.
285    pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
286        assert_fully_folded_with(h, |v| v == expected_value);
287    }
288
289    /// Check that a hugr just loads and returns a single constant, and validate
290    /// that constant using `check_value`.
291    ///
292    /// [`CustomConst::equals_const`] is not required to be implemented. Use this
293    /// function for Values containing such a `CustomConst`.
294    pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
295        let mut node_count = 0;
296
297        for node in h.children(h.entrypoint()) {
298            let op = h.get_optype(node);
299            match op {
300                OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
301                OpType::Const(c) if check_value(c.value()) => node_count += 1,
302                _ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
303            }
304        }
305
306        assert_eq!(node_count, 4);
307    }
308}