1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr,
10 ::proptest::prelude::*, ::proptest_derive::Arbitrary,
11};
12
13use crate::core::HugrNode;
14use crate::extension::simple_op::MakeExtensionOp;
15use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
16use crate::types::{Signature, type_param::TypeArg};
17use crate::{IncomingPort, ops};
18
19use super::dataflow::DataflowOpTrait;
20use super::tag::OpTag;
21use super::{NamedOp, OpName, OpNameRef};
22
23#[derive(Clone, Debug, serde::Serialize)]
30#[serde(into = "OpaqueOp")]
31#[cfg_attr(test, derive(Arbitrary))]
32pub struct ExtensionOp {
33 #[cfg_attr(
34 test,
35 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
36 )]
37 def: Arc<OpDef>,
38 args: Vec<TypeArg>,
39 signature: Signature, }
41
42impl ExtensionOp {
43 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
45 let args: Vec<TypeArg> = args.into();
46 let signature = def.compute_signature(&args)?;
47 Ok(Self {
48 def,
49 args,
50 signature,
51 })
52 }
53
54 pub(crate) fn new_with_cached(
56 def: Arc<OpDef>,
57 args: impl IntoIterator<Item = TypeArg>,
58 opaque: &OpaqueOp,
59 ) -> Result<Self, SignatureError> {
60 let args: Vec<TypeArg> = args.into_iter().collect();
61 let signature = match def.compute_signature(&args) {
64 Ok(sig) => sig,
65 Err(SignatureError::MissingComputeFunc) => {
66 opaque.signature().into_owned()
68 }
69 Err(e) => return Err(e),
70 };
71 Ok(Self {
72 def,
73 args,
74 signature,
75 })
76 }
77
78 #[must_use]
80 pub fn args(&self) -> &[TypeArg] {
81 &self.args
82 }
83
84 #[must_use]
86 pub fn def(&self) -> &OpDef {
87 self.def.as_ref()
88 }
89
90 #[must_use]
93 pub fn def_arc(&self) -> &Arc<OpDef> {
94 &self.def
95 }
96
97 #[must_use]
99 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
100 self.def().constant_fold(self.args(), consts)
101 }
102
103 #[must_use]
111 pub fn make_opaque(&self) -> OpaqueOp {
112 OpaqueOp {
113 extension: self.def.extension_id().clone(),
114 name: self.def.name().clone(),
115 args: self.args.clone(),
116 signature: self.signature.clone(),
117 }
118 }
119
120 pub fn signature_mut(&mut self) -> &mut Signature {
122 &mut self.signature
123 }
124
125 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
127 self.args.as_mut_slice()
128 }
129
130 #[must_use]
134 pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
135 T::from_extension_op(self).ok()
136 }
137
138 #[must_use]
140 pub fn extension_id(&self) -> &ExtensionId {
141 self.def.extension_id()
142 }
143
144 #[must_use]
147 pub fn unqualified_id(&self) -> &OpNameRef {
148 self.def.name()
149 }
150
151 #[must_use]
153 pub fn qualified_id(&self) -> OpName {
154 qualify_name(self.extension_id(), self.unqualified_id())
155 }
156}
157
158impl From<ExtensionOp> for OpaqueOp {
159 fn from(op: ExtensionOp) -> Self {
160 let ExtensionOp {
161 def,
162 args,
163 signature,
164 } = op;
165 OpaqueOp {
166 extension: def.extension_id().clone(),
167 name: def.name().clone(),
168 args,
169 signature,
170 }
171 }
172}
173
174impl PartialEq for ExtensionOp {
175 fn eq(&self, other: &Self) -> bool {
176 if Arc::<OpDef>::ptr_eq(&self.def, &other.def) {
177 self.args() == other.args()
179 } else {
180 self.args() == other.args()
181 && self.signature() == other.signature()
182 && self.def.name() == other.def.name()
183 && self.def.extension_id() == other.def.extension_id()
184 }
185 }
186}
187
188impl Eq for ExtensionOp {}
189
190impl NamedOp for ExtensionOp {
191 fn name(&self) -> OpName {
193 self.qualified_id()
194 }
195}
196
197impl DataflowOpTrait for ExtensionOp {
198 const TAG: OpTag = OpTag::Leaf;
199
200 fn description(&self) -> &str {
201 self.def().description()
202 }
203
204 fn signature(&self) -> Cow<'_, Signature> {
205 Cow::Borrowed(&self.signature)
206 }
207
208 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
209 let args = self
210 .args
211 .iter()
212 .map(|ta| ta.substitute(subst))
213 .collect::<Vec<_>>();
214 let signature = self.signature.substitute(subst);
215 Self {
216 def: self.def.clone(),
217 args,
218 signature,
219 }
220 }
221}
222
223#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
233#[cfg_attr(test, derive(Arbitrary))]
234pub struct OpaqueOp {
235 extension: ExtensionId,
236 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
237 name: OpName,
238 args: Vec<TypeArg>,
239 signature: Signature,
243}
244
245fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
246 format!("{res_id}.{name}").into()
247}
248
249impl OpaqueOp {
250 pub fn new(
252 extension: ExtensionId,
253 name: impl Into<OpName>,
254 args: impl Into<Vec<TypeArg>>,
255 signature: Signature,
256 ) -> Self {
257 Self {
258 extension,
259 name: name.into(),
260 args: args.into(),
261 signature,
262 }
263 }
264
265 pub fn signature_mut(&mut self) -> &mut Signature {
267 &mut self.signature
268 }
269}
270
271impl NamedOp for OpaqueOp {
272 fn name(&self) -> OpName {
273 format!("OpaqueOp:{}", self.qualified_id()).into()
274 }
275}
276
277impl OpaqueOp {
278 #[must_use]
280 pub fn unqualified_id(&self) -> &OpName {
281 &self.name
282 }
283
284 #[must_use]
286 pub fn qualified_id(&self) -> OpName {
287 qualify_name(self.extension(), self.unqualified_id())
288 }
289
290 #[must_use]
292 pub fn args(&self) -> &[TypeArg] {
293 &self.args
294 }
295
296 #[must_use]
298 pub fn extension(&self) -> &ExtensionId {
299 &self.extension
300 }
301
302 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
304 self.args.as_mut_slice()
305 }
306}
307
308impl DataflowOpTrait for OpaqueOp {
309 const TAG: OpTag = OpTag::Leaf;
310
311 fn description(&self) -> &str {
312 "Opaque operation"
313 }
314
315 fn signature(&self) -> Cow<'_, Signature> {
316 Cow::Borrowed(&self.signature)
317 }
318
319 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
320 Self {
321 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
322 signature: self.signature.substitute(subst),
323 ..self.clone()
324 }
325 }
326}
327
328#[derive(Clone, Debug, Error, PartialEq)]
331#[non_exhaustive]
332pub enum OpaqueOpError<N: HugrNode> {
333 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
335 available_ops.iter().join(", ")
336 )]
337 OpNotFoundInExtension {
338 node: N,
340 op: OpName,
342 extension: ExtensionId,
344 available_ops: Vec<OpName>,
346 },
347 #[error(
349 "Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}"
350 )]
351 #[allow(missing_docs)]
352 SignatureMismatch {
353 node: N,
354 extension: ExtensionId,
355 op: OpName,
356 stored: Signature,
357 computed: Signature,
358 },
359 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
361 #[allow(missing_docs)]
362 SignatureError {
363 node: N,
364 name: OpName,
365 #[source]
366 cause: SignatureError,
367 },
368 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
370 UnresolvedOp(N, OpName, ExtensionId),
371 #[error("Error updating extension registry: {0}")]
373 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
374}
375
376#[cfg(test)]
377mod test {
378
379 use ops::OpType;
380
381 use crate::Node;
382 use crate::extension::ExtensionRegistry;
383 use crate::extension::resolution::resolve_op_extensions;
384 use crate::std_extensions::STD_REG;
385 use crate::std_extensions::arithmetic::conversions::{self};
386 use crate::{
387 Extension,
388 extension::{
389 SignatureFunc,
390 prelude::{bool_t, qb_t, usize_t},
391 },
392 std_extensions::arithmetic::int_types::INT_TYPES,
393 types::FuncValueType,
394 };
395
396 use super::*;
397
398 fn resolve_res_definition(res: &OpType) -> &OpDef {
400 res.as_extension_op().unwrap().def()
401 }
402
403 #[test]
404 fn new_opaque_op() {
405 let sig = Signature::new_endo(vec![qb_t()]);
406 let op = OpaqueOp::new(
407 "res".try_into().unwrap(),
408 "op",
409 vec![TypeArg::Type { ty: usize_t() }],
410 sig.clone(),
411 );
412 assert_eq!(op.name(), "OpaqueOp:res.op");
413 assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
414 assert_eq!(op.signature().as_ref(), &sig);
415 }
416
417 #[test]
418 fn resolve_opaque_op() {
419 let registry = &STD_REG;
420 let i0 = &INT_TYPES[0];
421 let opaque = OpaqueOp::new(
422 conversions::EXTENSION_ID,
423 "itobool",
424 vec![],
425 Signature::new(i0.clone(), bool_t()),
426 );
427 let mut resolved = opaque.into();
428 resolve_op_extensions(
429 Node::from(portgraph::NodeIndex::new(1)),
430 &mut resolved,
431 registry,
432 )
433 .unwrap();
434 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
435 }
436
437 #[test]
438 fn resolve_missing() {
439 let val_name = "missing_val";
440 let comp_name = "missing_comp";
441 let endo_sig = Signature::new_endo(bool_t());
442
443 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
444 ext.add_op(
445 val_name.into(),
446 String::new(),
447 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
448 extension_ref,
449 )
450 .unwrap();
451
452 ext.add_op(
453 comp_name.into(),
454 String::new(),
455 SignatureFunc::MissingComputeFunc,
456 extension_ref,
457 )
458 .unwrap();
459 });
460 let ext_id = ext.name().clone();
461
462 let registry = ExtensionRegistry::new([ext]);
463 registry.validate().unwrap();
464 let opaque_val = OpaqueOp::new(ext_id.clone(), val_name, vec![], endo_sig.clone());
465 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, vec![], endo_sig);
466 let mut resolved_val = opaque_val.into();
467 resolve_op_extensions(
468 Node::from(portgraph::NodeIndex::new(1)),
469 &mut resolved_val,
470 ®istry,
471 )
472 .unwrap();
473 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
474
475 let mut resolved_comp = opaque_comp.into();
476 resolve_op_extensions(
477 Node::from(portgraph::NodeIndex::new(2)),
478 &mut resolved_comp,
479 ®istry,
480 )
481 .unwrap();
482 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
483 }
484}