1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef,
10 crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
11 ::proptest::prelude::*,
12 ::proptest_derive::Arbitrary,
13};
14
15use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
16use crate::types::{type_param::TypeArg, Signature};
17use crate::{ops, IncomingPort, Node};
18
19use super::dataflow::DataflowOpTrait;
20use super::tag::OpTag;
21use super::{NamedOp, OpName, OpNameRef};
22
23#[derive(Clone, Debug, serde::Serialize)]
30#[serde(into = "OpaqueOp")]
31#[cfg_attr(test, derive(Arbitrary))]
32pub struct ExtensionOp {
33 #[cfg_attr(
34 test,
35 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
36 )]
37 def: Arc<OpDef>,
38 args: Vec<TypeArg>,
39 signature: Signature, }
41
42impl ExtensionOp {
43 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
45 let args: Vec<TypeArg> = args.into();
46 let signature = def.compute_signature(&args)?;
47 Ok(Self {
48 def,
49 args,
50 signature,
51 })
52 }
53
54 pub(crate) fn new_with_cached(
56 def: Arc<OpDef>,
57 args: impl IntoIterator<Item = TypeArg>,
58 opaque: &OpaqueOp,
59 ) -> Result<Self, SignatureError> {
60 let args: Vec<TypeArg> = args.into_iter().collect();
61 let signature = match def.compute_signature(&args) {
64 Ok(sig) => sig,
65 Err(SignatureError::MissingComputeFunc) => {
66 opaque.signature().into_owned()
68 }
69 Err(e) => return Err(e),
70 };
71 Ok(Self {
72 def,
73 args,
74 signature,
75 })
76 }
77
78 pub fn args(&self) -> &[TypeArg] {
80 &self.args
81 }
82
83 pub fn def(&self) -> &OpDef {
85 self.def.as_ref()
86 }
87
88 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
90 self.def().constant_fold(self.args(), consts)
91 }
92
93 pub fn make_opaque(&self) -> OpaqueOp {
102 OpaqueOp {
103 extension: self.def.extension_id().clone(),
104 name: self.def.name().clone(),
105 description: self.def.description().into(),
106 args: self.args.clone(),
107 signature: self.signature.clone(),
108 }
109 }
110
111 pub fn signature_mut(&mut self) -> &mut Signature {
113 &mut self.signature
114 }
115
116 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
118 self.args.as_mut_slice()
119 }
120}
121
122impl From<ExtensionOp> for OpaqueOp {
123 fn from(op: ExtensionOp) -> Self {
124 let ExtensionOp {
125 def,
126 args,
127 signature,
128 } = op;
129 OpaqueOp {
130 extension: def.extension_id().clone(),
131 name: def.name().clone(),
132 description: def.description().into(),
133 args,
134 signature,
135 }
136 }
137}
138
139impl PartialEq for ExtensionOp {
140 fn eq(&self, other: &Self) -> bool {
141 Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
142 }
143}
144
145impl Eq for ExtensionOp {}
146
147impl NamedOp for ExtensionOp {
148 fn name(&self) -> OpName {
150 qualify_name(self.def.extension_id(), self.def.name())
151 }
152}
153
154impl DataflowOpTrait for ExtensionOp {
155 const TAG: OpTag = OpTag::Leaf;
156
157 fn description(&self) -> &str {
158 self.def().description()
159 }
160
161 fn signature(&self) -> Cow<'_, Signature> {
162 Cow::Borrowed(&self.signature)
163 }
164
165 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
166 let args = self
167 .args
168 .iter()
169 .map(|ta| ta.substitute(subst))
170 .collect::<Vec<_>>();
171 let signature = self.signature.substitute(subst);
172 Self {
173 def: self.def.clone(),
174 args,
175 signature,
176 }
177 }
178}
179
180#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
190#[cfg_attr(test, derive(Arbitrary))]
191pub struct OpaqueOp {
192 extension: ExtensionId,
193 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
194 name: OpName,
195 #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
196 description: String, args: Vec<TypeArg>,
198 signature: Signature,
202}
203
204fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
205 format!("{}.{}", res_id, name).into()
206}
207
208impl OpaqueOp {
209 pub fn new(
211 extension: ExtensionId,
212 name: impl Into<OpName>,
213 description: String,
214 args: impl Into<Vec<TypeArg>>,
215 signature: Signature,
216 ) -> Self {
217 let signature = signature.with_extension_delta(extension.clone());
218 Self {
219 extension,
220 name: name.into(),
221 description,
222 args: args.into(),
223 signature,
224 }
225 }
226
227 pub fn signature_mut(&mut self) -> &mut Signature {
229 &mut self.signature
230 }
231}
232
233impl NamedOp for OpaqueOp {
234 fn name(&self) -> OpName {
236 qualify_name(&self.extension, &self.name)
237 }
238}
239impl OpaqueOp {
240 pub fn op_name(&self) -> &OpName {
242 &self.name
243 }
244
245 pub fn args(&self) -> &[TypeArg] {
247 &self.args
248 }
249
250 pub fn extension(&self) -> &ExtensionId {
252 &self.extension
253 }
254
255 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
257 self.args.as_mut_slice()
258 }
259}
260
261impl DataflowOpTrait for OpaqueOp {
262 const TAG: OpTag = OpTag::Leaf;
263
264 fn description(&self) -> &str {
265 &self.description
266 }
267
268 fn signature(&self) -> Cow<'_, Signature> {
269 Cow::Borrowed(&self.signature)
270 }
271
272 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
273 Self {
274 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
275 signature: self.signature.substitute(subst),
276 ..self.clone()
277 }
278 }
279}
280
281#[derive(Clone, Debug, Error, PartialEq)]
284#[non_exhaustive]
285pub enum OpaqueOpError {
286 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
288 available_ops.iter().join(", ")
289 )]
290 OpNotFoundInExtension {
291 node: Node,
293 op: OpName,
295 extension: ExtensionId,
297 available_ops: Vec<OpName>,
299 },
300 #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
302 #[allow(missing_docs)]
303 SignatureMismatch {
304 node: Node,
305 extension: ExtensionId,
306 op: OpName,
307 stored: Signature,
308 computed: Signature,
309 },
310 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
312 #[allow(missing_docs)]
313 SignatureError {
314 node: Node,
315 name: OpName,
316 #[source]
317 cause: SignatureError,
318 },
319 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
321 UnresolvedOp(Node, OpName, ExtensionId),
322 #[error("Error updating extension registry: {0}")]
324 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
325}
326
327#[cfg(test)]
328mod test {
329
330 use ops::OpType;
331
332 use crate::extension::resolution::resolve_op_extensions;
333 use crate::extension::ExtensionRegistry;
334 use crate::std_extensions::arithmetic::conversions::{self};
335 use crate::std_extensions::STD_REG;
336 use crate::{
337 extension::{
338 prelude::{bool_t, qb_t, usize_t},
339 SignatureFunc,
340 },
341 std_extensions::arithmetic::int_types::INT_TYPES,
342 types::FuncValueType,
343 Extension,
344 };
345
346 use super::*;
347
348 fn resolve_res_definition(res: &OpType) -> &OpDef {
350 res.as_extension_op().unwrap().def()
351 }
352
353 #[test]
354 fn new_opaque_op() {
355 let sig = Signature::new_endo(vec![qb_t()]);
356 let op = OpaqueOp::new(
357 "res".try_into().unwrap(),
358 "op",
359 "desc".into(),
360 vec![TypeArg::Type { ty: usize_t() }],
361 sig.clone(),
362 );
363 assert_eq!(op.name(), "res.op");
364 assert_eq!(DataflowOpTrait::description(&op), "desc");
365 assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
366 assert_eq!(
367 op.signature().as_ref(),
368 &sig.with_extension_delta(op.extension().clone())
369 );
370 }
371
372 #[test]
373 fn resolve_opaque_op() {
374 let registry = &STD_REG;
375 let i0 = &INT_TYPES[0];
376 let opaque = OpaqueOp::new(
377 conversions::EXTENSION_ID,
378 "itobool",
379 "description".into(),
380 vec![],
381 Signature::new(i0.clone(), bool_t()),
382 );
383 let mut resolved = opaque.into();
384 resolve_op_extensions(
385 Node::from(portgraph::NodeIndex::new(1)),
386 &mut resolved,
387 registry,
388 )
389 .unwrap();
390 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
391 }
392
393 #[test]
394 fn resolve_missing() {
395 let val_name = "missing_val";
396 let comp_name = "missing_comp";
397 let endo_sig = Signature::new_endo(bool_t());
398
399 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
400 ext.add_op(
401 val_name.into(),
402 "".to_string(),
403 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
404 extension_ref,
405 )
406 .unwrap();
407
408 ext.add_op(
409 comp_name.into(),
410 "".to_string(),
411 SignatureFunc::MissingComputeFunc,
412 extension_ref,
413 )
414 .unwrap();
415 });
416 let ext_id = ext.name().clone();
417
418 let registry = ExtensionRegistry::new([ext]);
419 registry.validate().unwrap();
420 let opaque_val = OpaqueOp::new(
421 ext_id.clone(),
422 val_name,
423 "".into(),
424 vec![],
425 endo_sig.clone(),
426 );
427 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
428 let mut resolved_val = opaque_val.into();
429 resolve_op_extensions(
430 Node::from(portgraph::NodeIndex::new(1)),
431 &mut resolved_val,
432 ®istry,
433 )
434 .unwrap();
435 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
436
437 let mut resolved_comp = opaque_comp.into();
438 resolve_op_extensions(
439 Node::from(portgraph::NodeIndex::new(2)),
440 &mut resolved_comp,
441 ®istry,
442 )
443 .unwrap();
444 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
445 }
446}