1use std::sync::Weak;
4
5use strum::IntoEnumIterator;
6
7use crate::ops::{ExtensionOp, OpName, OpNameRef};
8use crate::{Extension, ops::OpType, types::TypeArg};
9
10use super::{ExtensionBuildError, ExtensionId, OpDef, SignatureError, op_def::SignatureFunc};
11use delegate::delegate;
12use thiserror::Error;
13
14#[derive(Debug, Error, PartialEq, Clone)]
16#[error("{0}")]
17#[allow(missing_docs)]
18#[non_exhaustive]
19pub enum OpLoadError {
20 #[error("Op with name {0} is not a member of this set.")]
21 NotMember(String),
22 #[error("Type args invalid: {0}.")]
23 InvalidArgs(#[from] SignatureError),
24 #[error("OpDef belongs to extension {0}, expected {1}.")]
25 WrongExtension(ExtensionId, ExtensionId),
26}
27
28pub trait MakeOpDef {
36 fn opdef_id(&self) -> OpName;
45
46 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
48 where
49 Self: Sized;
50
51 fn extension(&self) -> ExtensionId;
53
54 fn extension_ref(&self) -> Weak<Extension>;
56
57 fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
63
64 fn signature(&self) -> SignatureFunc {
66 self.init_signature(&self.extension_ref())
67 }
68
69 fn description(&self) -> String {
71 self.opdef_id().to_string()
72 }
73
74 fn post_opdef(&self, _def: &mut OpDef) {}
76
77 fn add_to_extension(
83 &self,
84 extension: &mut Extension,
85 extension_ref: &Weak<Extension>,
86 ) -> Result<(), ExtensionBuildError> {
87 let def = extension.add_op(
88 self.opdef_id(),
89 self.description(),
90 self.init_signature(extension_ref),
91 extension_ref,
92 )?;
93
94 self.post_opdef(def);
95
96 Ok(())
97 }
98
99 fn load_all_ops(
105 extension: &mut Extension,
106 extension_ref: &Weak<Extension>,
107 ) -> Result<(), ExtensionBuildError>
108 where
109 Self: IntoEnumIterator,
110 {
111 for op in Self::iter() {
112 op.add_to_extension(extension, extension_ref)?;
113 }
114 Ok(())
115 }
116
117 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
119 where
120 Self: Sized + std::str::FromStr,
121 {
122 Self::from_extension_op(ext_op)
123 }
124}
125
126pub trait HasConcrete: MakeOpDef {
128 type Concrete: MakeExtensionOp;
130
131 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
133}
134
135pub trait HasDef: MakeExtensionOp {
137 type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
139
140 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
142 where
143 Self: Sized,
144 {
145 Self::from_extension_op(ext_op)
146 }
147}
148
149pub trait MakeExtensionOp {
152 fn op_id(&self) -> OpName;
158
159 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
161 where
162 Self: Sized;
163 #[must_use]
166 fn from_optype(op: &OpType) -> Option<Self>
167 where
168 Self: Sized,
169 {
170 let ext: &ExtensionOp = op.as_extension_op()?;
171 Self::from_extension_op(ext).ok()
172 }
173
174 fn type_args(&self) -> Vec<TypeArg>;
176
177 fn to_registered(
180 self,
181 extension_id: ExtensionId,
182 extension: Weak<Extension>,
183 ) -> RegisteredOp<Self>
184 where
185 Self: Sized,
186 {
187 RegisteredOp {
188 extension_id,
189 extension,
190 op: self,
191 }
192 }
193}
194
195impl<T: MakeOpDef> MakeExtensionOp for T {
197 fn op_id(&self) -> OpName {
198 self.opdef_id()
199 }
200
201 #[inline]
202 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
203 where
204 Self: Sized,
205 {
206 Self::from_def(ext_op.def())
207 }
208
209 #[inline]
210 fn type_args(&self) -> Vec<TypeArg> {
211 vec![]
212 }
213}
214
215pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
218where
219 T: std::str::FromStr + MakeOpDef,
220{
221 let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
222 let expected_extension = op.extension();
223 if def_extension != &expected_extension {
224 return Err(OpLoadError::WrongExtension(
225 def_extension.clone(),
226 expected_extension,
227 ));
228 }
229
230 Ok(op)
231}
232
233#[derive(Clone, Debug)]
236pub struct RegisteredOp<T> {
237 pub extension_id: ExtensionId,
239 extension: Weak<Extension>,
241 op: T,
243}
244
245impl<T> RegisteredOp<T> {
246 pub fn to_inner(self) -> T {
248 self.op
249 }
250}
251
252impl<T: MakeExtensionOp> RegisteredOp<T> {
253 pub fn to_extension_op(&self) -> Option<ExtensionOp> {
255 ExtensionOp::new(
256 self.extension.upgrade()?.get_op(&self.op_id())?.clone(),
257 self.type_args(),
258 )
259 .ok()
260 }
261
262 delegate! {
263 to self.op {
264 pub fn op_id(&self) -> OpName;
266 pub fn type_args(&self) -> Vec<TypeArg>;
268 }
269 }
270}
271
272pub trait MakeRegisteredOp: MakeExtensionOp {
276 fn extension_id(&self) -> ExtensionId;
278 fn extension_ref(&self) -> Weak<Extension>;
280
281 fn to_extension_op(self) -> Option<ExtensionOp>
284 where
285 Self: Sized,
286 {
287 let registered: RegisteredOp<_> = self.into();
288 registered.to_extension_op()
289 }
290}
291
292impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
293 fn from(ext_op: T) -> Self {
294 let extension_id = ext_op.extension_id();
295 let extension = ext_op.extension_ref();
296 ext_op.to_registered(extension_id, extension)
297 }
298}
299
300impl<T: MakeRegisteredOp> From<T> for OpType {
301 fn from(ext_op: T) -> Self {
303 ext_op.to_extension_op().unwrap().into()
304 }
305}
306
307#[cfg(test)]
308mod test {
309 use std::sync::{Arc, LazyLock};
310
311 use crate::{
312 const_extension_ids, type_row,
313 types::{Signature, Term},
314 };
315
316 use super::*;
317 use strum::{EnumIter, EnumString, IntoStaticStr};
318
319 #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
320 enum DummyEnum {
321 Dumb,
322 }
323
324 impl MakeOpDef for DummyEnum {
325 fn opdef_id(&self) -> OpName {
326 <&'static str>::from(self).into()
327 }
328
329 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
330 Signature::new_endo(type_row![]).into()
331 }
332
333 fn extension_ref(&self) -> Weak<Extension> {
334 Arc::downgrade(&EXT)
335 }
336
337 fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
338 Ok(Self::Dumb)
339 }
340
341 fn extension(&self) -> ExtensionId {
342 EXT_ID.clone()
343 }
344 }
345
346 impl HasConcrete for DummyEnum {
347 type Concrete = Self;
348
349 fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
350 if _type_args.is_empty() {
351 Ok(self.clone())
352 } else {
353 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
354 }
355 }
356 }
357 const_extension_ids! {
358 const EXT_ID: ExtensionId = "DummyExt";
359 }
360
361 static EXT: LazyLock<Arc<Extension>> = LazyLock::new(|| {
362 Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
363 DummyEnum::Dumb
364 .add_to_extension(ext, extension_ref)
365 .unwrap();
366 })
367 });
368
369 impl MakeRegisteredOp for DummyEnum {
370 fn extension_id(&self) -> ExtensionId {
371 EXT_ID.clone()
372 }
373
374 fn extension_ref(&self) -> Weak<Extension> {
375 Arc::downgrade(&EXT)
376 }
377 }
378
379 #[test]
380 fn test_dummy_enum() {
381 let o = DummyEnum::Dumb;
382
383 assert_eq!(
384 DummyEnum::from_def(EXT.get_op(&o.opdef_id()).unwrap()).unwrap(),
385 o
386 );
387
388 assert_eq!(
389 DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
390 o
391 );
392 let registered: RegisteredOp<_> = o.clone().into();
393 assert_eq!(registered.to_inner(), o);
394
395 assert_eq!(o.instantiate(&[]), Ok(o.clone()));
396 assert_eq!(
397 o.instantiate(&[Term::from(1u64)]),
398 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
399 );
400 }
401}