hugr_core/
utils.rs

1//! General utilities.
2
3use std::fmt::{self, Debug, Display};
4
5use itertools::Itertools;
6
7use crate::{Hugr, HugrView, IncomingPort, Node, ops::Value};
8
9/// Write a comma separated list of of some types.
10/// Like `debug_list`, but using the Display instance rather than Debug,
11/// and not adding surrounding square brackets.
12pub fn display_list<T>(ts: impl IntoIterator<Item = T>, f: &mut fmt::Formatter) -> fmt::Result
13where
14    T: Display,
15{
16    display_list_with_separator(ts, f, ", ")
17}
18
19/// Write a separated list of of some types, using a custom separator.
20/// Like `debug_list`, but using the Display instance rather than Debug,
21/// and not adding surrounding square brackets.
22pub fn display_list_with_separator<T>(
23    ts: impl IntoIterator<Item = T>,
24    f: &mut fmt::Formatter,
25    sep: &str,
26) -> fmt::Result
27where
28    T: Display,
29{
30    let mut first = true;
31    for t in ts {
32        if !first {
33            f.write_str(sep)?;
34        }
35        t.fmt(f)?;
36        if first {
37            first = false;
38        }
39    }
40    Ok(())
41}
42
43/// Collect a vector into an array.
44///
45/// This is useful for deconstructing a vectors content.
46///
47/// # Example
48///
49/// ```ignore
50/// let iter = 0..3;
51/// let [a, b, c] = crate::utils::collect_array(iter);
52/// assert_eq!(b, 1);
53/// ```
54///
55/// # Panics
56///
57/// If the length of the slice is not equal to `N`.
58///
59/// See also [`try_collect_array`] for a non-panicking version.
60#[inline]
61#[track_caller]
62pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
63    match try_collect_array(arr) {
64        Ok(v) => v,
65        Err(v) => panic!("Expected {N} elements, got {v:?}"),
66    }
67}
68
69/// Collect a vector into an array.
70///
71/// This is useful for deconstructing a vectors content.
72///
73/// # Example
74///
75/// ```ignore
76/// let iter = 0..3;
77/// let [a, b, c] = crate::utils::try_collect_array(iter)
78///     .unwrap_or_else(|v| panic!("Expected 3 elements, got {:?}", v));
79/// assert_eq!(b, 1);
80/// ```
81///
82/// See also [`collect_array`].
83#[inline]
84#[track_caller]
85pub fn try_collect_array<const N: usize, T>(
86    arr: impl IntoIterator<Item = T>,
87) -> Result<[T; N], Vec<T>> {
88    arr.into_iter().collect_vec().try_into()
89}
90
91/// Helper method to skip serialization of default values in serde.
92///
93/// ```ignore
94/// use serde::Serialize;
95///
96/// #[derive(Serialize)]
97/// struct MyStruct {
98///     #[serde(skip_serializing_if = "crate::utils::is_default")]
99///     field: i32,
100/// }
101/// ```
102///
103/// From <https://github.com/serde-rs/serde/issues/818>.
104#[allow(dead_code)]
105pub(crate) fn is_default<T: Default + PartialEq>(t: &T) -> bool {
106    *t == Default::default()
107}
108
109/// An empty type.
110///
111/// # Example
112///
113/// ```ignore
114/// fn foo(never: Never) -> ! {
115///     match never {}
116/// }
117/// ```
118#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
119pub enum Never {}
120
121#[cfg(test)]
122pub(crate) mod test_quantum_extension {
123    use std::sync::Arc;
124
125    use crate::ops::{OpName, OpNameRef};
126    use crate::std_extensions::arithmetic::float_ops;
127    use crate::std_extensions::logic;
128    use crate::types::FuncValueType;
129    use crate::{
130        Extension,
131        extension::{
132            ExtensionId, ExtensionRegistry, PRELUDE,
133            prelude::{bool_t, qb_t},
134        },
135        ops::ExtensionOp,
136        std_extensions::arithmetic::float_types,
137        type_row,
138        types::{PolyFuncTypeRV, Signature},
139    };
140
141    use lazy_static::lazy_static;
142
143    fn one_qb_func() -> PolyFuncTypeRV {
144        FuncValueType::new_endo(qb_t()).into()
145    }
146
147    fn two_qb_func() -> PolyFuncTypeRV {
148        FuncValueType::new_endo(vec![qb_t(), qb_t()]).into()
149    }
150    /// The extension identifier.
151    pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum");
152    fn extension() -> Arc<Extension> {
153        Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| {
154            extension
155                .add_op(
156                    OpName::new_inline("H"),
157                    "Hadamard".into(),
158                    one_qb_func(),
159                    extension_ref,
160                )
161                .unwrap();
162            extension
163                .add_op(
164                    OpName::new_inline("RzF64"),
165                    "Rotation specified by float".into(),
166                    Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]),
167                    extension_ref,
168                )
169                .unwrap();
170
171            extension
172                .add_op(
173                    OpName::new_inline("CX"),
174                    "CX".into(),
175                    two_qb_func(),
176                    extension_ref,
177                )
178                .unwrap();
179
180            extension
181                .add_op(
182                    OpName::new_inline("Measure"),
183                    "Measure a qubit, returning the qubit and the measurement result.".into(),
184                    Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]),
185                    extension_ref,
186                )
187                .unwrap();
188
189            extension
190                .add_op(
191                    OpName::new_inline("QAlloc"),
192                    "Allocate a new qubit.".into(),
193                    Signature::new(type_row![], vec![qb_t()]),
194                    extension_ref,
195                )
196                .unwrap();
197
198            extension
199                .add_op(
200                    OpName::new_inline("QDiscard"),
201                    "Discard a qubit.".into(),
202                    Signature::new(vec![qb_t()], type_row![]),
203                    extension_ref,
204                )
205                .unwrap();
206        })
207    }
208
209    lazy_static! {
210        /// Quantum extension definition.
211        pub static ref EXTENSION: Arc<Extension> = extension();
212
213        /// A registry with all necessary extensions to run tests internally, including the test quantum extension.
214        pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([
215            EXTENSION.clone(),
216            PRELUDE.clone(),
217            float_types::EXTENSION.clone(),
218            float_ops::EXTENSION.clone(),
219            logic::EXTENSION.clone()
220        ]);
221
222    }
223
224    fn get_gate(gate_name: &OpNameRef) -> ExtensionOp {
225        EXTENSION.instantiate_extension_op(gate_name, []).unwrap()
226    }
227    pub(crate) fn h_gate() -> ExtensionOp {
228        get_gate("H")
229    }
230
231    pub(crate) fn cx_gate() -> ExtensionOp {
232        get_gate("CX")
233    }
234
235    pub(crate) fn measure() -> ExtensionOp {
236        get_gate("Measure")
237    }
238
239    pub(crate) fn rz_f64() -> ExtensionOp {
240        get_gate("RzF64")
241    }
242
243    pub(crate) fn q_alloc() -> ExtensionOp {
244        get_gate("QAlloc")
245    }
246
247    pub(crate) fn q_discard() -> ExtensionOp {
248        get_gate("QDiscard")
249    }
250}
251
252/// Sort folding inputs with [`IncomingPort`] as key
253fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> {
254    let mut v: Vec<_> = consts.iter().collect();
255    v.sort_by_key(|(i, _)| i);
256    v
257}
258
259/// Sort some input constants by port and just return the constants.
260#[must_use]
261pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
262    sort_by_in_port(consts)
263        .into_iter()
264        .map(|(_, c)| c)
265        .collect()
266}
267
268/// Calculate the depth of a node in the hierarchy.
269pub fn depth(h: &Hugr, n: Node) -> u32 {
270    match h.get_parent(n) {
271        Some(p) => 1 + depth(h, p),
272        None => 0,
273    }
274}
275
276#[allow(dead_code)]
277// Test only utils
278#[cfg(test)]
279pub(crate) mod test {
280    #[allow(unused_imports)]
281    use crate::HugrView;
282    use crate::{
283        Hugr,
284        ops::{OpType, Value},
285    };
286
287    /// Check that a hugr just loads and returns a single expected constant.
288    pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
289        assert_fully_folded_with(h, |v| v == expected_value);
290    }
291
292    /// Check that a hugr just loads and returns a single constant, and validate
293    /// that constant using `check_value`.
294    ///
295    /// [`CustomConst::equals_const`] is not required to be implemented. Use this
296    /// function for Values containing such a `CustomConst`.
297    pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
298        let mut node_count = 0;
299
300        for node in h.children(h.entrypoint()) {
301            let op = h.get_optype(node);
302            match op {
303                OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
304                OpType::Const(c) if check_value(c.value()) => node_count += 1,
305                _ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
306            }
307        }
308
309        assert_eq!(node_count, 4);
310    }
311}