1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef,
10 crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
11 ::proptest::prelude::*,
12 ::proptest_derive::Arbitrary,
13};
14
15use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
16use crate::types::{type_param::TypeArg, Signature};
17use crate::{ops, IncomingPort, Node};
18
19use super::dataflow::DataflowOpTrait;
20use super::tag::OpTag;
21use super::{NamedOp, OpName, OpNameRef};
22
23#[derive(Clone, Debug, serde::Serialize)]
30#[serde(into = "OpaqueOp")]
31#[cfg_attr(test, derive(Arbitrary))]
32pub struct ExtensionOp {
33 #[cfg_attr(
34 test,
35 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
36 )]
37 def: Arc<OpDef>,
38 args: Vec<TypeArg>,
39 signature: Signature, }
41
42impl ExtensionOp {
43 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
45 let args: Vec<TypeArg> = args.into();
46 let signature = def.compute_signature(&args)?;
47 Ok(Self {
48 def,
49 args,
50 signature,
51 })
52 }
53
54 pub(crate) fn new_with_cached(
56 def: Arc<OpDef>,
57 args: impl IntoIterator<Item = TypeArg>,
58 opaque: &OpaqueOp,
59 ) -> Result<Self, SignatureError> {
60 let args: Vec<TypeArg> = args.into_iter().collect();
61 let signature = match def.compute_signature(&args) {
64 Ok(sig) => sig,
65 Err(SignatureError::MissingComputeFunc) => {
66 opaque.signature().into_owned()
68 }
69 Err(e) => return Err(e),
70 };
71 Ok(Self {
72 def,
73 args,
74 signature,
75 })
76 }
77
78 pub fn args(&self) -> &[TypeArg] {
80 &self.args
81 }
82
83 pub fn def(&self) -> &OpDef {
85 self.def.as_ref()
86 }
87
88 pub fn def_arc(&self) -> &Arc<OpDef> {
91 &self.def
92 }
93
94 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
96 self.def().constant_fold(self.args(), consts)
97 }
98
99 pub fn make_opaque(&self) -> OpaqueOp {
108 OpaqueOp {
109 extension: self.def.extension_id().clone(),
110 name: self.def.name().clone(),
111 description: self.def.description().into(),
112 args: self.args.clone(),
113 signature: self.signature.clone(),
114 }
115 }
116
117 pub fn signature_mut(&mut self) -> &mut Signature {
119 &mut self.signature
120 }
121
122 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
124 self.args.as_mut_slice()
125 }
126}
127
128impl From<ExtensionOp> for OpaqueOp {
129 fn from(op: ExtensionOp) -> Self {
130 let ExtensionOp {
131 def,
132 args,
133 signature,
134 } = op;
135 OpaqueOp {
136 extension: def.extension_id().clone(),
137 name: def.name().clone(),
138 description: def.description().into(),
139 args,
140 signature,
141 }
142 }
143}
144
145impl PartialEq for ExtensionOp {
146 fn eq(&self, other: &Self) -> bool {
147 Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
148 }
149}
150
151impl Eq for ExtensionOp {}
152
153impl NamedOp for ExtensionOp {
154 fn name(&self) -> OpName {
156 qualify_name(self.def.extension_id(), self.def.name())
157 }
158}
159
160impl DataflowOpTrait for ExtensionOp {
161 const TAG: OpTag = OpTag::Leaf;
162
163 fn description(&self) -> &str {
164 self.def().description()
165 }
166
167 fn signature(&self) -> Cow<'_, Signature> {
168 Cow::Borrowed(&self.signature)
169 }
170
171 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
172 let args = self
173 .args
174 .iter()
175 .map(|ta| ta.substitute(subst))
176 .collect::<Vec<_>>();
177 let signature = self.signature.substitute(subst);
178 Self {
179 def: self.def.clone(),
180 args,
181 signature,
182 }
183 }
184}
185
186#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
196#[cfg_attr(test, derive(Arbitrary))]
197pub struct OpaqueOp {
198 extension: ExtensionId,
199 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
200 name: OpName,
201 #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
202 description: String, args: Vec<TypeArg>,
204 signature: Signature,
208}
209
210fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
211 format!("{}.{}", res_id, name).into()
212}
213
214impl OpaqueOp {
215 pub fn new(
217 extension: ExtensionId,
218 name: impl Into<OpName>,
219 description: String,
220 args: impl Into<Vec<TypeArg>>,
221 signature: Signature,
222 ) -> Self {
223 let signature = signature.with_extension_delta(extension.clone());
224 Self {
225 extension,
226 name: name.into(),
227 description,
228 args: args.into(),
229 signature,
230 }
231 }
232
233 pub fn signature_mut(&mut self) -> &mut Signature {
235 &mut self.signature
236 }
237}
238
239impl NamedOp for OpaqueOp {
240 fn name(&self) -> OpName {
242 qualify_name(&self.extension, &self.name)
243 }
244}
245impl OpaqueOp {
246 pub fn op_name(&self) -> &OpName {
248 &self.name
249 }
250
251 pub fn args(&self) -> &[TypeArg] {
253 &self.args
254 }
255
256 pub fn extension(&self) -> &ExtensionId {
258 &self.extension
259 }
260
261 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
263 self.args.as_mut_slice()
264 }
265}
266
267impl DataflowOpTrait for OpaqueOp {
268 const TAG: OpTag = OpTag::Leaf;
269
270 fn description(&self) -> &str {
271 &self.description
272 }
273
274 fn signature(&self) -> Cow<'_, Signature> {
275 Cow::Borrowed(&self.signature)
276 }
277
278 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
279 Self {
280 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
281 signature: self.signature.substitute(subst),
282 ..self.clone()
283 }
284 }
285}
286
287#[derive(Clone, Debug, Error, PartialEq)]
290#[non_exhaustive]
291pub enum OpaqueOpError {
292 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
294 available_ops.iter().join(", ")
295 )]
296 OpNotFoundInExtension {
297 node: Node,
299 op: OpName,
301 extension: ExtensionId,
303 available_ops: Vec<OpName>,
305 },
306 #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
308 #[allow(missing_docs)]
309 SignatureMismatch {
310 node: Node,
311 extension: ExtensionId,
312 op: OpName,
313 stored: Signature,
314 computed: Signature,
315 },
316 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
318 #[allow(missing_docs)]
319 SignatureError {
320 node: Node,
321 name: OpName,
322 #[source]
323 cause: SignatureError,
324 },
325 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
327 UnresolvedOp(Node, OpName, ExtensionId),
328 #[error("Error updating extension registry: {0}")]
330 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
331}
332
333#[cfg(test)]
334mod test {
335
336 use ops::OpType;
337
338 use crate::extension::resolution::resolve_op_extensions;
339 use crate::extension::ExtensionRegistry;
340 use crate::std_extensions::arithmetic::conversions::{self};
341 use crate::std_extensions::STD_REG;
342 use crate::{
343 extension::{
344 prelude::{bool_t, qb_t, usize_t},
345 SignatureFunc,
346 },
347 std_extensions::arithmetic::int_types::INT_TYPES,
348 types::FuncValueType,
349 Extension,
350 };
351
352 use super::*;
353
354 fn resolve_res_definition(res: &OpType) -> &OpDef {
356 res.as_extension_op().unwrap().def()
357 }
358
359 #[test]
360 fn new_opaque_op() {
361 let sig = Signature::new_endo(vec![qb_t()]);
362 let op = OpaqueOp::new(
363 "res".try_into().unwrap(),
364 "op",
365 "desc".into(),
366 vec![TypeArg::Type { ty: usize_t() }],
367 sig.clone(),
368 );
369 assert_eq!(op.name(), "res.op");
370 assert_eq!(DataflowOpTrait::description(&op), "desc");
371 assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
372 assert_eq!(
373 op.signature().as_ref(),
374 &sig.with_extension_delta(op.extension().clone())
375 );
376 }
377
378 #[test]
379 fn resolve_opaque_op() {
380 let registry = &STD_REG;
381 let i0 = &INT_TYPES[0];
382 let opaque = OpaqueOp::new(
383 conversions::EXTENSION_ID,
384 "itobool",
385 "description".into(),
386 vec![],
387 Signature::new(i0.clone(), bool_t()),
388 );
389 let mut resolved = opaque.into();
390 resolve_op_extensions(
391 Node::from(portgraph::NodeIndex::new(1)),
392 &mut resolved,
393 registry,
394 )
395 .unwrap();
396 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
397 }
398
399 #[test]
400 fn resolve_missing() {
401 let val_name = "missing_val";
402 let comp_name = "missing_comp";
403 let endo_sig = Signature::new_endo(bool_t());
404
405 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
406 ext.add_op(
407 val_name.into(),
408 "".to_string(),
409 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
410 extension_ref,
411 )
412 .unwrap();
413
414 ext.add_op(
415 comp_name.into(),
416 "".to_string(),
417 SignatureFunc::MissingComputeFunc,
418 extension_ref,
419 )
420 .unwrap();
421 });
422 let ext_id = ext.name().clone();
423
424 let registry = ExtensionRegistry::new([ext]);
425 registry.validate().unwrap();
426 let opaque_val = OpaqueOp::new(
427 ext_id.clone(),
428 val_name,
429 "".into(),
430 vec![],
431 endo_sig.clone(),
432 );
433 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
434 let mut resolved_val = opaque_val.into();
435 resolve_op_extensions(
436 Node::from(portgraph::NodeIndex::new(1)),
437 &mut resolved_val,
438 ®istry,
439 )
440 .unwrap();
441 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
442
443 let mut resolved_comp = opaque_comp.into();
444 resolve_op_extensions(
445 Node::from(portgraph::NodeIndex::new(2)),
446 &mut resolved_comp,
447 ®istry,
448 )
449 .unwrap();
450 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
451 }
452}