1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef,
10 crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
11 ::proptest::prelude::*,
12 ::proptest_derive::Arbitrary,
13};
14
15use crate::core::HugrNode;
16use crate::extension::simple_op::MakeExtensionOp;
17use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
18use crate::types::{type_param::TypeArg, Signature};
19use crate::{ops, IncomingPort};
20
21use super::dataflow::DataflowOpTrait;
22use super::tag::OpTag;
23use super::{NamedOp, OpName, OpNameRef};
24
25#[derive(Clone, Debug, serde::Serialize)]
32#[serde(into = "OpaqueOp")]
33#[cfg_attr(test, derive(Arbitrary))]
34pub struct ExtensionOp {
35 #[cfg_attr(
36 test,
37 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
38 )]
39 def: Arc<OpDef>,
40 args: Vec<TypeArg>,
41 signature: Signature, }
43
44impl ExtensionOp {
45 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
47 let args: Vec<TypeArg> = args.into();
48 let signature = def.compute_signature(&args)?;
49 Ok(Self {
50 def,
51 args,
52 signature,
53 })
54 }
55
56 pub(crate) fn new_with_cached(
58 def: Arc<OpDef>,
59 args: impl IntoIterator<Item = TypeArg>,
60 opaque: &OpaqueOp,
61 ) -> Result<Self, SignatureError> {
62 let args: Vec<TypeArg> = args.into_iter().collect();
63 let signature = match def.compute_signature(&args) {
66 Ok(sig) => sig,
67 Err(SignatureError::MissingComputeFunc) => {
68 opaque.signature().into_owned()
70 }
71 Err(e) => return Err(e),
72 };
73 Ok(Self {
74 def,
75 args,
76 signature,
77 })
78 }
79
80 pub fn args(&self) -> &[TypeArg] {
82 &self.args
83 }
84
85 pub fn def(&self) -> &OpDef {
87 self.def.as_ref()
88 }
89
90 pub fn def_arc(&self) -> &Arc<OpDef> {
93 &self.def
94 }
95
96 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
98 self.def().constant_fold(self.args(), consts)
99 }
100
101 pub fn make_opaque(&self) -> OpaqueOp {
109 OpaqueOp {
110 extension: self.def.extension_id().clone(),
111 name: self.def.name().clone(),
112 description: self.def.description().into(),
113 args: self.args.clone(),
114 signature: self.signature.clone(),
115 }
116 }
117
118 pub fn signature_mut(&mut self) -> &mut Signature {
120 &mut self.signature
121 }
122
123 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
125 self.args.as_mut_slice()
126 }
127
128 pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
132 T::from_extension_op(self).ok()
133 }
134
135 pub fn extension_id(&self) -> &ExtensionId {
137 self.def.extension_id()
138 }
139
140 pub fn unqualified_id(&self) -> &OpNameRef {
143 self.def.name()
144 }
145
146 pub fn qualified_id(&self) -> OpName {
148 qualify_name(self.extension_id(), self.unqualified_id())
149 }
150}
151
152impl From<ExtensionOp> for OpaqueOp {
153 fn from(op: ExtensionOp) -> Self {
154 let ExtensionOp {
155 def,
156 args,
157 signature,
158 } = op;
159 OpaqueOp {
160 extension: def.extension_id().clone(),
161 name: def.name().clone(),
162 description: def.description().into(),
163 args,
164 signature,
165 }
166 }
167}
168
169impl PartialEq for ExtensionOp {
170 fn eq(&self, other: &Self) -> bool {
171 Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
172 }
173}
174
175impl Eq for ExtensionOp {}
176
177impl NamedOp for ExtensionOp {
178 fn name(&self) -> OpName {
180 self.qualified_id()
181 }
182}
183
184impl DataflowOpTrait for ExtensionOp {
185 const TAG: OpTag = OpTag::Leaf;
186
187 fn description(&self) -> &str {
188 self.def().description()
189 }
190
191 fn signature(&self) -> Cow<'_, Signature> {
192 Cow::Borrowed(&self.signature)
193 }
194
195 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
196 let args = self
197 .args
198 .iter()
199 .map(|ta| ta.substitute(subst))
200 .collect::<Vec<_>>();
201 let signature = self.signature.substitute(subst);
202 Self {
203 def: self.def.clone(),
204 args,
205 signature,
206 }
207 }
208}
209
210#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
220#[cfg_attr(test, derive(Arbitrary))]
221pub struct OpaqueOp {
222 extension: ExtensionId,
223 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
224 name: OpName,
225 #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
226 description: String, args: Vec<TypeArg>,
228 signature: Signature,
232}
233
234fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
235 format!("{}.{}", res_id, name).into()
236}
237
238impl OpaqueOp {
239 pub fn new(
241 extension: ExtensionId,
242 name: impl Into<OpName>,
243 description: String,
244 args: impl Into<Vec<TypeArg>>,
245 signature: Signature,
246 ) -> Self {
247 Self {
248 extension,
249 name: name.into(),
250 description,
251 args: args.into(),
252 signature,
253 }
254 }
255
256 pub fn signature_mut(&mut self) -> &mut Signature {
258 &mut self.signature
259 }
260}
261
262impl NamedOp for OpaqueOp {
263 fn name(&self) -> OpName {
264 format!("OpaqueOp:{}", self.qualified_id()).into()
265 }
266}
267
268impl OpaqueOp {
269 pub fn unqualified_id(&self) -> &OpName {
271 &self.name
272 }
273
274 pub fn qualified_id(&self) -> OpName {
276 qualify_name(self.extension(), self.unqualified_id())
277 }
278
279 pub fn args(&self) -> &[TypeArg] {
281 &self.args
282 }
283
284 pub fn extension(&self) -> &ExtensionId {
286 &self.extension
287 }
288
289 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
291 self.args.as_mut_slice()
292 }
293}
294
295impl DataflowOpTrait for OpaqueOp {
296 const TAG: OpTag = OpTag::Leaf;
297
298 fn description(&self) -> &str {
299 &self.description
300 }
301
302 fn signature(&self) -> Cow<'_, Signature> {
303 Cow::Borrowed(&self.signature)
304 }
305
306 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
307 Self {
308 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
309 signature: self.signature.substitute(subst),
310 ..self.clone()
311 }
312 }
313}
314
315#[derive(Clone, Debug, Error, PartialEq)]
318#[non_exhaustive]
319pub enum OpaqueOpError<N: HugrNode> {
320 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
322 available_ops.iter().join(", ")
323 )]
324 OpNotFoundInExtension {
325 node: N,
327 op: OpName,
329 extension: ExtensionId,
331 available_ops: Vec<OpName>,
333 },
334 #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
336 #[allow(missing_docs)]
337 SignatureMismatch {
338 node: N,
339 extension: ExtensionId,
340 op: OpName,
341 stored: Signature,
342 computed: Signature,
343 },
344 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
346 #[allow(missing_docs)]
347 SignatureError {
348 node: N,
349 name: OpName,
350 #[source]
351 cause: SignatureError,
352 },
353 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
355 UnresolvedOp(N, OpName, ExtensionId),
356 #[error("Error updating extension registry: {0}")]
358 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
359}
360
361#[cfg(test)]
362mod test {
363
364 use ops::OpType;
365
366 use crate::extension::resolution::resolve_op_extensions;
367 use crate::extension::ExtensionRegistry;
368 use crate::std_extensions::arithmetic::conversions::{self};
369 use crate::std_extensions::STD_REG;
370 use crate::Node;
371 use crate::{
372 extension::{
373 prelude::{bool_t, qb_t, usize_t},
374 SignatureFunc,
375 },
376 std_extensions::arithmetic::int_types::INT_TYPES,
377 types::FuncValueType,
378 Extension,
379 };
380
381 use super::*;
382
383 fn resolve_res_definition(res: &OpType) -> &OpDef {
385 res.as_extension_op().unwrap().def()
386 }
387
388 #[test]
389 fn new_opaque_op() {
390 let sig = Signature::new_endo(vec![qb_t()]);
391 let op = OpaqueOp::new(
392 "res".try_into().unwrap(),
393 "op",
394 "desc".into(),
395 vec![TypeArg::Type { ty: usize_t() }],
396 sig.clone(),
397 );
398 assert_eq!(op.name(), "OpaqueOp:res.op");
399 assert_eq!(DataflowOpTrait::description(&op), "desc");
400 assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
401 assert_eq!(op.signature().as_ref(), &sig);
402 }
403
404 #[test]
405 fn resolve_opaque_op() {
406 let registry = &STD_REG;
407 let i0 = &INT_TYPES[0];
408 let opaque = OpaqueOp::new(
409 conversions::EXTENSION_ID,
410 "itobool",
411 "description".into(),
412 vec![],
413 Signature::new(i0.clone(), bool_t()),
414 );
415 let mut resolved = opaque.into();
416 resolve_op_extensions(
417 Node::from(portgraph::NodeIndex::new(1)),
418 &mut resolved,
419 registry,
420 )
421 .unwrap();
422 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
423 }
424
425 #[test]
426 fn resolve_missing() {
427 let val_name = "missing_val";
428 let comp_name = "missing_comp";
429 let endo_sig = Signature::new_endo(bool_t());
430
431 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
432 ext.add_op(
433 val_name.into(),
434 "".to_string(),
435 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
436 extension_ref,
437 )
438 .unwrap();
439
440 ext.add_op(
441 comp_name.into(),
442 "".to_string(),
443 SignatureFunc::MissingComputeFunc,
444 extension_ref,
445 )
446 .unwrap();
447 });
448 let ext_id = ext.name().clone();
449
450 let registry = ExtensionRegistry::new([ext]);
451 registry.validate().unwrap();
452 let opaque_val = OpaqueOp::new(
453 ext_id.clone(),
454 val_name,
455 "".into(),
456 vec![],
457 endo_sig.clone(),
458 );
459 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
460 let mut resolved_val = opaque_val.into();
461 resolve_op_extensions(
462 Node::from(portgraph::NodeIndex::new(1)),
463 &mut resolved_val,
464 ®istry,
465 )
466 .unwrap();
467 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
468
469 let mut resolved_comp = opaque_comp.into();
470 resolve_op_extensions(
471 Node::from(portgraph::NodeIndex::new(2)),
472 &mut resolved_comp,
473 ®istry,
474 )
475 .unwrap();
476 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
477 }
478}