hugr_core/extension/resolution/
extension.rs1#![allow(dead_code, unused_variables)]
6
7use std::mem;
8use std::sync::Arc;
9
10use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef};
11
12use super::types_mut::resolve_signature_exts;
13use super::{ExtensionResolutionError, WeakExtensionRegistry};
14
15impl ExtensionRegistry {
16 pub fn new_with_extension_resolution(
28 extensions: impl IntoIterator<Item = Extension>,
29 other_extensions: &WeakExtensionRegistry,
30 ) -> Result<ExtensionRegistry, ExtensionResolutionError> {
31 Self::new_cyclic(extensions, |mut exts, weak_registry| {
32 let mut weak_registry = weak_registry.clone();
33 for (other_id, other) in other_extensions.iter() {
34 weak_registry.register(other_id.clone(), other.clone());
35 }
36 for ext in &mut exts {
37 ext.resolve_references(&weak_registry)?;
38 }
39 Ok(exts)
40 })
41 }
42}
43
44impl Extension {
45 fn resolve_references(
54 &mut self,
55 extensions: &WeakExtensionRegistry,
56 ) -> Result<(), ExtensionResolutionError> {
57 let mut used_extensions = WeakExtensionRegistry::default();
58
59 for type_def in self.types.values_mut() {
60 resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?;
61 }
62
63 let ops = mem::take(&mut self.operations);
64 for (op_id, mut op_def) in ops {
65 let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
68 resolve_opdef_exts(&self.name, op_ref, extensions, &mut used_extensions)?;
69 self.operations.insert(op_id, op_def);
70 }
71
72 Ok(())
73 }
74}
75
76pub(super) fn resolve_typedef_exts(
81 extension: &ExtensionId,
82 def: &mut TypeDef,
83 extensions: &WeakExtensionRegistry,
84 used_extensions: &mut WeakExtensionRegistry,
85) -> Result<(), ExtensionResolutionError> {
86 match extensions.get(def.extension_id()) {
87 Some(ext) => {
88 *def.extension_mut() = ext.clone();
89 }
90 None => {
91 return Err(ExtensionResolutionError::WrongTypeDefExtension {
92 extension: extension.clone(),
93 def: def.name().clone(),
94 wrong_extension: def.extension_id().clone(),
95 });
96 }
97 }
98
99 Ok(())
100}
101
102pub(super) fn resolve_opdef_exts(
107 extension: &ExtensionId,
108 def: &mut OpDef,
109 extensions: &WeakExtensionRegistry,
110 used_extensions: &mut WeakExtensionRegistry,
111) -> Result<(), ExtensionResolutionError> {
112 match extensions.get(def.extension_id()) {
113 Some(ext) => {
114 *def.extension_mut() = ext.clone();
115 }
116 None => {
117 return Err(ExtensionResolutionError::WrongOpDefExtension {
118 extension: extension.clone(),
119 def: def.name().clone(),
120 wrong_extension: def.extension_id().clone(),
121 });
122 }
123 }
124
125 resolve_signature_func_exts(
126 extension,
127 def.signature_func_mut(),
128 extensions,
129 used_extensions,
130 )?;
131
132 Ok(())
137}
138
139pub(super) fn resolve_signature_func_exts(
144 extension: &ExtensionId,
145 signature: &mut SignatureFunc,
146 extensions: &WeakExtensionRegistry,
147 used_extensions: &mut WeakExtensionRegistry,
148) -> Result<(), ExtensionResolutionError> {
149 let signature_body = match signature {
150 SignatureFunc::PolyFuncType(p) => p.body_mut(),
151 SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
152 SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
153 SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
155 return Ok(());
156 }
157 };
158 resolve_signature_exts(None, signature_body, extensions, used_extensions)
159}