hugr_core/std_extensions/collections/
array.rs

1//! Fixed-length array type and operations extension.
2
3mod array_op;
4mod array_repeat;
5mod array_scan;
6
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use lazy_static::lazy_static;
11use serde::{Deserialize, Serialize};
12use std::hash::{Hash, Hasher};
13
14use crate::extension::resolution::{
15    resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError,
16    WeakExtensionRegistry,
17};
18use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
19use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound};
20use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName};
21use crate::ops::{ExtensionOp, OpName, Value};
22use crate::types::type_param::{TypeArg, TypeParam};
23use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName};
24use crate::Extension;
25
26pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter};
27pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID};
28pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID};
29
30/// Reported unique name of the array type.
31pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array");
32/// Reported unique name of the extension
33pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.array");
34/// Extension version.
35pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38/// Statically sized array of values, all of the same type.
39pub struct ArrayValue {
40    values: Vec<Value>,
41    typ: Type,
42}
43
44impl ArrayValue {
45    /// Name of the constructor for creating constant arrays.
46    #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))]
47    pub(crate) const CTR_NAME: &'static str = "collections.array.const";
48
49    /// Create a new [CustomConst] for an array of values of type `typ`.
50    /// That all values are of type `typ` is not checked here.
51    pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
52        Self {
53            values: contents.into_iter().collect_vec(),
54            typ,
55        }
56    }
57
58    /// Create a new [CustomConst] for an empty array of values of type `typ`.
59    pub fn new_empty(typ: Type) -> Self {
60        Self {
61            values: vec![],
62            typ,
63        }
64    }
65
66    /// Returns the type of the `[ArrayValue]` as a `[CustomType]`.`
67    pub fn custom_type(&self) -> CustomType {
68        array_custom_type(self.values.len() as u64, self.typ.clone())
69    }
70
71    /// Returns the type of values inside the `[ArrayValue]`.
72    pub fn get_element_type(&self) -> &Type {
73        &self.typ
74    }
75
76    /// Returns the values contained inside the `[ArrayValue]`.
77    pub fn get_contents(&self) -> &[Value] {
78        &self.values
79    }
80}
81
82impl TryHash for ArrayValue {
83    fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
84        maybe_hash_values(&self.values, &mut st) && {
85            self.typ.hash(&mut st);
86            true
87        }
88    }
89}
90
91#[typetag::serde]
92impl CustomConst for ArrayValue {
93    fn name(&self) -> ValueName {
94        ValueName::new_inline("array")
95    }
96
97    fn get_type(&self) -> Type {
98        self.custom_type().into()
99    }
100
101    fn validate(&self) -> Result<(), CustomCheckFailure> {
102        let typ = self.custom_type();
103
104        EXTENSION
105            .get_type(&ARRAY_TYPENAME)
106            .unwrap()
107            .check_custom(&typ)
108            .map_err(|_| {
109                CustomCheckFailure::Message(format!(
110                    "Custom typ {typ} is not a valid instantiation of array."
111                ))
112            })?;
113
114        // constant can only hold classic type.
115        let ty = match typ.args() {
116            [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]
117                if *n as usize == self.values.len() =>
118            {
119                ty
120            }
121            _ => {
122                return Err(CustomCheckFailure::Message(format!(
123                    "Invalid array type arguments: {:?}",
124                    typ.args()
125                )))
126            }
127        };
128
129        // check all values are instances of the element type
130        for v in &self.values {
131            if v.get_type() != *ty {
132                return Err(CustomCheckFailure::Message(format!(
133                    "Array element {v:?} is not of expected type {ty}"
134                )));
135            }
136        }
137
138        Ok(())
139    }
140
141    fn equal_consts(&self, other: &dyn CustomConst) -> bool {
142        crate::ops::constant::downcast_equal_consts(self, other)
143    }
144
145    fn extension_reqs(&self) -> ExtensionSet {
146        ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs))
147            .union(EXTENSION_ID.into())
148    }
149
150    fn update_extensions(
151        &mut self,
152        extensions: &WeakExtensionRegistry,
153    ) -> Result<(), ExtensionResolutionError> {
154        for val in &mut self.values {
155            resolve_value_extensions(val, extensions)?;
156        }
157        resolve_type_extensions(&mut self.typ, extensions)
158    }
159}
160
161lazy_static! {
162    /// Extension for array operations.
163    pub static ref EXTENSION: Arc<Extension> = {
164        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
165            extension.add_type(
166                    ARRAY_TYPENAME,
167                    vec![ TypeParam::max_nat(), TypeBound::Any.into()],
168                    "Fixed-length array".into(),
169                    TypeDefBound::from_params(vec![1] ),
170                    extension_ref,
171                )
172                .unwrap();
173
174            array_op::ArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
175            array_repeat::ArrayRepeatDef.add_to_extension(extension, extension_ref).unwrap();
176            array_scan::ArrayScanDef.add_to_extension(extension, extension_ref).unwrap();
177        })
178    };
179}
180
181fn array_type_def() -> &'static TypeDef {
182    EXTENSION.get_type(&ARRAY_TYPENAME).unwrap()
183}
184
185/// Instantiate a new array type given a size argument and element type.
186///
187/// This method is equivalent to [`array_type_parametric`], but uses concrete
188/// arguments types to ensure no errors are possible.
189pub fn array_type(size: u64, element_ty: Type) -> Type {
190    array_custom_type(size, element_ty).into()
191}
192
193/// Instantiate a new array type given the size and element type parameters.
194///
195/// This is a generic version of [`array_type`].
196pub fn array_type_parametric(
197    size: impl Into<TypeArg>,
198    element_ty: impl Into<TypeArg>,
199) -> Result<Type, SignatureError> {
200    instantiate_array(array_type_def(), size, element_ty)
201}
202
203fn array_custom_type(size: impl Into<TypeArg>, element_ty: impl Into<TypeArg>) -> CustomType {
204    instantiate_array_custom(array_type_def(), size, element_ty)
205        .expect("array parameters are valid")
206}
207
208fn instantiate_array_custom(
209    array_def: &TypeDef,
210    size: impl Into<TypeArg>,
211    element_ty: impl Into<TypeArg>,
212) -> Result<CustomType, SignatureError> {
213    array_def.instantiate(vec![size.into(), element_ty.into()])
214}
215
216fn instantiate_array(
217    array_def: &TypeDef,
218    size: impl Into<TypeArg>,
219    element_ty: impl Into<TypeArg>,
220) -> Result<Type, SignatureError> {
221    instantiate_array_custom(array_def, size, element_ty).map(Into::into)
222}
223
224/// Name of the operation in the prelude for creating new arrays.
225pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array");
226
227/// Initialize a new array op of element type `element_ty` of length `size`
228pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
229    let op = array_op::ArrayOpDef::new_array.to_concrete(element_ty, size);
230    op.to_extension_op().unwrap()
231}
232
233#[cfg(test)]
234mod test {
235    use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr};
236    use crate::extension::prelude::{qb_t, usize_t, ConstUsize};
237    use crate::ops::constant::CustomConst;
238    use crate::std_extensions::arithmetic::float_types::ConstF64;
239
240    use super::{array_type, new_array_op, ArrayValue};
241
242    #[test]
243    /// Test building a HUGR involving a new_array operation.
244    fn test_new_array() {
245        let mut b =
246            DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], array_type(2, qb_t()))).unwrap();
247
248        let [q1, q2] = b.input_wires_arr();
249
250        let op = new_array_op(qb_t(), 2);
251
252        let out = b.add_dataflow_op(op, [q1, q2]).unwrap();
253
254        b.finish_hugr_with_outputs(out.outputs()).unwrap();
255    }
256
257    #[test]
258    fn test_array_value() {
259        let array_value = ArrayValue {
260            values: vec![ConstUsize::new(3).into()],
261            typ: usize_t(),
262        };
263
264        array_value.validate().unwrap();
265
266        let wrong_array_value = ArrayValue {
267            values: vec![ConstF64::new(1.2).into()],
268            typ: usize_t(),
269        };
270        assert!(wrong_array_value.validate().is_err());
271    }
272}