hugr_core/std_extensions/collections/array/
array_conversion.rs

1//! Operations for converting between the different array extensions
2
3use std::marker::PhantomData;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use crate::Extension;
8use crate::extension::simple_op::{
9    HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
10};
11use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
12use crate::ops::{ExtensionOp, NamedOp, OpName};
13use crate::types::type_param::{TypeArg, TypeParam};
14use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
15
16use super::array_kind::ArrayKind;
17
18/// Array conversion direction.
19///
20/// Either the current array type [INTO] the other one, or the current array type [FROM] the
21/// other one.
22pub type Direction = bool;
23
24/// Array conversion direction to turn the current array type [INTO] the other one.
25pub const INTO: Direction = true;
26
27/// Array conversion direction to obtain the current array type [FROM] the other one.
28pub const FROM: Direction = false;
29
30/// Definition of array conversion operations.
31///
32/// Generic over the concrete array implementation of the extension containing the operation, as
33/// well as over another array implementation that should be converted between. Also generic over
34/// the conversion [Direction].
35#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
36pub struct GenericArrayConvertDef<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>(
37    PhantomData<AK>,
38    PhantomData<OtherAK>,
39);
40
41impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
42    GenericArrayConvertDef<AK, DIR, OtherAK>
43{
44    /// Creates a new array conversion definition.
45    #[must_use]
46    pub fn new() -> Self {
47        GenericArrayConvertDef(PhantomData, PhantomData)
48    }
49}
50
51impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> Default
52    for GenericArrayConvertDef<AK, DIR, OtherAK>
53{
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> FromStr
60    for GenericArrayConvertDef<AK, DIR, OtherAK>
61{
62    type Err = ();
63
64    fn from_str(s: &str) -> Result<Self, Self::Err> {
65        let def = GenericArrayConvertDef::new();
66        if s == def.opdef_id() {
67            Ok(def)
68        } else {
69            Err(())
70        }
71    }
72}
73
74impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
75    GenericArrayConvertDef<AK, DIR, OtherAK>
76{
77    /// To avoid recursion when defining the extension, take the type definition as an argument.
78    fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
79        let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()];
80        let size = TypeArg::new_var_use(0, TypeParam::max_nat_type());
81        let element_ty = Type::new_var_use(1, TypeBound::Linear);
82
83        let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone())
84            .expect("Array type instantiation failed");
85        let other_ty =
86            OtherAK::ty_parametric(size, element_ty).expect("Array type instantiation failed");
87
88        let sig = match DIR {
89            INTO => FuncValueType::new(this_ty, other_ty),
90            FROM => FuncValueType::new(other_ty, this_ty),
91        };
92        PolyFuncTypeRV::new(params, sig).into()
93    }
94}
95
96impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeOpDef
97    for GenericArrayConvertDef<AK, DIR, OtherAK>
98{
99    fn opdef_id(&self) -> OpName {
100        match DIR {
101            INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
102            FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
103        }
104    }
105    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
106    where
107        Self: Sized,
108    {
109        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
110    }
111
112    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
113        self.signature_from_def(AK::type_def())
114    }
115
116    fn extension_ref(&self) -> Weak<Extension> {
117        Arc::downgrade(AK::extension())
118    }
119
120    fn extension(&self) -> ExtensionId {
121        AK::EXTENSION_ID
122    }
123
124    fn description(&self) -> String {
125        match DIR {
126            INTO => format!("Turns `{}` into `{}`", AK::TYPE_NAME, OtherAK::TYPE_NAME),
127            FROM => format!("Turns `{}` into `{}`", OtherAK::TYPE_NAME, AK::TYPE_NAME),
128        }
129    }
130
131    /// Add an operation implemented as a [`MakeOpDef`], which can provide the data
132    /// required to define an [`OpDef`], to an extension.
133    //
134    // This method is re-defined here since we need to pass the array type def while
135    // computing the signature, to avoid recursive loops initializing the extension.
136    fn add_to_extension(
137        &self,
138        extension: &mut Extension,
139        extension_ref: &Weak<Extension>,
140    ) -> Result<(), crate::extension::ExtensionBuildError> {
141        let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
142        let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
143        self.post_opdef(def);
144        Ok(())
145    }
146}
147
148/// Definition of the array conversion op.
149///
150/// Generic over the concrete array implementation of the extension containing the operation, as
151/// well as over another array implementation that should be converted between. Also generic over
152/// the conversion [Direction].
153#[derive(Clone, Debug, PartialEq)]
154pub struct GenericArrayConvert<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> {
155    /// The element type of the array.
156    pub elem_ty: Type,
157    /// Size of the array.
158    pub size: u64,
159    _kind: PhantomData<AK>,
160    _other_kind: PhantomData<OtherAK>,
161}
162
163impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
164    GenericArrayConvert<AK, DIR, OtherAK>
165{
166    /// Creates a new array conversion op.
167    #[must_use]
168    pub fn new(elem_ty: Type, size: u64) -> Self {
169        GenericArrayConvert {
170            elem_ty,
171            size,
172            _kind: PhantomData,
173            _other_kind: PhantomData,
174        }
175    }
176}
177
178impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> NamedOp
179    for GenericArrayConvert<AK, DIR, OtherAK>
180{
181    fn name(&self) -> OpName {
182        match DIR {
183            INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
184            FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
185        }
186    }
187}
188
189impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeExtensionOp
190    for GenericArrayConvert<AK, DIR, OtherAK>
191{
192    fn op_id(&self) -> OpName {
193        GenericArrayConvertDef::<AK, DIR, OtherAK>::new().opdef_id()
194    }
195
196    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
197    where
198        Self: Sized,
199    {
200        let def = GenericArrayConvertDef::<AK, DIR, OtherAK>::from_def(ext_op.def())?;
201        def.instantiate(ext_op.args())
202    }
203
204    fn type_args(&self) -> Vec<TypeArg> {
205        vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()]
206    }
207}
208
209impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeRegisteredOp
210    for GenericArrayConvert<AK, DIR, OtherAK>
211{
212    fn extension_id(&self) -> ExtensionId {
213        AK::EXTENSION_ID
214    }
215
216    fn extension_ref(&self) -> Weak<Extension> {
217        Arc::downgrade(AK::extension())
218    }
219}
220
221impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasDef
222    for GenericArrayConvert<AK, DIR, OtherAK>
223{
224    type Def = GenericArrayConvertDef<AK, DIR, OtherAK>;
225}
226
227impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasConcrete
228    for GenericArrayConvertDef<AK, DIR, OtherAK>
229{
230    type Concrete = GenericArrayConvert<AK, DIR, OtherAK>;
231
232    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
233        match type_args {
234            [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => {
235                Ok(GenericArrayConvert::new(ty.clone(), *n))
236            }
237            _ => Err(SignatureError::InvalidTypeArgs.into()),
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use rstest::rstest;
245
246    use crate::extension::prelude::bool_t;
247    use crate::ops::{OpTrait, OpType};
248    use crate::std_extensions::collections::array::Array;
249    use crate::std_extensions::collections::borrow_array::BorrowArray;
250    use crate::std_extensions::collections::value_array::ValueArray;
251
252    use super::*;
253
254    #[rstest]
255    #[case(ValueArray, Array)]
256    #[case(BorrowArray, Array)]
257    fn test_convert_from_def<AK: ArrayKind, OtherAK: ArrayKind>(
258        #[case] _kind: AK,
259        #[case] _other_kind: OtherAK,
260    ) {
261        let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(bool_t(), 2);
262        let optype: OpType = op.clone().into();
263        let new_op: GenericArrayConvert<AK, FROM, OtherAK> = optype.cast().unwrap();
264        assert_eq!(new_op, op);
265    }
266
267    #[rstest]
268    #[case(ValueArray, Array)]
269    #[case(BorrowArray, Array)]
270    fn test_convert_into_def<AK: ArrayKind, OtherAK: ArrayKind>(
271        #[case] _kind: AK,
272        #[case] _other_kind: OtherAK,
273    ) {
274        let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(bool_t(), 2);
275        let optype: OpType = op.clone().into();
276        let new_op: GenericArrayConvert<AK, INTO, OtherAK> = optype.cast().unwrap();
277        assert_eq!(new_op, op);
278    }
279
280    #[rstest]
281    #[case(ValueArray, Array)]
282    #[case(BorrowArray, Array)]
283    fn test_convert_from<AK: ArrayKind, OtherAK: ArrayKind>(
284        #[case] _kind: AK,
285        #[case] _other_kind: OtherAK,
286    ) {
287        let size = 2;
288        let element_ty = bool_t();
289        let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(element_ty.clone(), size);
290        let optype: OpType = op.into();
291        let sig = optype.dataflow_signature().unwrap();
292        assert_eq!(
293            sig.io(),
294            (
295                &vec![OtherAK::ty(size, element_ty.clone())].into(),
296                &vec![AK::ty(size, element_ty.clone())].into(),
297            )
298        );
299    }
300
301    #[rstest]
302    #[case(ValueArray, Array)]
303    #[case(BorrowArray, Array)]
304    fn test_convert_into<AK: ArrayKind, OtherAK: ArrayKind>(
305        #[case] _kind: AK,
306        #[case] _other_kind: OtherAK,
307    ) {
308        let size = 2;
309        let element_ty = bool_t();
310        let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(element_ty.clone(), size);
311        let optype: OpType = op.into();
312        let sig = optype.dataflow_signature().unwrap();
313        assert_eq!(
314            sig.io(),
315            (
316                &vec![AK::ty(size, element_ty.clone())].into(),
317                &vec![OtherAK::ty(size, element_ty.clone())].into(),
318            )
319        );
320    }
321}