1use std::{
17 hash::{self, Hash as _},
18 iter,
19 sync::{self, Arc},
20};
21
22use crate::{
23 Extension, Wire,
24 builder::{BuildError, Dataflow},
25 extension::{
26 ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef,
27 prelude::{option_type, usize_t},
28 resolution::{ExtensionResolutionError, WeakExtensionRegistry},
29 simple_op::{
30 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
31 try_from_name,
32 },
33 },
34 ops::{
35 ExtensionOp, OpName, Value,
36 constant::{CustomConst, TryHash, ValueName, maybe_hash_values},
37 },
38 types::{
39 ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg,
40 TypeBound, TypeName,
41 type_param::{TermTypeError, TypeParam},
42 },
43};
44
45use lazy_static::lazy_static;
46
47use super::array::ArrayValue;
48
49pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.static_array");
51pub const STATIC_ARRAY_TYPENAME: TypeName = TypeName::new_inline("static_array");
53pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
55
56#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, derive_more::From)]
57pub struct StaticArrayValue {
60 pub value: ArrayValue,
62 pub name: String,
64}
65
66impl StaticArrayValue {
67 #[must_use]
69 pub fn get_element_type(&self) -> &Type {
70 self.value.get_element_type()
71 }
72
73 #[must_use]
75 pub fn get_contents(&self) -> &[Value] {
76 self.value.get_contents()
77 }
78
79 pub fn try_new(
82 name: impl ToString,
83 typ: Type,
84 contents: impl IntoIterator<Item = Value>,
85 ) -> Result<Self, ConstTypeError> {
86 if !TypeBound::Copyable.contains(typ.least_upper_bound()) {
87 return Err(CustomCheckFailure::Message(format!(
88 "Failed to construct a StaticArrayValue with non-Copyable type: {typ}"
89 ))
90 .into());
91 }
92 Ok(Self {
93 value: ArrayValue::new(typ, contents),
94 name: name.to_string(),
95 })
96 }
97
98 pub fn try_new_empty(name: impl ToString, typ: Type) -> Result<Self, ConstTypeError> {
100 Self::try_new(name, typ, iter::empty())
101 }
102
103 #[must_use]
105 pub fn custom_type(&self) -> CustomType {
106 static_array_custom_type(self.get_element_type().clone())
107 }
108}
109
110impl TryHash for StaticArrayValue {
111 fn try_hash(&self, mut st: &mut dyn hash::Hasher) -> bool {
112 maybe_hash_values(self.get_contents(), &mut st) && {
113 self.name.hash(&mut st);
114 self.get_element_type().hash(&mut st);
115 true
116 }
117 }
118}
119
120#[typetag::serde]
121impl CustomConst for StaticArrayValue {
122 fn name(&self) -> ValueName {
123 ValueName::new_inline("const_array")
124 }
125
126 fn get_type(&self) -> Type {
127 self.custom_type().into()
128 }
129
130 fn equal_consts(&self, other: &dyn CustomConst) -> bool {
131 crate::ops::constant::downcast_equal_consts(self, other)
132 }
133
134 fn update_extensions(
135 &mut self,
136 extensions: &WeakExtensionRegistry,
137 ) -> Result<(), ExtensionResolutionError> {
138 self.value.update_extensions(extensions)
139 }
140}
141
142lazy_static! {
143 pub static ref EXTENSION: Arc<Extension> = {
145 use TypeBound::Copyable;
146 Extension::new_arc(EXTENSION_ID.clone(), VERSION, |extension, extension_ref| {
147 extension.add_type(
148 STATIC_ARRAY_TYPENAME,
149 vec![Copyable.into()],
150 "Fixed-length constant array".into(),
151 Copyable.into(),
152 extension_ref,
153 )
154 .unwrap();
155
156 StaticArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
157 })
158 };
159}
160
161fn instantiate_const_static_array_custom_type(
162 def: &TypeDef,
163 element_ty: impl Into<TypeArg>,
164) -> CustomType {
165 def.instantiate([element_ty.into()])
166 .unwrap_or_else(|e| panic!("{e}"))
167}
168
169pub fn static_array_custom_type(element_ty: impl Into<TypeArg>) -> CustomType {
171 instantiate_const_static_array_custom_type(
172 EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
173 element_ty,
174 )
175}
176
177pub fn static_array_type(element_ty: impl Into<TypeArg>) -> Type {
179 static_array_custom_type(element_ty).into()
180}
181
182#[derive(
183 Clone,
184 Copy,
185 Debug,
186 Hash,
187 PartialEq,
188 Eq,
189 strum::EnumIter,
190 strum::IntoStaticStr,
191 strum::EnumString,
192)]
193#[allow(non_camel_case_types, missing_docs)]
194#[non_exhaustive]
195pub enum StaticArrayOpDef {
196 get,
197 len,
198}
199
200impl StaticArrayOpDef {
201 fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak<Extension>) -> SignatureFunc {
202 use TypeBound::Copyable;
203 let t_param = TypeParam::from(Copyable);
204 let elem_ty = Type::new_var_use(0, Copyable);
205 let array_ty: Type =
206 instantiate_const_static_array_custom_type(def, elem_ty.clone()).into();
207 match self {
208 Self::get => PolyFuncType::new(
209 [t_param],
210 Signature::new(vec![array_ty, usize_t()], Type::from(option_type(elem_ty))),
211 )
212 .into(),
213 Self::len => PolyFuncType::new([t_param], Signature::new(array_ty, usize_t())).into(),
214 }
215 }
216}
217
218impl MakeOpDef for StaticArrayOpDef {
219 fn opdef_id(&self) -> OpName {
220 <&'static str>::from(self).into()
221 }
222
223 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
224 where
225 Self: Sized,
226 {
227 try_from_name(op_def.name(), op_def.extension_id())
228 }
229
230 fn init_signature(&self, extension_ref: &sync::Weak<Extension>) -> SignatureFunc {
231 self.signature_from_def(
232 EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
233 extension_ref,
234 )
235 }
236
237 fn extension_ref(&self) -> sync::Weak<Extension> {
238 Arc::downgrade(&EXTENSION)
239 }
240
241 fn extension(&self) -> ExtensionId {
242 EXTENSION_ID.clone()
243 }
244
245 fn description(&self) -> String {
246 match self {
247 Self::get => "Get an element from a static array",
248 Self::len => "Get the length of a static array",
249 }
250 .into()
251 }
252
253 fn add_to_extension(
257 &self,
258 extension: &mut Extension,
259 extension_ref: &sync::Weak<Extension>,
260 ) -> Result<(), crate::extension::ExtensionBuildError> {
261 let sig = self.signature_from_def(
262 extension.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
263 extension_ref,
264 );
265 let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
266
267 self.post_opdef(def);
268
269 Ok(())
270 }
271}
272
273#[derive(Clone, Debug, PartialEq)]
274pub struct StaticArrayOp {
276 pub def: StaticArrayOpDef,
278 pub elem_ty: Type,
280}
281
282impl MakeExtensionOp for StaticArrayOp {
283 fn op_id(&self) -> OpName {
284 self.def.opdef_id()
285 }
286
287 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
288 where
289 Self: Sized,
290 {
291 let def = StaticArrayOpDef::from_def(ext_op.def())?;
292 def.instantiate(ext_op.args())
293 }
294
295 fn type_args(&self) -> Vec<TypeArg> {
296 vec![self.elem_ty.clone().into()]
297 }
298}
299
300impl HasDef for StaticArrayOp {
301 type Def = StaticArrayOpDef;
302}
303
304impl HasConcrete for StaticArrayOpDef {
305 type Concrete = StaticArrayOp;
306
307 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
308 use TypeBound::Copyable;
309 match type_args {
310 [arg] => {
311 let elem_ty = arg
312 .as_runtime()
313 .filter(|t| Copyable.contains(t.least_upper_bound()))
314 .ok_or(SignatureError::TypeArgMismatch(
315 TermTypeError::TypeMismatch {
316 type_: Box::new(Copyable.into()),
317 term: Box::new(arg.clone()),
318 },
319 ))?;
320
321 Ok(StaticArrayOp {
322 def: *self,
323 elem_ty,
324 })
325 }
326 _ => Err(
327 SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1))
328 .into(),
329 ),
330 }
331 }
332}
333
334impl MakeRegisteredOp for StaticArrayOp {
335 fn extension_id(&self) -> ExtensionId {
336 EXTENSION_ID.clone()
337 }
338
339 fn extension_ref(&self) -> sync::Weak<Extension> {
340 Arc::downgrade(&EXTENSION)
341 }
342}
343
344pub trait StaticArrayOpBuilder: Dataflow {
346 fn add_static_array_get(
358 &mut self,
359 elem_ty: Type,
360 array: Wire,
361 index: Wire,
362 ) -> Result<Wire, BuildError> {
363 Ok(self
364 .add_dataflow_op(
365 StaticArrayOp {
366 def: StaticArrayOpDef::get,
367 elem_ty,
368 }
369 .to_extension_op()
370 .unwrap(),
371 [array, index],
372 )?
373 .out_wire(0))
374 }
375
376 fn add_static_array_len(&mut self, elem_ty: Type, array: Wire) -> Result<Wire, BuildError> {
387 Ok(self
388 .add_dataflow_op(
389 StaticArrayOp {
390 def: StaticArrayOpDef::len,
391 elem_ty,
392 }
393 .to_extension_op()
394 .unwrap(),
395 [array],
396 )?
397 .out_wire(0))
398 }
399}
400
401impl<T: Dataflow> StaticArrayOpBuilder for T {}
402
403#[cfg(test)]
404mod test {
405 use crate::{
406 builder::{DFGBuilder, DataflowHugr as _},
407 extension::prelude::{ConstUsize, qb_t},
408 type_row,
409 };
410
411 use super::*;
412
413 #[test]
414 fn const_static_array_copyable() {
415 let _good = StaticArrayValue::try_new_empty("good", Type::UNIT).unwrap();
416 let _bad = StaticArrayValue::try_new_empty("good", qb_t()).unwrap_err();
417 }
418
419 #[test]
420 fn all_ops() {
421 let _ = {
422 let mut builder = DFGBuilder::new(Signature::new(
423 type_row![],
424 Type::from(option_type(usize_t())),
425 ))
426 .unwrap();
427 let array = builder.add_load_value(
428 StaticArrayValue::try_new(
429 "t",
430 usize_t(),
431 (1..999).map(|x| ConstUsize::new(x).into()),
432 )
433 .unwrap(),
434 );
435 let _ = builder.add_static_array_len(usize_t(), array).unwrap();
436 let index = builder.add_load_value(ConstUsize::new(777));
437 let x = builder
438 .add_static_array_get(usize_t(), array, index)
439 .unwrap();
440 builder.finish_hugr_with_outputs([x]).unwrap()
441 };
442 }
443}