1use std::{
17 hash::{self, Hash as _},
18 iter,
19 sync::{self, Arc, LazyLock},
20};
21
22use crate::{
23 Extension, Wire,
24 builder::{BuildError, Dataflow},
25 extension::{
26 ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef,
27 prelude::{option_type, usize_t},
28 resolution::{ExtensionResolutionError, WeakExtensionRegistry},
29 simple_op::{
30 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
31 try_from_name,
32 },
33 },
34 ops::{
35 ExtensionOp, OpName, Value,
36 constant::{CustomConst, TryHash, ValueName, maybe_hash_values},
37 },
38 types::{
39 ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg,
40 TypeBound, TypeName,
41 type_param::{TermTypeError, TypeParam},
42 },
43};
44
45use super::array::ArrayValue;
46
47pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.static_array");
49pub const STATIC_ARRAY_TYPENAME: TypeName = TypeName::new_inline("static_array");
51pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
53
54#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, derive_more::From)]
55pub struct StaticArrayValue {
58 pub value: ArrayValue,
60 pub name: String,
62}
63
64impl StaticArrayValue {
65 #[must_use]
67 pub fn get_element_type(&self) -> &Type {
68 self.value.get_element_type()
69 }
70
71 #[must_use]
73 pub fn get_contents(&self) -> &[Value] {
74 self.value.get_contents()
75 }
76
77 pub fn try_new(
80 name: impl ToString,
81 typ: Type,
82 contents: impl IntoIterator<Item = Value>,
83 ) -> Result<Self, ConstTypeError> {
84 if !TypeBound::Copyable.contains(typ.least_upper_bound()) {
85 return Err(CustomCheckFailure::Message(format!(
86 "Failed to construct a StaticArrayValue with non-Copyable type: {typ}"
87 ))
88 .into());
89 }
90 Ok(Self {
91 value: ArrayValue::new(typ, contents),
92 name: name.to_string(),
93 })
94 }
95
96 pub fn try_new_empty(name: impl ToString, typ: Type) -> Result<Self, ConstTypeError> {
98 Self::try_new(name, typ, iter::empty())
99 }
100
101 #[must_use]
103 pub fn custom_type(&self) -> CustomType {
104 static_array_custom_type(self.get_element_type().clone())
105 }
106}
107
108impl TryHash for StaticArrayValue {
109 fn try_hash(&self, mut st: &mut dyn hash::Hasher) -> bool {
110 maybe_hash_values(self.get_contents(), &mut st) && {
111 self.name.hash(&mut st);
112 self.get_element_type().hash(&mut st);
113 true
114 }
115 }
116}
117
118#[typetag::serde]
119impl CustomConst for StaticArrayValue {
120 fn name(&self) -> ValueName {
121 ValueName::new_inline("const_array")
122 }
123
124 fn get_type(&self) -> Type {
125 self.custom_type().into()
126 }
127
128 fn equal_consts(&self, other: &dyn CustomConst) -> bool {
129 crate::ops::constant::downcast_equal_consts(self, other)
130 }
131
132 fn update_extensions(
133 &mut self,
134 extensions: &WeakExtensionRegistry,
135 ) -> Result<(), ExtensionResolutionError> {
136 self.value.update_extensions(extensions)
137 }
138}
139
140pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
142 use TypeBound::Copyable;
143 Extension::new_arc(EXTENSION_ID.clone(), VERSION, |extension, extension_ref| {
144 extension
145 .add_type(
146 STATIC_ARRAY_TYPENAME,
147 vec![Copyable.into()],
148 "Fixed-length constant array".into(),
149 Copyable.into(),
150 extension_ref,
151 )
152 .unwrap();
153
154 StaticArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
155 })
156});
157
158fn instantiate_const_static_array_custom_type(
159 def: &TypeDef,
160 element_ty: impl Into<TypeArg>,
161) -> CustomType {
162 def.instantiate([element_ty.into()])
163 .unwrap_or_else(|e| panic!("{e}"))
164}
165
166pub fn static_array_custom_type(element_ty: impl Into<TypeArg>) -> CustomType {
168 instantiate_const_static_array_custom_type(
169 EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
170 element_ty,
171 )
172}
173
174pub fn static_array_type(element_ty: impl Into<TypeArg>) -> Type {
176 static_array_custom_type(element_ty).into()
177}
178
179#[derive(
180 Clone,
181 Copy,
182 Debug,
183 Hash,
184 PartialEq,
185 Eq,
186 strum::EnumIter,
187 strum::IntoStaticStr,
188 strum::EnumString,
189)]
190#[allow(non_camel_case_types, missing_docs)]
191#[non_exhaustive]
192pub enum StaticArrayOpDef {
193 get,
194 len,
195}
196
197impl StaticArrayOpDef {
198 fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak<Extension>) -> SignatureFunc {
199 use TypeBound::Copyable;
200 let t_param = TypeParam::from(Copyable);
201 let elem_ty = Type::new_var_use(0, Copyable);
202 let array_ty: Type =
203 instantiate_const_static_array_custom_type(def, elem_ty.clone()).into();
204 match self {
205 Self::get => PolyFuncType::new(
206 [t_param],
207 Signature::new(vec![array_ty, usize_t()], Type::from(option_type(elem_ty))),
208 )
209 .into(),
210 Self::len => PolyFuncType::new([t_param], Signature::new(array_ty, usize_t())).into(),
211 }
212 }
213}
214
215impl MakeOpDef for StaticArrayOpDef {
216 fn opdef_id(&self) -> OpName {
217 <&'static str>::from(self).into()
218 }
219
220 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
221 where
222 Self: Sized,
223 {
224 try_from_name(op_def.name(), op_def.extension_id())
225 }
226
227 fn init_signature(&self, extension_ref: &sync::Weak<Extension>) -> SignatureFunc {
228 self.signature_from_def(
229 EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
230 extension_ref,
231 )
232 }
233
234 fn extension_ref(&self) -> sync::Weak<Extension> {
235 Arc::downgrade(&EXTENSION)
236 }
237
238 fn extension(&self) -> ExtensionId {
239 EXTENSION_ID.clone()
240 }
241
242 fn description(&self) -> String {
243 match self {
244 Self::get => "Get an element from a static array",
245 Self::len => "Get the length of a static array",
246 }
247 .into()
248 }
249
250 fn add_to_extension(
254 &self,
255 extension: &mut Extension,
256 extension_ref: &sync::Weak<Extension>,
257 ) -> Result<(), crate::extension::ExtensionBuildError> {
258 let sig = self.signature_from_def(
259 extension.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
260 extension_ref,
261 );
262 let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
263
264 self.post_opdef(def);
265
266 Ok(())
267 }
268}
269
270#[derive(Clone, Debug, PartialEq)]
271pub struct StaticArrayOp {
273 pub def: StaticArrayOpDef,
275 pub elem_ty: Type,
277}
278
279impl MakeExtensionOp for StaticArrayOp {
280 fn op_id(&self) -> OpName {
281 self.def.opdef_id()
282 }
283
284 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
285 where
286 Self: Sized,
287 {
288 let def = StaticArrayOpDef::from_def(ext_op.def())?;
289 def.instantiate(ext_op.args())
290 }
291
292 fn type_args(&self) -> Vec<TypeArg> {
293 vec![self.elem_ty.clone().into()]
294 }
295}
296
297impl HasDef for StaticArrayOp {
298 type Def = StaticArrayOpDef;
299}
300
301impl HasConcrete for StaticArrayOpDef {
302 type Concrete = StaticArrayOp;
303
304 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
305 use TypeBound::Copyable;
306 match type_args {
307 [arg] => {
308 let elem_ty = arg
309 .as_runtime()
310 .filter(|t| Copyable.contains(t.least_upper_bound()))
311 .ok_or(SignatureError::TypeArgMismatch(
312 TermTypeError::TypeMismatch {
313 type_: Box::new(Copyable.into()),
314 term: Box::new(arg.clone()),
315 },
316 ))?;
317
318 Ok(StaticArrayOp {
319 def: *self,
320 elem_ty,
321 })
322 }
323 _ => Err(
324 SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1))
325 .into(),
326 ),
327 }
328 }
329}
330
331impl MakeRegisteredOp for StaticArrayOp {
332 fn extension_id(&self) -> ExtensionId {
333 EXTENSION_ID.clone()
334 }
335
336 fn extension_ref(&self) -> sync::Weak<Extension> {
337 Arc::downgrade(&EXTENSION)
338 }
339}
340
341pub trait StaticArrayOpBuilder: Dataflow {
343 fn add_static_array_get(
355 &mut self,
356 elem_ty: Type,
357 array: Wire,
358 index: Wire,
359 ) -> Result<Wire, BuildError> {
360 Ok(self
361 .add_dataflow_op(
362 StaticArrayOp {
363 def: StaticArrayOpDef::get,
364 elem_ty,
365 }
366 .to_extension_op()
367 .unwrap(),
368 [array, index],
369 )?
370 .out_wire(0))
371 }
372
373 fn add_static_array_len(&mut self, elem_ty: Type, array: Wire) -> Result<Wire, BuildError> {
384 Ok(self
385 .add_dataflow_op(
386 StaticArrayOp {
387 def: StaticArrayOpDef::len,
388 elem_ty,
389 }
390 .to_extension_op()
391 .unwrap(),
392 [array],
393 )?
394 .out_wire(0))
395 }
396}
397
398impl<T: Dataflow> StaticArrayOpBuilder for T {}
399
400#[cfg(test)]
401mod test {
402 use crate::{
403 builder::{DFGBuilder, DataflowHugr as _},
404 extension::prelude::{ConstUsize, qb_t},
405 type_row,
406 };
407
408 use super::*;
409
410 #[test]
411 fn const_static_array_copyable() {
412 let _good = StaticArrayValue::try_new_empty("good", Type::UNIT).unwrap();
413 let _bad = StaticArrayValue::try_new_empty("good", qb_t()).unwrap_err();
414 }
415
416 #[test]
417 fn all_ops() {
418 let _ = {
419 let mut builder = DFGBuilder::new(Signature::new(
420 type_row![],
421 Type::from(option_type(usize_t())),
422 ))
423 .unwrap();
424 let array = builder.add_load_value(
425 StaticArrayValue::try_new(
426 "t",
427 usize_t(),
428 (1..999).map(|x| ConstUsize::new(x).into()),
429 )
430 .unwrap(),
431 );
432 let _ = builder.add_static_array_len(usize_t(), array).unwrap();
433 let index = builder.add_load_value(ConstUsize::new(777));
434 let x = builder
435 .add_static_array_get(usize_t(), array, index)
436 .unwrap();
437 builder.finish_hugr_with_outputs([x]).unwrap()
438 };
439 }
440}