hugr_core/std_extensions/collections/
static_array.rs

1//! An extension for working with static arrays.
2//!
3//! The extension `collections.static_arrays` models globally available constant
4//! arrays of [`TypeBound::Copyable`] values.
5//!
6//! The type `static_array<T>` is parameterised by its element type. Note that
7//! unlike `collections.array.array` the length of a static array is not tracked
8//! in type args.
9//!
10//! The [`CustomConst`] [`StaticArrayValue`] is the only manner by which a value of
11//! `static_array` type can be obtained.
12//!
13//! Operations provided:
14//!  * `get<T: Copyable>: [static_array<T>, prelude.usize] -> [[] + [T]]`
15//!  * `len<T: Copyable>: [static_array<T>] -> [prelude.usize]`
16use std::{
17    hash::{self, Hash as _},
18    iter,
19    sync::{self, Arc},
20};
21
22use crate::{
23    Extension, Wire,
24    builder::{BuildError, Dataflow},
25    extension::{
26        ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef,
27        prelude::{option_type, usize_t},
28        resolution::{ExtensionResolutionError, WeakExtensionRegistry},
29        simple_op::{
30            HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
31            try_from_name,
32        },
33    },
34    ops::{
35        ExtensionOp, OpName, Value,
36        constant::{CustomConst, TryHash, ValueName, maybe_hash_values},
37    },
38    types::{
39        ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg,
40        TypeBound, TypeName,
41        type_param::{TermTypeError, TypeParam},
42    },
43};
44
45use lazy_static::lazy_static;
46
47use super::array::ArrayValue;
48
49/// Reported unique name of the extension
50pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.static_array");
51/// Reported unique name of the array type.
52pub const STATIC_ARRAY_TYPENAME: TypeName = TypeName::new_inline("static_array");
53/// Extension version.
54pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
55
56#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, derive_more::From)]
57/// Statically sized array of values, all of the same [`TypeBound::Copyable`]
58/// type.
59pub struct StaticArrayValue {
60    /// The contents of the `StaticArrayValue`.
61    pub value: ArrayValue,
62    /// The name of the `StaticArrayValue`.
63    pub name: String,
64}
65
66impl StaticArrayValue {
67    /// Returns the type of values inside the `[StaticArrayValue]`.
68    #[must_use]
69    pub fn get_element_type(&self) -> &Type {
70        self.value.get_element_type()
71    }
72
73    /// Returns the values contained inside the `[StaticArrayValue]`.
74    #[must_use]
75    pub fn get_contents(&self) -> &[Value] {
76        self.value.get_contents()
77    }
78
79    /// Create a new [`CustomConst`] for an array of values of type `typ`.
80    /// That all values are of type `typ` is not checked here.
81    pub fn try_new(
82        name: impl ToString,
83        typ: Type,
84        contents: impl IntoIterator<Item = Value>,
85    ) -> Result<Self, ConstTypeError> {
86        if !TypeBound::Copyable.contains(typ.least_upper_bound()) {
87            return Err(CustomCheckFailure::Message(format!(
88                "Failed to construct a StaticArrayValue with non-Copyable type: {typ}"
89            ))
90            .into());
91        }
92        Ok(Self {
93            value: ArrayValue::new(typ, contents),
94            name: name.to_string(),
95        })
96    }
97
98    /// Create a new [`CustomConst`] for an empty array of values of type `typ`.
99    pub fn try_new_empty(name: impl ToString, typ: Type) -> Result<Self, ConstTypeError> {
100        Self::try_new(name, typ, iter::empty())
101    }
102
103    /// Returns the type of the `[StaticArrayValue]` as a `[CustomType]`.`
104    #[must_use]
105    pub fn custom_type(&self) -> CustomType {
106        static_array_custom_type(self.get_element_type().clone())
107    }
108}
109
110impl TryHash for StaticArrayValue {
111    fn try_hash(&self, mut st: &mut dyn hash::Hasher) -> bool {
112        maybe_hash_values(self.get_contents(), &mut st) && {
113            self.name.hash(&mut st);
114            self.get_element_type().hash(&mut st);
115            true
116        }
117    }
118}
119
120#[typetag::serde]
121impl CustomConst for StaticArrayValue {
122    fn name(&self) -> ValueName {
123        ValueName::new_inline("const_array")
124    }
125
126    fn get_type(&self) -> Type {
127        self.custom_type().into()
128    }
129
130    fn equal_consts(&self, other: &dyn CustomConst) -> bool {
131        crate::ops::constant::downcast_equal_consts(self, other)
132    }
133
134    fn update_extensions(
135        &mut self,
136        extensions: &WeakExtensionRegistry,
137    ) -> Result<(), ExtensionResolutionError> {
138        self.value.update_extensions(extensions)
139    }
140}
141
142lazy_static! {
143    /// Extension for array operations.
144    pub static ref EXTENSION: Arc<Extension> = {
145        use TypeBound::Copyable;
146        Extension::new_arc(EXTENSION_ID.clone(), VERSION, |extension, extension_ref| {
147            extension.add_type(
148                    STATIC_ARRAY_TYPENAME,
149                    vec![Copyable.into()],
150                    "Fixed-length constant array".into(),
151                    Copyable.into(),
152                    extension_ref,
153                )
154                .unwrap();
155
156            StaticArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
157        })
158    };
159}
160
161fn instantiate_const_static_array_custom_type(
162    def: &TypeDef,
163    element_ty: impl Into<TypeArg>,
164) -> CustomType {
165    def.instantiate([element_ty.into()])
166        .unwrap_or_else(|e| panic!("{e}"))
167}
168
169/// Instantiate a new `static_array` [`CustomType`] given an element type.
170pub fn static_array_custom_type(element_ty: impl Into<TypeArg>) -> CustomType {
171    instantiate_const_static_array_custom_type(
172        EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
173        element_ty,
174    )
175}
176
177/// Instantiate a new `static_array` [Type] given an element type.
178pub fn static_array_type(element_ty: impl Into<TypeArg>) -> Type {
179    static_array_custom_type(element_ty).into()
180}
181
182#[derive(
183    Clone,
184    Copy,
185    Debug,
186    Hash,
187    PartialEq,
188    Eq,
189    strum::EnumIter,
190    strum::IntoStaticStr,
191    strum::EnumString,
192)]
193#[allow(non_camel_case_types, missing_docs)]
194#[non_exhaustive]
195pub enum StaticArrayOpDef {
196    get,
197    len,
198}
199
200impl StaticArrayOpDef {
201    fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak<Extension>) -> SignatureFunc {
202        use TypeBound::Copyable;
203        let t_param = TypeParam::from(Copyable);
204        let elem_ty = Type::new_var_use(0, Copyable);
205        let array_ty: Type =
206            instantiate_const_static_array_custom_type(def, elem_ty.clone()).into();
207        match self {
208            Self::get => PolyFuncType::new(
209                [t_param],
210                Signature::new(vec![array_ty, usize_t()], Type::from(option_type(elem_ty))),
211            )
212            .into(),
213            Self::len => PolyFuncType::new([t_param], Signature::new(array_ty, usize_t())).into(),
214        }
215    }
216}
217
218impl MakeOpDef for StaticArrayOpDef {
219    fn opdef_id(&self) -> OpName {
220        <&'static str>::from(self).into()
221    }
222
223    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
224    where
225        Self: Sized,
226    {
227        try_from_name(op_def.name(), op_def.extension_id())
228    }
229
230    fn init_signature(&self, extension_ref: &sync::Weak<Extension>) -> SignatureFunc {
231        self.signature_from_def(
232            EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
233            extension_ref,
234        )
235    }
236
237    fn extension_ref(&self) -> sync::Weak<Extension> {
238        Arc::downgrade(&EXTENSION)
239    }
240
241    fn extension(&self) -> ExtensionId {
242        EXTENSION_ID.clone()
243    }
244
245    fn description(&self) -> String {
246        match self {
247            Self::get => "Get an element from a static array",
248            Self::len => "Get the length of a static array",
249        }
250        .into()
251    }
252
253    // This method is re-defined here since we need to pass the static array
254    // type def while computing the signature, to avoid recursive loops
255    // initializing the extension.
256    fn add_to_extension(
257        &self,
258        extension: &mut Extension,
259        extension_ref: &sync::Weak<Extension>,
260    ) -> Result<(), crate::extension::ExtensionBuildError> {
261        let sig = self.signature_from_def(
262            extension.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
263            extension_ref,
264        );
265        let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
266
267        self.post_opdef(def);
268
269        Ok(())
270    }
271}
272
273#[derive(Clone, Debug, PartialEq)]
274/// Concrete array operation.
275pub struct StaticArrayOp {
276    /// The operation definition.
277    pub def: StaticArrayOpDef,
278    /// The element type of the array.
279    pub elem_ty: Type,
280}
281
282impl MakeExtensionOp for StaticArrayOp {
283    fn op_id(&self) -> OpName {
284        self.def.opdef_id()
285    }
286
287    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
288    where
289        Self: Sized,
290    {
291        let def = StaticArrayOpDef::from_def(ext_op.def())?;
292        def.instantiate(ext_op.args())
293    }
294
295    fn type_args(&self) -> Vec<TypeArg> {
296        vec![self.elem_ty.clone().into()]
297    }
298}
299
300impl HasDef for StaticArrayOp {
301    type Def = StaticArrayOpDef;
302}
303
304impl HasConcrete for StaticArrayOpDef {
305    type Concrete = StaticArrayOp;
306
307    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
308        use TypeBound::Copyable;
309        match type_args {
310            [arg] => {
311                let elem_ty = arg
312                    .as_runtime()
313                    .filter(|t| Copyable.contains(t.least_upper_bound()))
314                    .ok_or(SignatureError::TypeArgMismatch(
315                        TermTypeError::TypeMismatch {
316                            type_: Box::new(Copyable.into()),
317                            term: Box::new(arg.clone()),
318                        },
319                    ))?;
320
321                Ok(StaticArrayOp {
322                    def: *self,
323                    elem_ty,
324                })
325            }
326            _ => Err(
327                SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1))
328                    .into(),
329            ),
330        }
331    }
332}
333
334impl MakeRegisteredOp for StaticArrayOp {
335    fn extension_id(&self) -> ExtensionId {
336        EXTENSION_ID.clone()
337    }
338
339    fn extension_ref(&self) -> sync::Weak<Extension> {
340        Arc::downgrade(&EXTENSION)
341    }
342}
343
344/// A trait for building static array operations in a dataflow graph.
345pub trait StaticArrayOpBuilder: Dataflow {
346    /// Adds a `get` operation to retrieve an element from a static array.
347    ///
348    /// # Arguments
349    ///
350    /// + `elem_ty` - The type of the elements in the array.
351    /// + `array` - The wire carrying the array.
352    /// + `index` - The wire carrying the index.
353    ///
354    /// # Returns
355    ///
356    /// Returns a `Result` containing the wire for the retrieved element or a `BuildError`.
357    fn add_static_array_get(
358        &mut self,
359        elem_ty: Type,
360        array: Wire,
361        index: Wire,
362    ) -> Result<Wire, BuildError> {
363        Ok(self
364            .add_dataflow_op(
365                StaticArrayOp {
366                    def: StaticArrayOpDef::get,
367                    elem_ty,
368                }
369                .to_extension_op()
370                .unwrap(),
371                [array, index],
372            )?
373            .out_wire(0))
374    }
375
376    /// Adds a `len` operation to get the length of a static array.
377    ///
378    /// # Arguments
379    ///
380    /// + `elem_ty` - The type of the elements in the array.
381    /// + `array` - The wire representing the array.
382    ///
383    /// # Returns
384    ///
385    /// Returns a `Result` containing the wire for the length of the array or a `BuildError`.
386    fn add_static_array_len(&mut self, elem_ty: Type, array: Wire) -> Result<Wire, BuildError> {
387        Ok(self
388            .add_dataflow_op(
389                StaticArrayOp {
390                    def: StaticArrayOpDef::len,
391                    elem_ty,
392                }
393                .to_extension_op()
394                .unwrap(),
395                [array],
396            )?
397            .out_wire(0))
398    }
399}
400
401impl<T: Dataflow> StaticArrayOpBuilder for T {}
402
403#[cfg(test)]
404mod test {
405    use crate::{
406        builder::{DFGBuilder, DataflowHugr as _},
407        extension::prelude::{ConstUsize, qb_t},
408        type_row,
409    };
410
411    use super::*;
412
413    #[test]
414    fn const_static_array_copyable() {
415        let _good = StaticArrayValue::try_new_empty("good", Type::UNIT).unwrap();
416        let _bad = StaticArrayValue::try_new_empty("good", qb_t()).unwrap_err();
417    }
418
419    #[test]
420    fn all_ops() {
421        let _ = {
422            let mut builder = DFGBuilder::new(Signature::new(
423                type_row![],
424                Type::from(option_type(usize_t())),
425            ))
426            .unwrap();
427            let array = builder.add_load_value(
428                StaticArrayValue::try_new(
429                    "t",
430                    usize_t(),
431                    (1..999).map(|x| ConstUsize::new(x).into()),
432                )
433                .unwrap(),
434            );
435            let _ = builder.add_static_array_len(usize_t(), array).unwrap();
436            let index = builder.add_load_value(ConstUsize::new(777));
437            let x = builder
438                .add_static_array_get(usize_t(), array, index)
439                .unwrap();
440            builder.finish_hugr_with_outputs([x]).unwrap()
441        };
442    }
443}