1use std::sync::Weak;
4
5use strum::IntoEnumIterator;
6
7use crate::ops::{ExtensionOp, OpName, OpNameRef};
8use crate::{
9 ops::{NamedOp, OpType},
10 types::TypeArg,
11 Extension,
12};
13
14use super::{op_def::SignatureFunc, ExtensionBuildError, ExtensionId, OpDef, SignatureError};
15use delegate::delegate;
16use thiserror::Error;
17
18#[derive(Debug, Error, PartialEq, Clone)]
20#[error("{0}")]
21#[allow(missing_docs)]
22#[non_exhaustive]
23pub enum OpLoadError {
24 #[error("Op with name {0} is not a member of this set.")]
25 NotMember(String),
26 #[error("Type args invalid: {0}.")]
27 InvalidArgs(#[from] SignatureError),
28 #[error("OpDef belongs to extension {0}, expected {1}.")]
29 WrongExtension(ExtensionId, ExtensionId),
30}
31
32impl<T> NamedOp for T
33where
34 for<'a> &'a T: Into<&'static str>,
35{
36 fn name(&self) -> OpName {
37 let s = self.into();
38 s.into()
39 }
40}
41
42pub trait MakeOpDef: NamedOp {
48 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
50 where
51 Self: Sized;
52
53 fn extension(&self) -> ExtensionId;
55
56 fn extension_ref(&self) -> Weak<Extension>;
58
59 fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
65
66 fn signature(&self) -> SignatureFunc {
68 self.init_signature(&self.extension_ref())
69 }
70
71 fn description(&self) -> String {
73 self.name().to_string()
74 }
75
76 fn post_opdef(&self, _def: &mut OpDef) {}
78
79 fn add_to_extension(
85 &self,
86 extension: &mut Extension,
87 extension_ref: &Weak<Extension>,
88 ) -> Result<(), ExtensionBuildError> {
89 let def = extension.add_op(
90 self.name(),
91 self.description(),
92 self.init_signature(extension_ref),
93 extension_ref,
94 )?;
95
96 self.post_opdef(def);
97
98 Ok(())
99 }
100
101 fn load_all_ops(
107 extension: &mut Extension,
108 extension_ref: &Weak<Extension>,
109 ) -> Result<(), ExtensionBuildError>
110 where
111 Self: IntoEnumIterator,
112 {
113 for op in Self::iter() {
114 op.add_to_extension(extension, extension_ref)?;
115 }
116 Ok(())
117 }
118
119 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
121 where
122 Self: Sized + std::str::FromStr,
123 {
124 Self::from_extension_op(ext_op)
125 }
126}
127
128pub trait HasConcrete: MakeOpDef {
130 type Concrete: MakeExtensionOp;
132
133 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
135}
136
137pub trait HasDef: MakeExtensionOp {
139 type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
141
142 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
144 where
145 Self: Sized,
146 {
147 Self::from_extension_op(ext_op)
148 }
149}
150
151pub trait MakeExtensionOp: NamedOp {
154 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
156 where
157 Self: Sized;
158 fn from_optype(op: &OpType) -> Option<Self>
161 where
162 Self: Sized,
163 {
164 let ext: &ExtensionOp = op.as_extension_op()?;
165 Self::from_extension_op(ext).ok()
166 }
167
168 fn type_args(&self) -> Vec<TypeArg>;
170
171 fn to_registered(
174 self,
175 extension_id: ExtensionId,
176 extension: Weak<Extension>,
177 ) -> RegisteredOp<Self>
178 where
179 Self: Sized,
180 {
181 RegisteredOp {
182 extension_id,
183 extension,
184 op: self,
185 }
186 }
187}
188
189impl<T: MakeOpDef> MakeExtensionOp for T {
191 #[inline]
192 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
193 where
194 Self: Sized,
195 {
196 Self::from_def(ext_op.def())
197 }
198
199 #[inline]
200 fn type_args(&self) -> Vec<TypeArg> {
201 vec![]
202 }
203}
204
205pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
208where
209 T: std::str::FromStr + MakeOpDef,
210{
211 let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
212 let expected_extension = op.extension();
213 if def_extension != &expected_extension {
214 return Err(OpLoadError::WrongExtension(
215 def_extension.clone(),
216 expected_extension,
217 ));
218 }
219
220 Ok(op)
221}
222
223#[derive(Clone, Debug)]
226pub struct RegisteredOp<T> {
227 pub extension_id: ExtensionId,
229 extension: Weak<Extension>,
231 op: T,
233}
234
235impl<T> RegisteredOp<T> {
236 pub fn to_inner(self) -> T {
238 self.op
239 }
240}
241
242impl<T: MakeExtensionOp> RegisteredOp<T> {
243 pub fn to_extension_op(&self) -> Option<ExtensionOp> {
245 ExtensionOp::new(
246 self.extension.upgrade()?.get_op(&self.name())?.clone(),
247 self.type_args(),
248 )
249 .ok()
250 }
251
252 delegate! {
253 to self.op {
254 pub fn name(&self) -> OpName;
256 pub fn type_args(&self) -> Vec<TypeArg>;
258 }
259 }
260}
261
262pub trait MakeRegisteredOp: MakeExtensionOp {
266 fn extension_id(&self) -> ExtensionId;
268 fn extension_ref(&self) -> Weak<Extension>;
270
271 fn to_extension_op(self) -> Option<ExtensionOp>
274 where
275 Self: Sized,
276 {
277 let registered: RegisteredOp<_> = self.into();
278 registered.to_extension_op()
279 }
280}
281
282impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
283 fn from(ext_op: T) -> Self {
284 let extension_id = ext_op.extension_id();
285 let extension = ext_op.extension_ref();
286 ext_op.to_registered(extension_id, extension)
287 }
288}
289
290impl<T: MakeRegisteredOp> From<T> for OpType {
291 fn from(ext_op: T) -> Self {
293 ext_op.to_extension_op().unwrap().into()
294 }
295}
296
297#[cfg(test)]
298mod test {
299 use std::sync::Arc;
300
301 use crate::{const_extension_ids, type_row, types::Signature};
302
303 use super::*;
304 use lazy_static::lazy_static;
305 use strum::{EnumIter, EnumString, IntoStaticStr};
306
307 #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
308 enum DummyEnum {
309 Dumb,
310 }
311
312 impl MakeOpDef for DummyEnum {
313 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
314 Signature::new_endo(type_row![]).into()
315 }
316
317 fn extension_ref(&self) -> Weak<Extension> {
318 Arc::downgrade(&EXT)
319 }
320
321 fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
322 Ok(Self::Dumb)
323 }
324
325 fn extension(&self) -> ExtensionId {
326 EXT_ID.to_owned()
327 }
328 }
329
330 impl HasConcrete for DummyEnum {
331 type Concrete = Self;
332
333 fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
334 if _type_args.is_empty() {
335 Ok(self.clone())
336 } else {
337 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
338 }
339 }
340 }
341 const_extension_ids! {
342 const EXT_ID: ExtensionId = "DummyExt";
343 }
344
345 lazy_static! {
346 static ref EXT: Arc<Extension> = {
347 Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
348 DummyEnum::Dumb
349 .add_to_extension(ext, extension_ref)
350 .unwrap();
351 })
352 };
353 }
354 impl MakeRegisteredOp for DummyEnum {
355 fn extension_id(&self) -> ExtensionId {
356 EXT_ID.to_owned()
357 }
358
359 fn extension_ref(&self) -> Weak<Extension> {
360 Arc::downgrade(&EXT)
361 }
362 }
363
364 #[test]
365 fn test_dummy_enum() {
366 let o = DummyEnum::Dumb;
367
368 assert_eq!(
369 DummyEnum::from_def(EXT.get_op(&o.name()).unwrap()).unwrap(),
370 o
371 );
372
373 assert_eq!(
374 DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
375 o
376 );
377 let registered: RegisteredOp<_> = o.clone().into();
378 assert_eq!(registered.to_inner(), o);
379
380 assert_eq!(o.instantiate(&[]), Ok(o.clone()));
381 assert_eq!(
382 o.instantiate(&[TypeArg::BoundedNat { n: 1 }]),
383 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
384 );
385 }
386}