1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef,
10 crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
11 ::proptest::prelude::*,
12 ::proptest_derive::Arbitrary,
13};
14
15use crate::extension::simple_op::MakeExtensionOp;
16use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
17use crate::types::{type_param::TypeArg, Signature};
18use crate::{ops, IncomingPort, Node};
19
20use super::dataflow::DataflowOpTrait;
21use super::tag::OpTag;
22use super::{NamedOp, OpName, OpNameRef};
23
24#[derive(Clone, Debug, serde::Serialize)]
31#[serde(into = "OpaqueOp")]
32#[cfg_attr(test, derive(Arbitrary))]
33pub struct ExtensionOp {
34 #[cfg_attr(
35 test,
36 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
37 )]
38 def: Arc<OpDef>,
39 args: Vec<TypeArg>,
40 signature: Signature, }
42
43impl ExtensionOp {
44 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
46 let args: Vec<TypeArg> = args.into();
47 let signature = def.compute_signature(&args)?;
48 Ok(Self {
49 def,
50 args,
51 signature,
52 })
53 }
54
55 pub(crate) fn new_with_cached(
57 def: Arc<OpDef>,
58 args: impl IntoIterator<Item = TypeArg>,
59 opaque: &OpaqueOp,
60 ) -> Result<Self, SignatureError> {
61 let args: Vec<TypeArg> = args.into_iter().collect();
62 let signature = match def.compute_signature(&args) {
65 Ok(sig) => sig,
66 Err(SignatureError::MissingComputeFunc) => {
67 opaque.signature().into_owned()
69 }
70 Err(e) => return Err(e),
71 };
72 Ok(Self {
73 def,
74 args,
75 signature,
76 })
77 }
78
79 pub fn args(&self) -> &[TypeArg] {
81 &self.args
82 }
83
84 pub fn def(&self) -> &OpDef {
86 self.def.as_ref()
87 }
88
89 pub fn def_arc(&self) -> &Arc<OpDef> {
92 &self.def
93 }
94
95 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
97 self.def().constant_fold(self.args(), consts)
98 }
99
100 pub fn make_opaque(&self) -> OpaqueOp {
109 OpaqueOp {
110 extension: self.def.extension_id().clone(),
111 name: self.def.name().clone(),
112 description: self.def.description().into(),
113 args: self.args.clone(),
114 signature: self.signature.clone(),
115 }
116 }
117
118 pub fn signature_mut(&mut self) -> &mut Signature {
120 &mut self.signature
121 }
122
123 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
125 self.args.as_mut_slice()
126 }
127
128 pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
132 T::from_extension_op(self).ok()
133 }
134
135 pub fn extension_id(&self) -> &ExtensionId {
137 self.def.extension_id()
138 }
139}
140
141impl From<ExtensionOp> for OpaqueOp {
142 fn from(op: ExtensionOp) -> Self {
143 let ExtensionOp {
144 def,
145 args,
146 signature,
147 } = op;
148 OpaqueOp {
149 extension: def.extension_id().clone(),
150 name: def.name().clone(),
151 description: def.description().into(),
152 args,
153 signature,
154 }
155 }
156}
157
158impl PartialEq for ExtensionOp {
159 fn eq(&self, other: &Self) -> bool {
160 Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
161 }
162}
163
164impl Eq for ExtensionOp {}
165
166impl NamedOp for ExtensionOp {
167 fn name(&self) -> OpName {
169 qualify_name(self.def.extension_id(), self.def.name())
170 }
171}
172
173impl DataflowOpTrait for ExtensionOp {
174 const TAG: OpTag = OpTag::Leaf;
175
176 fn description(&self) -> &str {
177 self.def().description()
178 }
179
180 fn signature(&self) -> Cow<'_, Signature> {
181 Cow::Borrowed(&self.signature)
182 }
183
184 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
185 let args = self
186 .args
187 .iter()
188 .map(|ta| ta.substitute(subst))
189 .collect::<Vec<_>>();
190 let signature = self.signature.substitute(subst);
191 Self {
192 def: self.def.clone(),
193 args,
194 signature,
195 }
196 }
197}
198
199#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
209#[cfg_attr(test, derive(Arbitrary))]
210pub struct OpaqueOp {
211 extension: ExtensionId,
212 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
213 name: OpName,
214 #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
215 description: String, args: Vec<TypeArg>,
217 signature: Signature,
221}
222
223fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
224 format!("{}.{}", res_id, name).into()
225}
226
227impl OpaqueOp {
228 pub fn new(
230 extension: ExtensionId,
231 name: impl Into<OpName>,
232 description: String,
233 args: impl Into<Vec<TypeArg>>,
234 signature: Signature,
235 ) -> Self {
236 let signature = signature.with_extension_delta(extension.clone());
237 Self {
238 extension,
239 name: name.into(),
240 description,
241 args: args.into(),
242 signature,
243 }
244 }
245
246 pub fn signature_mut(&mut self) -> &mut Signature {
248 &mut self.signature
249 }
250}
251
252impl NamedOp for OpaqueOp {
253 fn name(&self) -> OpName {
255 qualify_name(&self.extension, &self.name)
256 }
257}
258impl OpaqueOp {
259 pub fn op_name(&self) -> &OpName {
261 &self.name
262 }
263
264 pub fn args(&self) -> &[TypeArg] {
266 &self.args
267 }
268
269 pub fn extension(&self) -> &ExtensionId {
271 &self.extension
272 }
273
274 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
276 self.args.as_mut_slice()
277 }
278}
279
280impl DataflowOpTrait for OpaqueOp {
281 const TAG: OpTag = OpTag::Leaf;
282
283 fn description(&self) -> &str {
284 &self.description
285 }
286
287 fn signature(&self) -> Cow<'_, Signature> {
288 Cow::Borrowed(&self.signature)
289 }
290
291 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
292 Self {
293 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
294 signature: self.signature.substitute(subst),
295 ..self.clone()
296 }
297 }
298}
299
300#[derive(Clone, Debug, Error, PartialEq)]
303#[non_exhaustive]
304pub enum OpaqueOpError {
305 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
307 available_ops.iter().join(", ")
308 )]
309 OpNotFoundInExtension {
310 node: Node,
312 op: OpName,
314 extension: ExtensionId,
316 available_ops: Vec<OpName>,
318 },
319 #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
321 #[allow(missing_docs)]
322 SignatureMismatch {
323 node: Node,
324 extension: ExtensionId,
325 op: OpName,
326 stored: Signature,
327 computed: Signature,
328 },
329 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
331 #[allow(missing_docs)]
332 SignatureError {
333 node: Node,
334 name: OpName,
335 #[source]
336 cause: SignatureError,
337 },
338 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
340 UnresolvedOp(Node, OpName, ExtensionId),
341 #[error("Error updating extension registry: {0}")]
343 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
344}
345
346#[cfg(test)]
347mod test {
348
349 use ops::OpType;
350
351 use crate::extension::resolution::resolve_op_extensions;
352 use crate::extension::ExtensionRegistry;
353 use crate::std_extensions::arithmetic::conversions::{self};
354 use crate::std_extensions::STD_REG;
355 use crate::{
356 extension::{
357 prelude::{bool_t, qb_t, usize_t},
358 SignatureFunc,
359 },
360 std_extensions::arithmetic::int_types::INT_TYPES,
361 types::FuncValueType,
362 Extension,
363 };
364
365 use super::*;
366
367 fn resolve_res_definition(res: &OpType) -> &OpDef {
369 res.as_extension_op().unwrap().def()
370 }
371
372 #[test]
373 fn new_opaque_op() {
374 let sig = Signature::new_endo(vec![qb_t()]);
375 let op = OpaqueOp::new(
376 "res".try_into().unwrap(),
377 "op",
378 "desc".into(),
379 vec![TypeArg::Type { ty: usize_t() }],
380 sig.clone(),
381 );
382 assert_eq!(op.name(), "res.op");
383 assert_eq!(DataflowOpTrait::description(&op), "desc");
384 assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
385 assert_eq!(
386 op.signature().as_ref(),
387 &sig.with_extension_delta(op.extension().clone())
388 );
389 }
390
391 #[test]
392 fn resolve_opaque_op() {
393 let registry = &STD_REG;
394 let i0 = &INT_TYPES[0];
395 let opaque = OpaqueOp::new(
396 conversions::EXTENSION_ID,
397 "itobool",
398 "description".into(),
399 vec![],
400 Signature::new(i0.clone(), bool_t()),
401 );
402 let mut resolved = opaque.into();
403 resolve_op_extensions(
404 Node::from(portgraph::NodeIndex::new(1)),
405 &mut resolved,
406 registry,
407 )
408 .unwrap();
409 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
410 }
411
412 #[test]
413 fn resolve_missing() {
414 let val_name = "missing_val";
415 let comp_name = "missing_comp";
416 let endo_sig = Signature::new_endo(bool_t());
417
418 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
419 ext.add_op(
420 val_name.into(),
421 "".to_string(),
422 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
423 extension_ref,
424 )
425 .unwrap();
426
427 ext.add_op(
428 comp_name.into(),
429 "".to_string(),
430 SignatureFunc::MissingComputeFunc,
431 extension_ref,
432 )
433 .unwrap();
434 });
435 let ext_id = ext.name().clone();
436
437 let registry = ExtensionRegistry::new([ext]);
438 registry.validate().unwrap();
439 let opaque_val = OpaqueOp::new(
440 ext_id.clone(),
441 val_name,
442 "".into(),
443 vec![],
444 endo_sig.clone(),
445 );
446 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
447 let mut resolved_val = opaque_val.into();
448 resolve_op_extensions(
449 Node::from(portgraph::NodeIndex::new(1)),
450 &mut resolved_val,
451 ®istry,
452 )
453 .unwrap();
454 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
455
456 let mut resolved_comp = opaque_comp.into();
457 resolve_op_extensions(
458 Node::from(portgraph::NodeIndex::new(2)),
459 &mut resolved_comp,
460 ®istry,
461 )
462 .unwrap();
463 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
464 }
465}