hugr_core/
utils.rs

1//! General utilities.
2
3use std::fmt::{self, Debug, Display};
4
5use itertools::Itertools;
6
7use crate::{ops::Value, Hugr, HugrView, IncomingPort, Node};
8
9/// Write a comma separated list of of some types.
10/// Like debug_list, but using the Display instance rather than Debug,
11/// and not adding surrounding square brackets.
12pub fn display_list<T>(ts: impl IntoIterator<Item = T>, f: &mut fmt::Formatter) -> fmt::Result
13where
14    T: Display,
15{
16    display_list_with_separator(ts, f, ", ")
17}
18
19/// Write a separated list of of some types, using a custom separator.
20/// Like debug_list, but using the Display instance rather than Debug,
21/// and not adding surrounding square brackets.
22pub fn display_list_with_separator<T>(
23    ts: impl IntoIterator<Item = T>,
24    f: &mut fmt::Formatter,
25    sep: &str,
26) -> fmt::Result
27where
28    T: Display,
29{
30    let mut first = true;
31    for t in ts.into_iter() {
32        if !first {
33            f.write_str(sep)?;
34        }
35        t.fmt(f)?;
36        if first {
37            first = false;
38        }
39    }
40    Ok(())
41}
42
43/// Collect a vector into an array.
44///
45/// This is useful for deconstructing a vectors content.
46///
47/// # Example
48///
49/// ```ignore
50/// let iter = 0..3;
51/// let [a, b, c] = crate::utils::collect_array(iter);
52/// assert_eq!(b, 1);
53/// ```
54///
55/// # Panics
56///
57/// If the length of the slice is not equal to `N`.
58///
59/// See also [`try_collect_array`] for a non-panicking version.
60#[inline]
61pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
62    try_collect_array(arr).unwrap_or_else(|v| panic!("Expected {} elements, got {:?}", N, v))
63}
64
65/// Collect a vector into an array.
66///
67/// This is useful for deconstructing a vectors content.
68///
69/// # Example
70///
71/// ```ignore
72/// let iter = 0..3;
73/// let [a, b, c] = crate::utils::try_collect_array(iter)
74///     .unwrap_or_else(|v| panic!("Expected 3 elements, got {:?}", v));
75/// assert_eq!(b, 1);
76/// ```
77///
78/// See also [`collect_array`].
79#[inline]
80pub fn try_collect_array<const N: usize, T>(
81    arr: impl IntoIterator<Item = T>,
82) -> Result<[T; N], Vec<T>> {
83    arr.into_iter().collect_vec().try_into()
84}
85
86/// Helper method to skip serialization of default values in serde.
87///
88/// ```ignore
89/// use serde::Serialize;
90///
91/// #[derive(Serialize)]
92/// struct MyStruct {
93///     #[serde(skip_serializing_if = "crate::utils::is_default")]
94///     field: i32,
95/// }
96/// ```
97///
98/// From https://github.com/serde-rs/serde/issues/818.
99#[allow(dead_code)]
100pub(crate) fn is_default<T: Default + PartialEq>(t: &T) -> bool {
101    *t == Default::default()
102}
103
104#[cfg(test)]
105pub(crate) mod test_quantum_extension {
106    use std::sync::Arc;
107
108    use crate::ops::{OpName, OpNameRef};
109    use crate::std_extensions::arithmetic::float_ops;
110    use crate::std_extensions::logic;
111    use crate::types::FuncValueType;
112    use crate::{
113        extension::{
114            prelude::{bool_t, qb_t},
115            ExtensionId, ExtensionRegistry, PRELUDE,
116        },
117        ops::ExtensionOp,
118        std_extensions::arithmetic::float_types,
119        type_row,
120        types::{PolyFuncTypeRV, Signature},
121        Extension,
122    };
123
124    use lazy_static::lazy_static;
125
126    fn one_qb_func() -> PolyFuncTypeRV {
127        FuncValueType::new_endo(qb_t()).into()
128    }
129
130    fn two_qb_func() -> PolyFuncTypeRV {
131        FuncValueType::new_endo(vec![qb_t(), qb_t()]).into()
132    }
133    /// The extension identifier.
134    pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum");
135    fn extension() -> Arc<Extension> {
136        Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| {
137            extension
138                .add_op(
139                    OpName::new_inline("H"),
140                    "Hadamard".into(),
141                    one_qb_func(),
142                    extension_ref,
143                )
144                .unwrap();
145            extension
146                .add_op(
147                    OpName::new_inline("RzF64"),
148                    "Rotation specified by float".into(),
149                    Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]),
150                    extension_ref,
151                )
152                .unwrap();
153
154            extension
155                .add_op(
156                    OpName::new_inline("CX"),
157                    "CX".into(),
158                    two_qb_func(),
159                    extension_ref,
160                )
161                .unwrap();
162
163            extension
164                .add_op(
165                    OpName::new_inline("Measure"),
166                    "Measure a qubit, returning the qubit and the measurement result.".into(),
167                    Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]),
168                    extension_ref,
169                )
170                .unwrap();
171
172            extension
173                .add_op(
174                    OpName::new_inline("QAlloc"),
175                    "Allocate a new qubit.".into(),
176                    Signature::new(type_row![], vec![qb_t()]),
177                    extension_ref,
178                )
179                .unwrap();
180
181            extension
182                .add_op(
183                    OpName::new_inline("QDiscard"),
184                    "Discard a qubit.".into(),
185                    Signature::new(vec![qb_t()], type_row![]),
186                    extension_ref,
187                )
188                .unwrap();
189        })
190    }
191
192    lazy_static! {
193        /// Quantum extension definition.
194        pub static ref EXTENSION: Arc<Extension> = extension();
195
196        /// A registry with all necessary extensions to run tests internally, including the test quantum extension.
197        pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([
198            EXTENSION.clone(),
199            PRELUDE.clone(),
200            float_types::EXTENSION.clone(),
201            float_ops::EXTENSION.clone(),
202            logic::EXTENSION.clone()
203        ]);
204
205    }
206
207    fn get_gate(gate_name: &OpNameRef) -> ExtensionOp {
208        EXTENSION.instantiate_extension_op(gate_name, []).unwrap()
209    }
210    pub(crate) fn h_gate() -> ExtensionOp {
211        get_gate("H")
212    }
213
214    pub(crate) fn cx_gate() -> ExtensionOp {
215        get_gate("CX")
216    }
217
218    pub(crate) fn measure() -> ExtensionOp {
219        get_gate("Measure")
220    }
221
222    pub(crate) fn rz_f64() -> ExtensionOp {
223        get_gate("RzF64")
224    }
225
226    pub(crate) fn q_alloc() -> ExtensionOp {
227        get_gate("QAlloc")
228    }
229
230    pub(crate) fn q_discard() -> ExtensionOp {
231        get_gate("QDiscard")
232    }
233}
234
235/// Sort folding inputs with [`IncomingPort`] as key
236fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> {
237    let mut v: Vec<_> = consts.iter().collect();
238    v.sort_by_key(|(i, _)| i);
239    v
240}
241
242/// Sort some input constants by port and just return the constants.
243pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
244    sort_by_in_port(consts)
245        .into_iter()
246        .map(|(_, c)| c)
247        .collect()
248}
249
250/// Calculate the depth of a node in the hierarchy.
251pub fn depth(h: &Hugr, n: Node) -> u32 {
252    match h.get_parent(n) {
253        Some(p) => 1 + depth(h, p),
254        None => 0,
255    }
256}
257
258#[allow(dead_code)]
259// Test only utils
260#[cfg(test)]
261pub(crate) mod test {
262    #[allow(unused_imports)]
263    use crate::HugrView;
264    use crate::{
265        ops::{OpType, Value},
266        Hugr,
267    };
268
269    /// Check that a hugr just loads and returns a single expected constant.
270    pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
271        assert_fully_folded_with(h, |v| v == expected_value)
272    }
273
274    /// Check that a hugr just loads and returns a single constant, and validate
275    /// that constant using `check_value`.
276    ///
277    /// [CustomConst::equals_const] is not required to be implemented. Use this
278    /// function for Values containing such a `CustomConst`.
279    pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
280        let mut node_count = 0;
281
282        for node in h.children(h.root()) {
283            let op = h.get_optype(node);
284            match op {
285                OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
286                OpType::Const(c) if check_value(c.value()) => node_count += 1,
287                _ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
288            }
289        }
290
291        assert_eq!(node_count, 4);
292    }
293}