1use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9 crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr,
10 crate::types::proptest_utils::any_serde_type_arg_vec, ::proptest::prelude::*,
11 ::proptest_derive::Arbitrary,
12};
13
14use crate::core::HugrNode;
15use crate::extension::simple_op::MakeExtensionOp;
16use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
17use crate::types::{Signature, type_param::TypeArg};
18use crate::{IncomingPort, ops};
19
20use super::dataflow::DataflowOpTrait;
21use super::tag::OpTag;
22use super::{NamedOp, OpName, OpNameRef};
23
24#[derive(Clone, Debug, serde::Serialize)]
31#[serde(into = "OpaqueOp")]
32#[cfg_attr(test, derive(Arbitrary))]
33pub struct ExtensionOp {
34 #[cfg_attr(
35 test,
36 proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
37 )]
38 def: Arc<OpDef>,
39 #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
40 args: Vec<TypeArg>,
41 signature: Signature, }
43
44impl ExtensionOp {
45 pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
47 let args: Vec<TypeArg> = args.into();
48 let signature = def.compute_signature(&args)?;
49 Ok(Self {
50 def,
51 args,
52 signature,
53 })
54 }
55
56 pub(crate) fn new_with_cached(
58 def: Arc<OpDef>,
59 args: impl IntoIterator<Item = TypeArg>,
60 opaque: &OpaqueOp,
61 ) -> Result<Self, SignatureError> {
62 let args: Vec<TypeArg> = args.into_iter().collect();
63 let signature = match def.compute_signature(&args) {
66 Ok(sig) => sig,
67 Err(SignatureError::MissingComputeFunc) => {
68 opaque.signature().into_owned()
70 }
71 Err(e) => return Err(e),
72 };
73 Ok(Self {
74 def,
75 args,
76 signature,
77 })
78 }
79
80 #[must_use]
82 pub fn args(&self) -> &[TypeArg] {
83 &self.args
84 }
85
86 #[must_use]
88 pub fn def(&self) -> &OpDef {
89 self.def.as_ref()
90 }
91
92 #[must_use]
95 pub fn def_arc(&self) -> &Arc<OpDef> {
96 &self.def
97 }
98
99 #[must_use]
101 pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
102 self.def().constant_fold(self.args(), consts)
103 }
104
105 #[must_use]
113 pub fn make_opaque(&self) -> OpaqueOp {
114 OpaqueOp {
115 extension: self.def.extension_id().clone(),
116 name: self.def.name().clone(),
117 args: self.args.clone(),
118 signature: self.signature.clone(),
119 }
120 }
121
122 pub fn signature_mut(&mut self) -> &mut Signature {
124 &mut self.signature
125 }
126
127 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
129 self.args.as_mut_slice()
130 }
131
132 #[must_use]
136 pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
137 T::from_extension_op(self).ok()
138 }
139
140 #[must_use]
142 pub fn extension_id(&self) -> &ExtensionId {
143 self.def.extension_id()
144 }
145
146 #[must_use]
149 pub fn unqualified_id(&self) -> &OpNameRef {
150 self.def.name()
151 }
152
153 #[must_use]
155 pub fn qualified_id(&self) -> OpName {
156 qualify_name(self.extension_id(), self.unqualified_id())
157 }
158}
159
160impl From<ExtensionOp> for OpaqueOp {
161 fn from(op: ExtensionOp) -> Self {
162 let ExtensionOp {
163 def,
164 args,
165 signature,
166 } = op;
167 OpaqueOp {
168 extension: def.extension_id().clone(),
169 name: def.name().clone(),
170 args,
171 signature,
172 }
173 }
174}
175
176impl PartialEq for ExtensionOp {
177 fn eq(&self, other: &Self) -> bool {
178 if Arc::<OpDef>::ptr_eq(&self.def, &other.def) {
179 self.args() == other.args()
181 } else {
182 self.args() == other.args()
183 && self.signature() == other.signature()
184 && self.def.name() == other.def.name()
185 && self.def.extension_id() == other.def.extension_id()
186 }
187 }
188}
189
190impl Eq for ExtensionOp {}
191
192impl NamedOp for ExtensionOp {
193 fn name(&self) -> OpName {
195 self.qualified_id()
196 }
197}
198
199impl DataflowOpTrait for ExtensionOp {
200 const TAG: OpTag = OpTag::Leaf;
201
202 fn description(&self) -> &str {
203 self.def().description()
204 }
205
206 fn signature(&self) -> Cow<'_, Signature> {
207 Cow::Borrowed(&self.signature)
208 }
209
210 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
211 let args = self
212 .args
213 .iter()
214 .map(|ta| ta.substitute(subst))
215 .collect::<Vec<_>>();
216 let signature = self.signature.substitute(subst);
217 Self {
218 def: self.def.clone(),
219 args,
220 signature,
221 }
222 }
223}
224
225#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
235#[cfg_attr(test, derive(Arbitrary))]
236pub struct OpaqueOp {
237 extension: ExtensionId,
238 #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
239 name: OpName,
240 #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
241 args: Vec<TypeArg>,
242 signature: Signature,
246}
247
248fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
249 format!("{res_id}.{name}").into()
250}
251
252impl OpaqueOp {
253 pub fn new(
255 extension: ExtensionId,
256 name: impl Into<OpName>,
257 args: impl Into<Vec<TypeArg>>,
258 signature: Signature,
259 ) -> Self {
260 Self {
261 extension,
262 name: name.into(),
263 args: args.into(),
264 signature,
265 }
266 }
267
268 pub fn signature_mut(&mut self) -> &mut Signature {
270 &mut self.signature
271 }
272}
273
274impl NamedOp for OpaqueOp {
275 fn name(&self) -> OpName {
276 format!("OpaqueOp:{}", self.qualified_id()).into()
277 }
278}
279
280impl OpaqueOp {
281 #[must_use]
283 pub fn unqualified_id(&self) -> &OpName {
284 &self.name
285 }
286
287 #[must_use]
289 pub fn qualified_id(&self) -> OpName {
290 qualify_name(self.extension(), self.unqualified_id())
291 }
292
293 #[must_use]
295 pub fn args(&self) -> &[TypeArg] {
296 &self.args
297 }
298
299 #[must_use]
301 pub fn extension(&self) -> &ExtensionId {
302 &self.extension
303 }
304
305 pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
307 self.args.as_mut_slice()
308 }
309}
310
311impl DataflowOpTrait for OpaqueOp {
312 const TAG: OpTag = OpTag::Leaf;
313
314 fn description(&self) -> &str {
315 "Opaque operation"
316 }
317
318 fn signature(&self) -> Cow<'_, Signature> {
319 Cow::Borrowed(&self.signature)
320 }
321
322 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
323 Self {
324 args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
325 signature: self.signature.substitute(subst),
326 ..self.clone()
327 }
328 }
329}
330
331#[derive(Clone, Debug, Error, PartialEq)]
334#[non_exhaustive]
335pub enum OpaqueOpError<N: HugrNode> {
336 #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
338 available_ops.iter().join(", ")
339 )]
340 OpNotFoundInExtension {
341 node: N,
343 op: OpName,
345 extension: ExtensionId,
347 available_ops: Vec<OpName>,
349 },
350 #[error(
352 "Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}"
353 )]
354 #[allow(missing_docs)]
355 SignatureMismatch {
356 node: N,
357 extension: ExtensionId,
358 op: OpName,
359 stored: Box<Signature>,
360 computed: Box<Signature>,
361 },
362 #[error("Error in signature of operation '{name}' in {node}: {cause}")]
364 #[allow(missing_docs)]
365 SignatureError {
366 node: N,
367 name: OpName,
368 #[source]
369 cause: SignatureError,
370 },
371 #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
373 UnresolvedOp(N, OpName, ExtensionId),
374 #[error("Error updating extension registry: {0}")]
376 ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
377}
378
379#[cfg(test)]
380mod test {
381
382 use ops::OpType;
383
384 use crate::Node;
385 use crate::extension::ExtensionRegistry;
386 use crate::extension::resolution::resolve_op_extensions;
387 use crate::std_extensions::STD_REG;
388 use crate::std_extensions::arithmetic::conversions::{self};
389 use crate::{
390 Extension,
391 extension::{
392 SignatureFunc,
393 prelude::{bool_t, qb_t, usize_t},
394 },
395 std_extensions::arithmetic::int_types::INT_TYPES,
396 types::FuncValueType,
397 };
398
399 use super::*;
400
401 fn resolve_res_definition(res: &OpType) -> &OpDef {
403 res.as_extension_op().unwrap().def()
404 }
405
406 #[test]
407 fn new_opaque_op() {
408 let sig = Signature::new_endo(vec![qb_t()]);
409 let op = OpaqueOp::new(
410 "res".try_into().unwrap(),
411 "op",
412 vec![usize_t().into()],
413 sig.clone(),
414 );
415 assert_eq!(op.name(), "OpaqueOp:res.op");
416 assert_eq!(op.args(), &[usize_t().into()]);
417 assert_eq!(op.signature().as_ref(), &sig);
418 }
419
420 #[test]
421 fn resolve_opaque_op() {
422 let registry = &STD_REG;
423 let i0 = &INT_TYPES[0];
424 let opaque = OpaqueOp::new(
425 conversions::EXTENSION_ID,
426 "itobool",
427 vec![],
428 Signature::new(i0.clone(), bool_t()),
429 );
430 let mut resolved = opaque.into();
431 resolve_op_extensions(
432 Node::from(portgraph::NodeIndex::new(1)),
433 &mut resolved,
434 registry,
435 )
436 .unwrap();
437 assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
438 }
439
440 #[test]
441 fn resolve_missing() {
442 let val_name = "missing_val";
443 let comp_name = "missing_comp";
444 let endo_sig = Signature::new_endo(bool_t());
445
446 let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
447 ext.add_op(
448 val_name.into(),
449 String::new(),
450 SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
451 extension_ref,
452 )
453 .unwrap();
454
455 ext.add_op(
456 comp_name.into(),
457 String::new(),
458 SignatureFunc::MissingComputeFunc,
459 extension_ref,
460 )
461 .unwrap();
462 });
463 let ext_id = ext.name().clone();
464
465 let registry = ExtensionRegistry::new([ext]);
466 registry.validate().unwrap();
467 let opaque_val = OpaqueOp::new(ext_id.clone(), val_name, vec![], endo_sig.clone());
468 let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, vec![], endo_sig);
469 let mut resolved_val = opaque_val.into();
470 resolve_op_extensions(
471 Node::from(portgraph::NodeIndex::new(1)),
472 &mut resolved_val,
473 ®istry,
474 )
475 .unwrap();
476 assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
477
478 let mut resolved_comp = opaque_comp.into();
479 resolve_op_extensions(
480 Node::from(portgraph::NodeIndex::new(2)),
481 &mut resolved_comp,
482 ®istry,
483 )
484 .unwrap();
485 assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
486 }
487}