hugr_core/extension/resolution/
extension.rs1#![allow(dead_code, unused_variables)]
6
7use std::mem;
8use std::sync::Arc;
9
10use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef};
11
12use super::types_mut::{resolve_signature_exts, resolve_value_exts};
13use super::{ExtensionResolutionError, WeakExtensionRegistry};
14
15impl ExtensionRegistry {
16 pub fn new_with_extension_resolution(
28 extensions: impl IntoIterator<Item = Extension>,
29 other_extensions: &WeakExtensionRegistry,
30 ) -> Result<ExtensionRegistry, ExtensionResolutionError> {
31 Self::new_cyclic(extensions, |mut exts, weak_registry| {
32 let mut weak_registry = weak_registry.clone();
33 for (other_id, other) in other_extensions.iter() {
34 weak_registry.register(other_id.clone(), other.clone());
35 }
36 for ext in &mut exts {
37 ext.resolve_references(&weak_registry)?;
38 }
39 Ok(exts)
40 })
41 }
42}
43
44impl Extension {
45 fn resolve_references(
54 &mut self,
55 extensions: &WeakExtensionRegistry,
56 ) -> Result<(), ExtensionResolutionError> {
57 let mut used_extensions = WeakExtensionRegistry::default();
58
59 for type_def in self.types.values_mut() {
60 resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?;
61 }
62 for val in self.values.values_mut() {
63 resolve_value_exts(
64 None,
65 val.typed_value_mut(),
66 extensions,
67 &mut used_extensions,
68 )?;
69 }
70 let ops = mem::take(&mut self.operations);
71 for (op_id, mut op_def) in ops {
72 let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
75 resolve_opdef_exts(&self.name, op_ref, extensions, &mut used_extensions)?;
76 self.operations.insert(op_id, op_def);
77 }
78
79 Ok(())
80 }
81}
82
83pub(super) fn resolve_typedef_exts(
88 extension: &ExtensionId,
89 def: &mut TypeDef,
90 extensions: &WeakExtensionRegistry,
91 used_extensions: &mut WeakExtensionRegistry,
92) -> Result<(), ExtensionResolutionError> {
93 match extensions.get(def.extension_id()) {
94 Some(ext) => {
95 *def.extension_mut() = ext.clone();
96 }
97 None => {
98 return Err(ExtensionResolutionError::WrongTypeDefExtension {
99 extension: extension.clone(),
100 def: def.name().clone(),
101 wrong_extension: def.extension_id().clone(),
102 });
103 }
104 }
105
106 Ok(())
107}
108
109pub(super) fn resolve_opdef_exts(
114 extension: &ExtensionId,
115 def: &mut OpDef,
116 extensions: &WeakExtensionRegistry,
117 used_extensions: &mut WeakExtensionRegistry,
118) -> Result<(), ExtensionResolutionError> {
119 match extensions.get(def.extension_id()) {
120 Some(ext) => {
121 *def.extension_mut() = ext.clone();
122 }
123 None => {
124 return Err(ExtensionResolutionError::WrongOpDefExtension {
125 extension: extension.clone(),
126 def: def.name().clone(),
127 wrong_extension: def.extension_id().clone(),
128 });
129 }
130 }
131
132 resolve_signature_func_exts(
133 extension,
134 def.signature_func_mut(),
135 extensions,
136 used_extensions,
137 )?;
138
139 Ok(())
144}
145
146pub(super) fn resolve_signature_func_exts(
151 extension: &ExtensionId,
152 signature: &mut SignatureFunc,
153 extensions: &WeakExtensionRegistry,
154 used_extensions: &mut WeakExtensionRegistry,
155) -> Result<(), ExtensionResolutionError> {
156 let signature_body = match signature {
157 SignatureFunc::PolyFuncType(p) => p.body_mut(),
158 SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
159 SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
160 SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
162 return Ok(());
163 }
164 };
165 resolve_signature_exts(None, signature_body, extensions, used_extensions)
166}