hugr_core/std_extensions/collections/
array.rs1mod array_op;
4mod array_repeat;
5mod array_scan;
6
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use lazy_static::lazy_static;
11use serde::{Deserialize, Serialize};
12use std::hash::{Hash, Hasher};
13
14use crate::extension::resolution::{
15 resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError,
16 WeakExtensionRegistry,
17};
18use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
19use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound};
20use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName};
21use crate::ops::{ExtensionOp, OpName, Value};
22use crate::types::type_param::{TypeArg, TypeParam};
23use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName};
24use crate::Extension;
25
26pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter};
27pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID};
28pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID};
29
30pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array");
32pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.array");
34pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38pub struct ArrayValue {
40 values: Vec<Value>,
41 typ: Type,
42}
43
44impl ArrayValue {
45 #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))]
47 pub(crate) const CTR_NAME: &'static str = "collections.array.const";
48
49 pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
52 Self {
53 values: contents.into_iter().collect_vec(),
54 typ,
55 }
56 }
57
58 pub fn new_empty(typ: Type) -> Self {
60 Self {
61 values: vec![],
62 typ,
63 }
64 }
65
66 pub fn custom_type(&self) -> CustomType {
68 array_custom_type(self.values.len() as u64, self.typ.clone())
69 }
70
71 pub fn get_element_type(&self) -> &Type {
73 &self.typ
74 }
75
76 pub fn get_contents(&self) -> &[Value] {
78 &self.values
79 }
80}
81
82impl TryHash for ArrayValue {
83 fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
84 maybe_hash_values(&self.values, &mut st) && {
85 self.typ.hash(&mut st);
86 true
87 }
88 }
89}
90
91#[typetag::serde]
92impl CustomConst for ArrayValue {
93 fn name(&self) -> ValueName {
94 ValueName::new_inline("array")
95 }
96
97 fn get_type(&self) -> Type {
98 self.custom_type().into()
99 }
100
101 fn validate(&self) -> Result<(), CustomCheckFailure> {
102 let typ = self.custom_type();
103
104 EXTENSION
105 .get_type(&ARRAY_TYPENAME)
106 .unwrap()
107 .check_custom(&typ)
108 .map_err(|_| {
109 CustomCheckFailure::Message(format!(
110 "Custom typ {typ} is not a valid instantiation of array."
111 ))
112 })?;
113
114 let ty = match typ.args() {
116 [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]
117 if *n as usize == self.values.len() =>
118 {
119 ty
120 }
121 _ => {
122 return Err(CustomCheckFailure::Message(format!(
123 "Invalid array type arguments: {:?}",
124 typ.args()
125 )))
126 }
127 };
128
129 for v in &self.values {
131 if v.get_type() != *ty {
132 return Err(CustomCheckFailure::Message(format!(
133 "Array element {v:?} is not of expected type {ty}"
134 )));
135 }
136 }
137
138 Ok(())
139 }
140
141 fn equal_consts(&self, other: &dyn CustomConst) -> bool {
142 crate::ops::constant::downcast_equal_consts(self, other)
143 }
144
145 fn extension_reqs(&self) -> ExtensionSet {
146 ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs))
147 .union(EXTENSION_ID.into())
148 }
149
150 fn update_extensions(
151 &mut self,
152 extensions: &WeakExtensionRegistry,
153 ) -> Result<(), ExtensionResolutionError> {
154 for val in &mut self.values {
155 resolve_value_extensions(val, extensions)?;
156 }
157 resolve_type_extensions(&mut self.typ, extensions)
158 }
159}
160
161lazy_static! {
162 pub static ref EXTENSION: Arc<Extension> = {
164 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
165 extension.add_type(
166 ARRAY_TYPENAME,
167 vec![ TypeParam::max_nat(), TypeBound::Any.into()],
168 "Fixed-length array".into(),
169 TypeDefBound::from_params(vec![1] ),
170 extension_ref,
171 )
172 .unwrap();
173
174 array_op::ArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
175 array_repeat::ArrayRepeatDef.add_to_extension(extension, extension_ref).unwrap();
176 array_scan::ArrayScanDef.add_to_extension(extension, extension_ref).unwrap();
177 })
178 };
179}
180
181fn array_type_def() -> &'static TypeDef {
182 EXTENSION.get_type(&ARRAY_TYPENAME).unwrap()
183}
184
185pub fn array_type(size: u64, element_ty: Type) -> Type {
190 array_custom_type(size, element_ty).into()
191}
192
193pub fn array_type_parametric(
197 size: impl Into<TypeArg>,
198 element_ty: impl Into<TypeArg>,
199) -> Result<Type, SignatureError> {
200 instantiate_array(array_type_def(), size, element_ty)
201}
202
203fn array_custom_type(size: impl Into<TypeArg>, element_ty: impl Into<TypeArg>) -> CustomType {
204 instantiate_array_custom(array_type_def(), size, element_ty)
205 .expect("array parameters are valid")
206}
207
208fn instantiate_array_custom(
209 array_def: &TypeDef,
210 size: impl Into<TypeArg>,
211 element_ty: impl Into<TypeArg>,
212) -> Result<CustomType, SignatureError> {
213 array_def.instantiate(vec![size.into(), element_ty.into()])
214}
215
216fn instantiate_array(
217 array_def: &TypeDef,
218 size: impl Into<TypeArg>,
219 element_ty: impl Into<TypeArg>,
220) -> Result<Type, SignatureError> {
221 instantiate_array_custom(array_def, size, element_ty).map(Into::into)
222}
223
224pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array");
226
227pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
229 let op = array_op::ArrayOpDef::new_array.to_concrete(element_ty, size);
230 op.to_extension_op().unwrap()
231}
232
233#[cfg(test)]
234mod test {
235 use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr};
236 use crate::extension::prelude::{qb_t, usize_t, ConstUsize};
237 use crate::ops::constant::CustomConst;
238 use crate::std_extensions::arithmetic::float_types::ConstF64;
239
240 use super::{array_type, new_array_op, ArrayValue};
241
242 #[test]
243 fn test_new_array() {
245 let mut b =
246 DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], array_type(2, qb_t()))).unwrap();
247
248 let [q1, q2] = b.input_wires_arr();
249
250 let op = new_array_op(qb_t(), 2);
251
252 let out = b.add_dataflow_op(op, [q1, q2]).unwrap();
253
254 b.finish_hugr_with_outputs(out.outputs()).unwrap();
255 }
256
257 #[test]
258 fn test_array_value() {
259 let array_value = ArrayValue {
260 values: vec![ConstUsize::new(3).into()],
261 typ: usize_t(),
262 };
263
264 array_value.validate().unwrap();
265
266 let wrong_array_value = ArrayValue {
267 values: vec![ConstF64::new(1.2).into()],
268 typ: usize_t(),
269 };
270 assert!(wrong_array_value.validate().is_err());
271 }
272}