hugr_core/std_extensions/collections/
static_array.rs

1//! An extension for working with static arrays.
2//!
3//! The extension `collections.static_arrays` models globally available constant
4//! arrays of [`TypeBound::Copyable`] values.
5//!
6//! The type `static_array<T>` is parameterised by its element type. Note that
7//! unlike `collections.array.array` the length of a static array is not tracked
8//! in type args.
9//!
10//! The [`CustomConst`] [`StaticArrayValue`] is the only manner by which a value of
11//! `static_array` type can be obtained.
12//!
13//! Operations provided:
14//!  * `get<T: Copyable>: [static_array<T>, prelude.usize] -> [[] + [T]]`
15//!  * `len<T: Copyable>: [static_array<T>] -> [prelude.usize]`
16use std::{
17    hash::{self, Hash as _},
18    iter,
19    sync::{self, Arc, LazyLock},
20};
21
22use crate::{
23    Extension, Wire,
24    builder::{BuildError, Dataflow},
25    extension::{
26        ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef,
27        prelude::{option_type, usize_t},
28        resolution::{ExtensionResolutionError, WeakExtensionRegistry},
29        simple_op::{
30            HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
31            try_from_name,
32        },
33    },
34    ops::{
35        ExtensionOp, OpName, Value,
36        constant::{CustomConst, TryHash, ValueName, maybe_hash_values},
37    },
38    types::{
39        ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg,
40        TypeBound, TypeName,
41        type_param::{TermTypeError, TypeParam},
42    },
43};
44
45use super::array::ArrayValue;
46
47/// Reported unique name of the extension
48pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.static_array");
49/// Reported unique name of the array type.
50pub const STATIC_ARRAY_TYPENAME: TypeName = TypeName::new_inline("static_array");
51/// Extension version.
52pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
53
54#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, derive_more::From)]
55/// Statically sized array of values, all of the same [`TypeBound::Copyable`]
56/// type.
57pub struct StaticArrayValue {
58    /// The contents of the `StaticArrayValue`.
59    pub value: ArrayValue,
60    /// The name of the `StaticArrayValue`.
61    pub name: String,
62}
63
64impl StaticArrayValue {
65    /// Returns the type of values inside the `[StaticArrayValue]`.
66    #[must_use]
67    pub fn get_element_type(&self) -> &Type {
68        self.value.get_element_type()
69    }
70
71    /// Returns the values contained inside the `[StaticArrayValue]`.
72    #[must_use]
73    pub fn get_contents(&self) -> &[Value] {
74        self.value.get_contents()
75    }
76
77    /// Create a new [`CustomConst`] for an array of values of type `typ`.
78    /// That all values are of type `typ` is not checked here.
79    pub fn try_new(
80        name: impl ToString,
81        typ: Type,
82        contents: impl IntoIterator<Item = Value>,
83    ) -> Result<Self, ConstTypeError> {
84        if !TypeBound::Copyable.contains(typ.least_upper_bound()) {
85            return Err(CustomCheckFailure::Message(format!(
86                "Failed to construct a StaticArrayValue with non-Copyable type: {typ}"
87            ))
88            .into());
89        }
90        Ok(Self {
91            value: ArrayValue::new(typ, contents),
92            name: name.to_string(),
93        })
94    }
95
96    /// Create a new [`CustomConst`] for an empty array of values of type `typ`.
97    pub fn try_new_empty(name: impl ToString, typ: Type) -> Result<Self, ConstTypeError> {
98        Self::try_new(name, typ, iter::empty())
99    }
100
101    /// Returns the type of the `[StaticArrayValue]` as a `[CustomType]`.`
102    #[must_use]
103    pub fn custom_type(&self) -> CustomType {
104        static_array_custom_type(self.get_element_type().clone())
105    }
106}
107
108impl TryHash for StaticArrayValue {
109    fn try_hash(&self, mut st: &mut dyn hash::Hasher) -> bool {
110        maybe_hash_values(self.get_contents(), &mut st) && {
111            self.name.hash(&mut st);
112            self.get_element_type().hash(&mut st);
113            true
114        }
115    }
116}
117
118#[typetag::serde]
119impl CustomConst for StaticArrayValue {
120    fn name(&self) -> ValueName {
121        ValueName::new_inline("const_array")
122    }
123
124    fn get_type(&self) -> Type {
125        self.custom_type().into()
126    }
127
128    fn equal_consts(&self, other: &dyn CustomConst) -> bool {
129        crate::ops::constant::downcast_equal_consts(self, other)
130    }
131
132    fn update_extensions(
133        &mut self,
134        extensions: &WeakExtensionRegistry,
135    ) -> Result<(), ExtensionResolutionError> {
136        self.value.update_extensions(extensions)
137    }
138}
139
140/// Extension for array operations.
141pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
142    use TypeBound::Copyable;
143    Extension::new_arc(EXTENSION_ID.clone(), VERSION, |extension, extension_ref| {
144        extension
145            .add_type(
146                STATIC_ARRAY_TYPENAME,
147                vec![Copyable.into()],
148                "Fixed-length constant array".into(),
149                Copyable.into(),
150                extension_ref,
151            )
152            .unwrap();
153
154        StaticArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
155    })
156});
157
158fn instantiate_const_static_array_custom_type(
159    def: &TypeDef,
160    element_ty: impl Into<TypeArg>,
161) -> CustomType {
162    def.instantiate([element_ty.into()])
163        .unwrap_or_else(|e| panic!("{e}"))
164}
165
166/// Instantiate a new `static_array` [`CustomType`] given an element type.
167pub fn static_array_custom_type(element_ty: impl Into<TypeArg>) -> CustomType {
168    instantiate_const_static_array_custom_type(
169        EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
170        element_ty,
171    )
172}
173
174/// Instantiate a new `static_array` [Type] given an element type.
175pub fn static_array_type(element_ty: impl Into<TypeArg>) -> Type {
176    static_array_custom_type(element_ty).into()
177}
178
179#[derive(
180    Clone,
181    Copy,
182    Debug,
183    Hash,
184    PartialEq,
185    Eq,
186    strum::EnumIter,
187    strum::IntoStaticStr,
188    strum::EnumString,
189)]
190#[allow(non_camel_case_types, missing_docs)]
191#[non_exhaustive]
192pub enum StaticArrayOpDef {
193    get,
194    len,
195}
196
197impl StaticArrayOpDef {
198    fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak<Extension>) -> SignatureFunc {
199        use TypeBound::Copyable;
200        let t_param = TypeParam::from(Copyable);
201        let elem_ty = Type::new_var_use(0, Copyable);
202        let array_ty: Type =
203            instantiate_const_static_array_custom_type(def, elem_ty.clone()).into();
204        match self {
205            Self::get => PolyFuncType::new(
206                [t_param],
207                Signature::new(vec![array_ty, usize_t()], Type::from(option_type(elem_ty))),
208            )
209            .into(),
210            Self::len => PolyFuncType::new([t_param], Signature::new(array_ty, usize_t())).into(),
211        }
212    }
213}
214
215impl MakeOpDef for StaticArrayOpDef {
216    fn opdef_id(&self) -> OpName {
217        <&'static str>::from(self).into()
218    }
219
220    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
221    where
222        Self: Sized,
223    {
224        try_from_name(op_def.name(), op_def.extension_id())
225    }
226
227    fn init_signature(&self, extension_ref: &sync::Weak<Extension>) -> SignatureFunc {
228        self.signature_from_def(
229            EXTENSION.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
230            extension_ref,
231        )
232    }
233
234    fn extension_ref(&self) -> sync::Weak<Extension> {
235        Arc::downgrade(&EXTENSION)
236    }
237
238    fn extension(&self) -> ExtensionId {
239        EXTENSION_ID.clone()
240    }
241
242    fn description(&self) -> String {
243        match self {
244            Self::get => "Get an element from a static array",
245            Self::len => "Get the length of a static array",
246        }
247        .into()
248    }
249
250    // This method is re-defined here since we need to pass the static array
251    // type def while computing the signature, to avoid recursive loops
252    // initializing the extension.
253    fn add_to_extension(
254        &self,
255        extension: &mut Extension,
256        extension_ref: &sync::Weak<Extension>,
257    ) -> Result<(), crate::extension::ExtensionBuildError> {
258        let sig = self.signature_from_def(
259            extension.get_type(&STATIC_ARRAY_TYPENAME).unwrap(),
260            extension_ref,
261        );
262        let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
263
264        self.post_opdef(def);
265
266        Ok(())
267    }
268}
269
270#[derive(Clone, Debug, PartialEq)]
271/// Concrete array operation.
272pub struct StaticArrayOp {
273    /// The operation definition.
274    pub def: StaticArrayOpDef,
275    /// The element type of the array.
276    pub elem_ty: Type,
277}
278
279impl MakeExtensionOp for StaticArrayOp {
280    fn op_id(&self) -> OpName {
281        self.def.opdef_id()
282    }
283
284    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
285    where
286        Self: Sized,
287    {
288        let def = StaticArrayOpDef::from_def(ext_op.def())?;
289        def.instantiate(ext_op.args())
290    }
291
292    fn type_args(&self) -> Vec<TypeArg> {
293        vec![self.elem_ty.clone().into()]
294    }
295}
296
297impl HasDef for StaticArrayOp {
298    type Def = StaticArrayOpDef;
299}
300
301impl HasConcrete for StaticArrayOpDef {
302    type Concrete = StaticArrayOp;
303
304    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
305        use TypeBound::Copyable;
306        match type_args {
307            [arg] => {
308                let elem_ty = arg
309                    .as_runtime()
310                    .filter(|t| Copyable.contains(t.least_upper_bound()))
311                    .ok_or(SignatureError::TypeArgMismatch(
312                        TermTypeError::TypeMismatch {
313                            type_: Box::new(Copyable.into()),
314                            term: Box::new(arg.clone()),
315                        },
316                    ))?;
317
318                Ok(StaticArrayOp {
319                    def: *self,
320                    elem_ty,
321                })
322            }
323            _ => Err(
324                SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1))
325                    .into(),
326            ),
327        }
328    }
329}
330
331impl MakeRegisteredOp for StaticArrayOp {
332    fn extension_id(&self) -> ExtensionId {
333        EXTENSION_ID.clone()
334    }
335
336    fn extension_ref(&self) -> sync::Weak<Extension> {
337        Arc::downgrade(&EXTENSION)
338    }
339}
340
341/// A trait for building static array operations in a dataflow graph.
342pub trait StaticArrayOpBuilder: Dataflow {
343    /// Adds a `get` operation to retrieve an element from a static array.
344    ///
345    /// # Arguments
346    ///
347    /// + `elem_ty` - The type of the elements in the array.
348    /// + `array` - The wire carrying the array.
349    /// + `index` - The wire carrying the index.
350    ///
351    /// # Returns
352    ///
353    /// Returns a `Result` containing the wire for the retrieved element or a `BuildError`.
354    fn add_static_array_get(
355        &mut self,
356        elem_ty: Type,
357        array: Wire,
358        index: Wire,
359    ) -> Result<Wire, BuildError> {
360        Ok(self
361            .add_dataflow_op(
362                StaticArrayOp {
363                    def: StaticArrayOpDef::get,
364                    elem_ty,
365                }
366                .to_extension_op()
367                .unwrap(),
368                [array, index],
369            )?
370            .out_wire(0))
371    }
372
373    /// Adds a `len` operation to get the length of a static array.
374    ///
375    /// # Arguments
376    ///
377    /// + `elem_ty` - The type of the elements in the array.
378    /// + `array` - The wire representing the array.
379    ///
380    /// # Returns
381    ///
382    /// Returns a `Result` containing the wire for the length of the array or a `BuildError`.
383    fn add_static_array_len(&mut self, elem_ty: Type, array: Wire) -> Result<Wire, BuildError> {
384        Ok(self
385            .add_dataflow_op(
386                StaticArrayOp {
387                    def: StaticArrayOpDef::len,
388                    elem_ty,
389                }
390                .to_extension_op()
391                .unwrap(),
392                [array],
393            )?
394            .out_wire(0))
395    }
396}
397
398impl<T: Dataflow> StaticArrayOpBuilder for T {}
399
400#[cfg(test)]
401mod test {
402    use crate::{
403        builder::{DFGBuilder, DataflowHugr as _},
404        extension::prelude::{ConstUsize, qb_t},
405        type_row,
406    };
407
408    use super::*;
409
410    #[test]
411    fn const_static_array_copyable() {
412        let _good = StaticArrayValue::try_new_empty("good", Type::UNIT).unwrap();
413        let _bad = StaticArrayValue::try_new_empty("good", qb_t()).unwrap_err();
414    }
415
416    #[test]
417    fn all_ops() {
418        let _ = {
419            let mut builder = DFGBuilder::new(Signature::new(
420                type_row![],
421                Type::from(option_type(usize_t())),
422            ))
423            .unwrap();
424            let array = builder.add_load_value(
425                StaticArrayValue::try_new(
426                    "t",
427                    usize_t(),
428                    (1..999).map(|x| ConstUsize::new(x).into()),
429                )
430                .unwrap(),
431            );
432            let _ = builder.add_static_array_len(usize_t(), array).unwrap();
433            let index = builder.add_load_value(ConstUsize::new(777));
434            let x = builder
435                .add_static_array_get(usize_t(), array, index)
436                .unwrap();
437            builder.finish_hugr_with_outputs([x]).unwrap()
438        };
439    }
440}