1use std::marker::PhantomData;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use crate::Extension;
8use crate::extension::simple_op::{
9 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
10};
11use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
12use crate::ops::{ExtensionOp, NamedOp, OpName};
13use crate::types::type_param::{TypeArg, TypeParam};
14use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
15
16use super::array_kind::ArrayKind;
17
18pub type Direction = bool;
23
24pub const INTO: Direction = true;
26
27pub const FROM: Direction = false;
29
30#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
36pub struct GenericArrayConvertDef<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>(
37 PhantomData<AK>,
38 PhantomData<OtherAK>,
39);
40
41impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
42 GenericArrayConvertDef<AK, DIR, OtherAK>
43{
44 #[must_use]
46 pub fn new() -> Self {
47 GenericArrayConvertDef(PhantomData, PhantomData)
48 }
49}
50
51impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> Default
52 for GenericArrayConvertDef<AK, DIR, OtherAK>
53{
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> FromStr
60 for GenericArrayConvertDef<AK, DIR, OtherAK>
61{
62 type Err = ();
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 let def = GenericArrayConvertDef::new();
66 if s == def.opdef_id() {
67 Ok(def)
68 } else {
69 Err(())
70 }
71 }
72}
73
74impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
75 GenericArrayConvertDef<AK, DIR, OtherAK>
76{
77 fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
79 let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()];
80 let size = TypeArg::new_var_use(0, TypeParam::max_nat_type());
81 let element_ty = Type::new_var_use(1, TypeBound::Linear);
82
83 let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone())
84 .expect("Array type instantiation failed");
85 let other_ty =
86 OtherAK::ty_parametric(size, element_ty).expect("Array type instantiation failed");
87
88 let sig = match DIR {
89 INTO => FuncValueType::new(this_ty, other_ty),
90 FROM => FuncValueType::new(other_ty, this_ty),
91 };
92 PolyFuncTypeRV::new(params, sig).into()
93 }
94}
95
96impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeOpDef
97 for GenericArrayConvertDef<AK, DIR, OtherAK>
98{
99 fn opdef_id(&self) -> OpName {
100 match DIR {
101 INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
102 FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
103 }
104 }
105 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
106 where
107 Self: Sized,
108 {
109 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
110 }
111
112 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
113 self.signature_from_def(AK::type_def())
114 }
115
116 fn extension_ref(&self) -> Weak<Extension> {
117 Arc::downgrade(AK::extension())
118 }
119
120 fn extension(&self) -> ExtensionId {
121 AK::EXTENSION_ID
122 }
123
124 fn description(&self) -> String {
125 match DIR {
126 INTO => format!("Turns `{}` into `{}`", AK::TYPE_NAME, OtherAK::TYPE_NAME),
127 FROM => format!("Turns `{}` into `{}`", OtherAK::TYPE_NAME, AK::TYPE_NAME),
128 }
129 }
130
131 fn add_to_extension(
137 &self,
138 extension: &mut Extension,
139 extension_ref: &Weak<Extension>,
140 ) -> Result<(), crate::extension::ExtensionBuildError> {
141 let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
142 let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
143 self.post_opdef(def);
144 Ok(())
145 }
146}
147
148#[derive(Clone, Debug, PartialEq)]
154pub struct GenericArrayConvert<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> {
155 pub elem_ty: Type,
157 pub size: u64,
159 _kind: PhantomData<AK>,
160 _other_kind: PhantomData<OtherAK>,
161}
162
163impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
164 GenericArrayConvert<AK, DIR, OtherAK>
165{
166 #[must_use]
168 pub fn new(elem_ty: Type, size: u64) -> Self {
169 GenericArrayConvert {
170 elem_ty,
171 size,
172 _kind: PhantomData,
173 _other_kind: PhantomData,
174 }
175 }
176}
177
178impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> NamedOp
179 for GenericArrayConvert<AK, DIR, OtherAK>
180{
181 fn name(&self) -> OpName {
182 match DIR {
183 INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
184 FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
185 }
186 }
187}
188
189impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeExtensionOp
190 for GenericArrayConvert<AK, DIR, OtherAK>
191{
192 fn op_id(&self) -> OpName {
193 GenericArrayConvertDef::<AK, DIR, OtherAK>::new().opdef_id()
194 }
195
196 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
197 where
198 Self: Sized,
199 {
200 let def = GenericArrayConvertDef::<AK, DIR, OtherAK>::from_def(ext_op.def())?;
201 def.instantiate(ext_op.args())
202 }
203
204 fn type_args(&self) -> Vec<TypeArg> {
205 vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()]
206 }
207}
208
209impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeRegisteredOp
210 for GenericArrayConvert<AK, DIR, OtherAK>
211{
212 fn extension_id(&self) -> ExtensionId {
213 AK::EXTENSION_ID
214 }
215
216 fn extension_ref(&self) -> Weak<Extension> {
217 Arc::downgrade(AK::extension())
218 }
219}
220
221impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasDef
222 for GenericArrayConvert<AK, DIR, OtherAK>
223{
224 type Def = GenericArrayConvertDef<AK, DIR, OtherAK>;
225}
226
227impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasConcrete
228 for GenericArrayConvertDef<AK, DIR, OtherAK>
229{
230 type Concrete = GenericArrayConvert<AK, DIR, OtherAK>;
231
232 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
233 match type_args {
234 [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => {
235 Ok(GenericArrayConvert::new(ty.clone(), *n))
236 }
237 _ => Err(SignatureError::InvalidTypeArgs.into()),
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use rstest::rstest;
245
246 use crate::extension::prelude::bool_t;
247 use crate::ops::{OpTrait, OpType};
248 use crate::std_extensions::collections::array::Array;
249 use crate::std_extensions::collections::borrow_array::BorrowArray;
250 use crate::std_extensions::collections::value_array::ValueArray;
251
252 use super::*;
253
254 #[rstest]
255 #[case(ValueArray, Array)]
256 #[case(BorrowArray, Array)]
257 fn test_convert_from_def<AK: ArrayKind, OtherAK: ArrayKind>(
258 #[case] _kind: AK,
259 #[case] _other_kind: OtherAK,
260 ) {
261 let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(bool_t(), 2);
262 let optype: OpType = op.clone().into();
263 let new_op: GenericArrayConvert<AK, FROM, OtherAK> = optype.cast().unwrap();
264 assert_eq!(new_op, op);
265 }
266
267 #[rstest]
268 #[case(ValueArray, Array)]
269 #[case(BorrowArray, Array)]
270 fn test_convert_into_def<AK: ArrayKind, OtherAK: ArrayKind>(
271 #[case] _kind: AK,
272 #[case] _other_kind: OtherAK,
273 ) {
274 let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(bool_t(), 2);
275 let optype: OpType = op.clone().into();
276 let new_op: GenericArrayConvert<AK, INTO, OtherAK> = optype.cast().unwrap();
277 assert_eq!(new_op, op);
278 }
279
280 #[rstest]
281 #[case(ValueArray, Array)]
282 #[case(BorrowArray, Array)]
283 fn test_convert_from<AK: ArrayKind, OtherAK: ArrayKind>(
284 #[case] _kind: AK,
285 #[case] _other_kind: OtherAK,
286 ) {
287 let size = 2;
288 let element_ty = bool_t();
289 let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(element_ty.clone(), size);
290 let optype: OpType = op.into();
291 let sig = optype.dataflow_signature().unwrap();
292 assert_eq!(
293 sig.io(),
294 (
295 &vec![OtherAK::ty(size, element_ty.clone())].into(),
296 &vec![AK::ty(size, element_ty.clone())].into(),
297 )
298 );
299 }
300
301 #[rstest]
302 #[case(ValueArray, Array)]
303 #[case(BorrowArray, Array)]
304 fn test_convert_into<AK: ArrayKind, OtherAK: ArrayKind>(
305 #[case] _kind: AK,
306 #[case] _other_kind: OtherAK,
307 ) {
308 let size = 2;
309 let element_ty = bool_t();
310 let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(element_ty.clone(), size);
311 let optype: OpType = op.into();
312 let sig = optype.dataflow_signature().unwrap();
313 assert_eq!(
314 sig.io(),
315 (
316 &vec![AK::ty(size, element_ty.clone())].into(),
317 &vec![OtherAK::ty(size, element_ty.clone())].into(),
318 )
319 );
320 }
321}