1use std::fmt::{self, Debug, Display};
4
5use itertools::Itertools;
6
7use crate::{ops::Value, Hugr, HugrView, IncomingPort, Node};
8
9pub fn display_list<T>(ts: impl IntoIterator<Item = T>, f: &mut fmt::Formatter) -> fmt::Result
13where
14 T: Display,
15{
16 display_list_with_separator(ts, f, ", ")
17}
18
19pub fn display_list_with_separator<T>(
23 ts: impl IntoIterator<Item = T>,
24 f: &mut fmt::Formatter,
25 sep: &str,
26) -> fmt::Result
27where
28 T: Display,
29{
30 let mut first = true;
31 for t in ts.into_iter() {
32 if !first {
33 f.write_str(sep)?;
34 }
35 t.fmt(f)?;
36 if first {
37 first = false;
38 }
39 }
40 Ok(())
41}
42
43#[inline]
61pub fn collect_array<const N: usize, T: Debug>(arr: impl IntoIterator<Item = T>) -> [T; N] {
62 try_collect_array(arr).unwrap_or_else(|v| panic!("Expected {} elements, got {:?}", N, v))
63}
64
65#[inline]
80pub fn try_collect_array<const N: usize, T>(
81 arr: impl IntoIterator<Item = T>,
82) -> Result<[T; N], Vec<T>> {
83 arr.into_iter().collect_vec().try_into()
84}
85
86#[allow(dead_code)]
100pub(crate) fn is_default<T: Default + PartialEq>(t: &T) -> bool {
101 *t == Default::default()
102}
103
104#[cfg(test)]
105pub(crate) mod test_quantum_extension {
106 use std::sync::Arc;
107
108 use crate::ops::{OpName, OpNameRef};
109 use crate::std_extensions::arithmetic::float_ops;
110 use crate::std_extensions::logic;
111 use crate::types::FuncValueType;
112 use crate::{
113 extension::{
114 prelude::{bool_t, qb_t},
115 ExtensionId, ExtensionRegistry, PRELUDE,
116 },
117 ops::ExtensionOp,
118 std_extensions::arithmetic::float_types,
119 type_row,
120 types::{PolyFuncTypeRV, Signature},
121 Extension,
122 };
123
124 use lazy_static::lazy_static;
125
126 fn one_qb_func() -> PolyFuncTypeRV {
127 FuncValueType::new_endo(qb_t()).into()
128 }
129
130 fn two_qb_func() -> PolyFuncTypeRV {
131 FuncValueType::new_endo(vec![qb_t(), qb_t()]).into()
132 }
133 pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum");
135 fn extension() -> Arc<Extension> {
136 Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| {
137 extension
138 .add_op(
139 OpName::new_inline("H"),
140 "Hadamard".into(),
141 one_qb_func(),
142 extension_ref,
143 )
144 .unwrap();
145 extension
146 .add_op(
147 OpName::new_inline("RzF64"),
148 "Rotation specified by float".into(),
149 Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]),
150 extension_ref,
151 )
152 .unwrap();
153
154 extension
155 .add_op(
156 OpName::new_inline("CX"),
157 "CX".into(),
158 two_qb_func(),
159 extension_ref,
160 )
161 .unwrap();
162
163 extension
164 .add_op(
165 OpName::new_inline("Measure"),
166 "Measure a qubit, returning the qubit and the measurement result.".into(),
167 Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]),
168 extension_ref,
169 )
170 .unwrap();
171
172 extension
173 .add_op(
174 OpName::new_inline("QAlloc"),
175 "Allocate a new qubit.".into(),
176 Signature::new(type_row![], vec![qb_t()]),
177 extension_ref,
178 )
179 .unwrap();
180
181 extension
182 .add_op(
183 OpName::new_inline("QDiscard"),
184 "Discard a qubit.".into(),
185 Signature::new(vec![qb_t()], type_row![]),
186 extension_ref,
187 )
188 .unwrap();
189 })
190 }
191
192 lazy_static! {
193 pub static ref EXTENSION: Arc<Extension> = extension();
195
196 pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([
198 EXTENSION.clone(),
199 PRELUDE.clone(),
200 float_types::EXTENSION.clone(),
201 float_ops::EXTENSION.clone(),
202 logic::EXTENSION.clone()
203 ]);
204
205 }
206
207 fn get_gate(gate_name: &OpNameRef) -> ExtensionOp {
208 EXTENSION.instantiate_extension_op(gate_name, []).unwrap()
209 }
210 pub(crate) fn h_gate() -> ExtensionOp {
211 get_gate("H")
212 }
213
214 pub(crate) fn cx_gate() -> ExtensionOp {
215 get_gate("CX")
216 }
217
218 pub(crate) fn measure() -> ExtensionOp {
219 get_gate("Measure")
220 }
221
222 pub(crate) fn rz_f64() -> ExtensionOp {
223 get_gate("RzF64")
224 }
225
226 pub(crate) fn q_alloc() -> ExtensionOp {
227 get_gate("QAlloc")
228 }
229
230 pub(crate) fn q_discard() -> ExtensionOp {
231 get_gate("QDiscard")
232 }
233}
234
235fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> {
237 let mut v: Vec<_> = consts.iter().collect();
238 v.sort_by_key(|(i, _)| i);
239 v
240}
241
242pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
244 sort_by_in_port(consts)
245 .into_iter()
246 .map(|(_, c)| c)
247 .collect()
248}
249
250pub fn depth(h: &Hugr, n: Node) -> u32 {
252 match h.get_parent(n) {
253 Some(p) => 1 + depth(h, p),
254 None => 0,
255 }
256}
257
258#[allow(dead_code)]
259#[cfg(test)]
261pub(crate) mod test {
262 #[allow(unused_imports)]
263 use crate::HugrView;
264 use crate::{
265 ops::{OpType, Value},
266 Hugr,
267 };
268
269 pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
271 assert_fully_folded_with(h, |v| v == expected_value)
272 }
273
274 pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
280 let mut node_count = 0;
281
282 for node in h.children(h.root()) {
283 let op = h.get_optype(node);
284 match op {
285 OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
286 OpType::Const(c) if check_value(c.value()) => node_count += 1,
287 _ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
288 }
289 }
290
291 assert_eq!(node_count, 4);
292 }
293}