1use std::sync::{Arc, Weak};
4
5use strum::IntoEnumIterator;
6
7use crate::ops::{ExtensionOp, OpName, OpNameRef};
8use crate::{Extension, ops::OpType, types::TypeArg};
9
10use super::{ExtensionBuildError, ExtensionId, OpDef, SignatureError, op_def::SignatureFunc};
11use delegate::delegate;
12use thiserror::Error;
13
14#[derive(Debug, Error, PartialEq, Clone)]
16#[error("{0}")]
17#[allow(missing_docs)]
18#[non_exhaustive]
19pub enum OpLoadError {
20 #[error("Op with name {0} is not a member of this set.")]
21 NotMember(String),
22 #[error("Type args invalid: {0}.")]
23 InvalidArgs(#[from] SignatureError),
24 #[error("OpDef belongs to extension {0}, expected {1}.")]
25 WrongExtension(ExtensionId, ExtensionId),
26}
27
28pub trait MakeOpDef {
36 fn opdef_id(&self) -> OpName;
45
46 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
48 where
49 Self: Sized;
50
51 fn extension(&self) -> ExtensionId;
53
54 fn extension_ref(&self) -> Weak<Extension>;
56
57 fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
63
64 fn signature(&self) -> SignatureFunc {
66 self.init_signature(&self.extension_ref())
67 }
68
69 fn description(&self) -> String {
71 self.opdef_id().to_string()
72 }
73
74 fn post_opdef(&self, _def: &mut OpDef) {}
76
77 fn add_to_extension(
83 &self,
84 extension: &mut Extension,
85 extension_ref: &Weak<Extension>,
86 ) -> Result<(), ExtensionBuildError> {
87 let def = extension.add_op(
88 self.opdef_id(),
89 self.description(),
90 self.init_signature(extension_ref),
91 extension_ref,
92 )?;
93
94 self.post_opdef(def);
95
96 Ok(())
97 }
98
99 fn load_all_ops(
105 extension: &mut Extension,
106 extension_ref: &Weak<Extension>,
107 ) -> Result<(), ExtensionBuildError>
108 where
109 Self: IntoEnumIterator,
110 {
111 for op in Self::iter() {
112 op.add_to_extension(extension, extension_ref)?;
113 }
114 Ok(())
115 }
116
117 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
119 where
120 Self: Sized + std::str::FromStr,
121 {
122 Self::from_extension_op(ext_op)
123 }
124}
125
126pub trait HasConcrete: MakeOpDef {
128 type Concrete: MakeExtensionOp;
130
131 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
133}
134
135pub trait HasDef: MakeExtensionOp {
137 type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
139
140 fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
142 where
143 Self: Sized,
144 {
145 Self::from_extension_op(ext_op)
146 }
147}
148
149pub trait MakeExtensionOp {
152 fn op_id(&self) -> OpName;
158
159 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
161 where
162 Self: Sized;
163 #[must_use]
166 fn from_optype(op: &OpType) -> Option<Self>
167 where
168 Self: Sized,
169 {
170 let ext: &ExtensionOp = op.as_extension_op()?;
171 Self::from_extension_op(ext).ok()
172 }
173
174 fn type_args(&self) -> Vec<TypeArg>;
176
177 fn to_registered(
180 self,
181 extension_id: ExtensionId,
182 extension: Arc<Extension>,
183 ) -> RegisteredOp<Self>
184 where
185 Self: Sized,
186 {
187 RegisteredOp {
188 extension_id,
189 extension,
190 op: self,
191 }
192 }
193}
194
195impl<T: MakeOpDef> MakeExtensionOp for T {
197 fn op_id(&self) -> OpName {
198 self.opdef_id()
199 }
200
201 #[inline]
202 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
203 where
204 Self: Sized,
205 {
206 Self::from_def(ext_op.def())
207 }
208
209 #[inline]
210 fn type_args(&self) -> Vec<TypeArg> {
211 vec![]
212 }
213}
214
215pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
218where
219 T: std::str::FromStr + MakeOpDef,
220{
221 let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
222 let expected_extension = op.extension();
223 if def_extension != &expected_extension {
224 return Err(OpLoadError::WrongExtension(
225 def_extension.clone(),
226 expected_extension,
227 ));
228 }
229
230 Ok(op)
231}
232
233#[derive(Clone, Debug)]
236pub struct RegisteredOp<T> {
237 pub extension_id: ExtensionId,
239 extension: Arc<Extension>,
241 op: T,
243}
244
245impl<T> RegisteredOp<T> {
246 pub fn to_inner(self) -> T {
248 self.op
249 }
250}
251
252impl<T: MakeExtensionOp> RegisteredOp<T> {
253 pub fn to_extension_op(&self) -> Result<ExtensionOp, SignatureError> {
255 let op_def = self.extension.get_op(&self.op_id()).unwrap_or_else(|| {
256 panic!(
257 "Extension::get_op() called with an invalid name ({}).",
258 self.op_id()
259 )
260 });
261 ExtensionOp::new(op_def.clone(), self.type_args())
262 }
263
264 delegate! {
265 to self.op {
266 pub fn op_id(&self) -> OpName;
268 pub fn type_args(&self) -> Vec<TypeArg>;
270 }
271 }
272}
273
274pub trait MakeRegisteredOp: MakeExtensionOp {
278 fn extension_id(&self) -> ExtensionId;
280 fn extension_ref(&self) -> Arc<Extension>;
282
283 fn to_extension_op(self) -> Result<ExtensionOp, SignatureError>
286 where
287 Self: Sized,
288 {
289 let registered: RegisteredOp<_> = self.into();
290 registered.to_extension_op()
291 }
292}
293
294impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
295 fn from(ext_op: T) -> Self {
296 let extension_id = ext_op.extension_id();
297 let extension = ext_op.extension_ref();
298 ext_op.to_registered(extension_id, extension)
299 }
300}
301
302impl<T: MakeRegisteredOp> From<T> for OpType {
303 fn from(ext_op: T) -> Self {
305 ext_op.to_extension_op().unwrap().into()
306 }
307}
308
309#[cfg(test)]
310mod test {
311 use std::sync::{Arc, LazyLock};
312
313 use crate::{
314 const_extension_ids, type_row,
315 types::{Signature, Term},
316 };
317
318 use super::*;
319 use strum::{EnumIter, EnumString, IntoStaticStr};
320
321 #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
322 enum DummyEnum {
323 Dumb,
324 }
325
326 impl MakeOpDef for DummyEnum {
327 fn opdef_id(&self) -> OpName {
328 <&'static str>::from(self).into()
329 }
330
331 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
332 Signature::new_endo(type_row![]).into()
333 }
334
335 fn extension_ref(&self) -> Weak<Extension> {
336 Arc::downgrade(&EXT)
337 }
338
339 fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
340 Ok(Self::Dumb)
341 }
342
343 fn extension(&self) -> ExtensionId {
344 EXT_ID.clone()
345 }
346 }
347
348 impl HasConcrete for DummyEnum {
349 type Concrete = Self;
350
351 fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
352 if _type_args.is_empty() {
353 Ok(self.clone())
354 } else {
355 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
356 }
357 }
358 }
359 const_extension_ids! {
360 const EXT_ID: ExtensionId = "DummyExt";
361 }
362
363 static EXT: LazyLock<Arc<Extension>> = LazyLock::new(|| {
364 Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
365 DummyEnum::Dumb
366 .add_to_extension(ext, extension_ref)
367 .unwrap();
368 })
369 });
370
371 impl MakeRegisteredOp for DummyEnum {
372 fn extension_id(&self) -> ExtensionId {
373 EXT_ID.clone()
374 }
375
376 fn extension_ref(&self) -> Arc<Extension> {
377 EXT.clone()
378 }
379 }
380
381 #[test]
382 fn test_dummy_enum() {
383 let o = DummyEnum::Dumb;
384
385 assert_eq!(
386 DummyEnum::from_def(EXT.get_op(&o.opdef_id()).unwrap()).unwrap(),
387 o
388 );
389
390 assert_eq!(
391 DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
392 o
393 );
394 let registered: RegisteredOp<_> = o.clone().into();
395 assert_eq!(registered.to_inner(), o);
396
397 assert_eq!(o.instantiate(&[]), Ok(o.clone()));
398 assert_eq!(
399 o.instantiate(&[Term::from(1u64)]),
400 Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
401 );
402 }
403}