1use std::fmt::{self, Debug, Display};
4
5use itertools::Itertools;
6
7use crate::{ops::Value, Hugr, HugrView, IncomingPort, Node};
8
9pub fn display_list<T>(ts: impl IntoIterator<Item = T>, f: &mut fmt::Formatter) -> fmt::Result
13where
14 T: Display,
15{
16 display_list_with_separator(ts, f, ", ")
17}
18
19pub fn display_list_with_separator<T>(
23 ts: impl IntoIterator<Item = T>,
24 f: &mut fmt::Formatter,
25 sep: &str,
26) -> fmt::Result
27where
28 T: Display,
29{
30 let mut first = true;
31 for t in ts.into_iter() {
32 if !first {
33 f.write_str(sep)?;
34 }
35 t.fmt(f)?;
36 if first {
37 first = false;
38 }
39 }
40 Ok(())
41}
42
43#[inline]
61#[track_caller]
62pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
63 match try_collect_array(arr) {
64 Ok(v) => v,
65 Err(v) => panic!("Expected {} elements, got {:?}", N, v),
66 }
67}
68
69#[inline]
84#[track_caller]
85pub fn try_collect_array<const N: usize, T>(
86 arr: impl IntoIterator<Item = T>,
87) -> Result<[T; N], Vec<T>> {
88 arr.into_iter().collect_vec().try_into()
89}
90
91#[allow(dead_code)]
105pub(crate) fn is_default<T: Default + PartialEq>(t: &T) -> bool {
106 *t == Default::default()
107}
108
109#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
119pub enum Never {}
120
121#[cfg(test)]
122pub(crate) mod test_quantum_extension {
123 use std::sync::Arc;
124
125 use crate::ops::{OpName, OpNameRef};
126 use crate::std_extensions::arithmetic::float_ops;
127 use crate::std_extensions::logic;
128 use crate::types::FuncValueType;
129 use crate::{
130 extension::{
131 prelude::{bool_t, qb_t},
132 ExtensionId, ExtensionRegistry, PRELUDE,
133 },
134 ops::ExtensionOp,
135 std_extensions::arithmetic::float_types,
136 type_row,
137 types::{PolyFuncTypeRV, Signature},
138 Extension,
139 };
140
141 use lazy_static::lazy_static;
142
143 fn one_qb_func() -> PolyFuncTypeRV {
144 FuncValueType::new_endo(qb_t()).into()
145 }
146
147 fn two_qb_func() -> PolyFuncTypeRV {
148 FuncValueType::new_endo(vec![qb_t(), qb_t()]).into()
149 }
150 pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum");
152 fn extension() -> Arc<Extension> {
153 Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| {
154 extension
155 .add_op(
156 OpName::new_inline("H"),
157 "Hadamard".into(),
158 one_qb_func(),
159 extension_ref,
160 )
161 .unwrap();
162 extension
163 .add_op(
164 OpName::new_inline("RzF64"),
165 "Rotation specified by float".into(),
166 Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]),
167 extension_ref,
168 )
169 .unwrap();
170
171 extension
172 .add_op(
173 OpName::new_inline("CX"),
174 "CX".into(),
175 two_qb_func(),
176 extension_ref,
177 )
178 .unwrap();
179
180 extension
181 .add_op(
182 OpName::new_inline("Measure"),
183 "Measure a qubit, returning the qubit and the measurement result.".into(),
184 Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]),
185 extension_ref,
186 )
187 .unwrap();
188
189 extension
190 .add_op(
191 OpName::new_inline("QAlloc"),
192 "Allocate a new qubit.".into(),
193 Signature::new(type_row![], vec![qb_t()]),
194 extension_ref,
195 )
196 .unwrap();
197
198 extension
199 .add_op(
200 OpName::new_inline("QDiscard"),
201 "Discard a qubit.".into(),
202 Signature::new(vec![qb_t()], type_row![]),
203 extension_ref,
204 )
205 .unwrap();
206 })
207 }
208
209 lazy_static! {
210 pub static ref EXTENSION: Arc<Extension> = extension();
212
213 pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([
215 EXTENSION.clone(),
216 PRELUDE.clone(),
217 float_types::EXTENSION.clone(),
218 float_ops::EXTENSION.clone(),
219 logic::EXTENSION.clone()
220 ]);
221
222 }
223
224 fn get_gate(gate_name: &OpNameRef) -> ExtensionOp {
225 EXTENSION.instantiate_extension_op(gate_name, []).unwrap()
226 }
227 pub(crate) fn h_gate() -> ExtensionOp {
228 get_gate("H")
229 }
230
231 pub(crate) fn cx_gate() -> ExtensionOp {
232 get_gate("CX")
233 }
234
235 pub(crate) fn measure() -> ExtensionOp {
236 get_gate("Measure")
237 }
238
239 pub(crate) fn rz_f64() -> ExtensionOp {
240 get_gate("RzF64")
241 }
242
243 pub(crate) fn q_alloc() -> ExtensionOp {
244 get_gate("QAlloc")
245 }
246
247 pub(crate) fn q_discard() -> ExtensionOp {
248 get_gate("QDiscard")
249 }
250}
251
252fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> {
254 let mut v: Vec<_> = consts.iter().collect();
255 v.sort_by_key(|(i, _)| i);
256 v
257}
258
259pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
261 sort_by_in_port(consts)
262 .into_iter()
263 .map(|(_, c)| c)
264 .collect()
265}
266
267pub fn depth(h: &Hugr, n: Node) -> u32 {
269 match h.get_parent(n) {
270 Some(p) => 1 + depth(h, p),
271 None => 0,
272 }
273}
274
275#[allow(dead_code)]
276#[cfg(test)]
278pub(crate) mod test {
279 #[allow(unused_imports)]
280 use crate::HugrView;
281 use crate::{
282 ops::{OpType, Value},
283 Hugr,
284 };
285
286 pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
288 assert_fully_folded_with(h, |v| v == expected_value)
289 }
290
291 pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
297 let mut node_count = 0;
298
299 for node in h.children(h.entrypoint()) {
300 let op = h.get_optype(node);
301 match op {
302 OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
303 OpType::Const(c) if check_value(c.value()) => node_count += 1,
304 _ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
305 }
306 }
307
308 assert_eq!(node_count, 4);
309 }
310}