hugr_core/extension/resolution/
extension.rs

1//! Resolve weak links inside `CustomType`s in an extension definition.
2//!
3//! This module is used when loading serialized extensions, to ensure that all
4//! weak links are resolved.
5#![allow(dead_code, unused_variables)]
6
7use std::mem;
8use std::sync::Arc;
9
10use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef};
11
12use super::types_mut::{resolve_signature_exts, resolve_value_exts};
13use super::{ExtensionResolutionError, WeakExtensionRegistry};
14
15impl ExtensionRegistry {
16    /// Given a list of extensions that has been deserialized, create a new
17    /// registry while updating any internal `Weak<Extension>` reference to
18    /// point to the newly created [`Arc`]s in the registry.
19    ///
20    /// # Errors
21    ///
22    /// - If an opaque operation cannot be resolved to an extension operation.
23    /// - If an extension operation references an extension that is missing from
24    ///   the registry.
25    /// - If a custom type references an extension that is missing from the
26    ///   registry.
27    pub fn new_with_extension_resolution(
28        extensions: impl IntoIterator<Item = Extension>,
29        other_extensions: &WeakExtensionRegistry,
30    ) -> Result<ExtensionRegistry, ExtensionResolutionError> {
31        Self::new_cyclic(extensions, |mut exts, weak_registry| {
32            let mut weak_registry = weak_registry.clone();
33            for (other_id, other) in other_extensions.iter() {
34                weak_registry.register(other_id.clone(), other.clone());
35            }
36            for ext in &mut exts {
37                ext.resolve_references(&weak_registry)?;
38            }
39            Ok(exts)
40        })
41    }
42}
43
44impl Extension {
45    /// Resolve all extension references inside the extension.
46    ///
47    /// This is internally called when after deserializing an extension, to
48    /// update all `Weak` references that were dropped from the types.
49    ///
50    /// This method will clone all opdef `Arc`s in the extension, so it should
51    /// not be called on locally defined extensions to prevent unnecessary
52    /// cloning.
53    fn resolve_references(
54        &mut self,
55        extensions: &WeakExtensionRegistry,
56    ) -> Result<(), ExtensionResolutionError> {
57        let mut used_extensions = WeakExtensionRegistry::default();
58
59        for type_def in self.types.values_mut() {
60            resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?;
61        }
62        for val in self.values.values_mut() {
63            resolve_value_exts(
64                None,
65                val.typed_value_mut(),
66                extensions,
67                &mut used_extensions,
68            )?;
69        }
70        let ops = mem::take(&mut self.operations);
71        for (op_id, mut op_def) in ops {
72            // TODO: We should be able to clone the definition if needed by using `make_mut`,
73            // but `OpDef` does not implement clone -.-
74            let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
75            resolve_opdef_exts(&self.name, op_ref, extensions, &mut used_extensions)?;
76            self.operations.insert(op_id, op_def);
77        }
78
79        Ok(())
80    }
81}
82
83/// Update all weak Extension pointers in the [`CustomType`]s inside a type
84/// definition.
85///
86/// Adds the extensions used in the type to the `used_extensions` registry.
87pub(super) fn resolve_typedef_exts(
88    extension: &ExtensionId,
89    def: &mut TypeDef,
90    extensions: &WeakExtensionRegistry,
91    used_extensions: &mut WeakExtensionRegistry,
92) -> Result<(), ExtensionResolutionError> {
93    match extensions.get(def.extension_id()) {
94        Some(ext) => {
95            *def.extension_mut() = ext.clone();
96        }
97        None => {
98            return Err(ExtensionResolutionError::WrongTypeDefExtension {
99                extension: extension.clone(),
100                def: def.name().clone(),
101                wrong_extension: def.extension_id().clone(),
102            });
103        }
104    }
105
106    Ok(())
107}
108
109/// Update all weak Extension pointers in the [`CustomType`]s inside an
110/// operation definition.
111///
112/// Adds the extensions used in the type to the `used_extensions` registry.
113pub(super) fn resolve_opdef_exts(
114    extension: &ExtensionId,
115    def: &mut OpDef,
116    extensions: &WeakExtensionRegistry,
117    used_extensions: &mut WeakExtensionRegistry,
118) -> Result<(), ExtensionResolutionError> {
119    match extensions.get(def.extension_id()) {
120        Some(ext) => {
121            *def.extension_mut() = ext.clone();
122        }
123        None => {
124            return Err(ExtensionResolutionError::WrongOpDefExtension {
125                extension: extension.clone(),
126                def: def.name().clone(),
127                wrong_extension: def.extension_id().clone(),
128            });
129        }
130    }
131
132    resolve_signature_func_exts(
133        extension,
134        def.signature_func_mut(),
135        extensions,
136        used_extensions,
137    )?;
138
139    // We ignore the lowering functions in the operation definition.
140    // They may contain an unresolved hugr, but it's the responsibility of the
141    // lowering call to resolve it, is it may use a different set of extensions.
142
143    Ok(())
144}
145
146/// Update all weak Extension pointers in the [`CustomType`]s inside a
147/// signature computation function.
148///
149/// Adds the extensions used in the type to the `used_extensions` registry.
150pub(super) fn resolve_signature_func_exts(
151    extension: &ExtensionId,
152    signature: &mut SignatureFunc,
153    extensions: &WeakExtensionRegistry,
154    used_extensions: &mut WeakExtensionRegistry,
155) -> Result<(), ExtensionResolutionError> {
156    let signature_body = match signature {
157        SignatureFunc::PolyFuncType(p) => p.body_mut(),
158        SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
159        SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
160        // Binary computation functions should always return valid types.
161        SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
162            return Ok(());
163        }
164    };
165    resolve_signature_exts(None, signature_body, extensions, used_extensions)
166}