hugr_core/extension/resolution/
extension.rs

1//! Resolve weak links inside `CustomType`s in an extension definition.
2//!
3//! This module is used when loading serialized extensions, to ensure that all
4//! weak links are resolved.
5#![allow(dead_code, unused_variables)]
6
7use std::mem;
8use std::sync::Arc;
9
10use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef};
11
12use super::types_mut::resolve_signature_exts;
13use super::{ExtensionResolutionError, WeakExtensionRegistry};
14
15impl ExtensionRegistry {
16    /// Given a list of extensions that has been deserialized, create a new
17    /// registry while updating any internal `Weak<Extension>` reference to
18    /// point to the newly created [`Arc`]s in the registry.
19    ///
20    /// # Errors
21    ///
22    /// - If an opaque operation cannot be resolved to an extension operation.
23    /// - If an extension operation references an extension that is missing from
24    ///   the registry.
25    /// - If a custom type references an extension that is missing from the
26    ///   registry.
27    pub fn new_with_extension_resolution(
28        extensions: impl IntoIterator<Item = Extension>,
29        other_extensions: &WeakExtensionRegistry,
30    ) -> Result<ExtensionRegistry, ExtensionResolutionError> {
31        Self::new_cyclic(extensions, |mut exts, weak_registry| {
32            let mut weak_registry = weak_registry.clone();
33            for (other_id, other) in other_extensions.iter() {
34                weak_registry.register(other_id.clone(), other.clone());
35            }
36            for ext in &mut exts {
37                ext.resolve_references(&weak_registry)?;
38            }
39            Ok(exts)
40        })
41    }
42}
43
44impl Extension {
45    /// Resolve all extension references inside the extension.
46    ///
47    /// This is internally called when after deserializing an extension, to
48    /// update all `Weak` references that were dropped from the types.
49    ///
50    /// This method will clone all opdef `Arc`s in the extension, so it should
51    /// not be called on locally defined extensions to prevent unnecessary
52    /// cloning.
53    fn resolve_references(
54        &mut self,
55        extensions: &WeakExtensionRegistry,
56    ) -> Result<(), ExtensionResolutionError> {
57        let mut used_extensions = WeakExtensionRegistry::default();
58
59        for type_def in self.types.values_mut() {
60            resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?;
61        }
62
63        let ops = mem::take(&mut self.operations);
64        for (op_id, mut op_def) in ops {
65            // TODO: We should be able to clone the definition if needed by using `make_mut`,
66            // but `OpDef` does not implement clone -.-
67            let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
68            resolve_opdef_exts(&self.name, op_ref, extensions, &mut used_extensions)?;
69            self.operations.insert(op_id, op_def);
70        }
71
72        Ok(())
73    }
74}
75
76/// Update all weak Extension pointers in the [`CustomType`]s inside a type
77/// definition.
78///
79/// Adds the extensions used in the type to the `used_extensions` registry.
80pub(super) fn resolve_typedef_exts(
81    extension: &ExtensionId,
82    def: &mut TypeDef,
83    extensions: &WeakExtensionRegistry,
84    used_extensions: &mut WeakExtensionRegistry,
85) -> Result<(), ExtensionResolutionError> {
86    match extensions.get(def.extension_id()) {
87        Some(ext) => {
88            *def.extension_mut() = ext.clone();
89        }
90        None => {
91            return Err(ExtensionResolutionError::WrongTypeDefExtension {
92                extension: extension.clone(),
93                def: def.name().clone(),
94                wrong_extension: def.extension_id().clone(),
95            });
96        }
97    }
98
99    Ok(())
100}
101
102/// Update all weak Extension pointers in the [`CustomType`]s inside an
103/// operation definition.
104///
105/// Adds the extensions used in the type to the `used_extensions` registry.
106pub(super) fn resolve_opdef_exts(
107    extension: &ExtensionId,
108    def: &mut OpDef,
109    extensions: &WeakExtensionRegistry,
110    used_extensions: &mut WeakExtensionRegistry,
111) -> Result<(), ExtensionResolutionError> {
112    match extensions.get(def.extension_id()) {
113        Some(ext) => {
114            *def.extension_mut() = ext.clone();
115        }
116        None => {
117            return Err(ExtensionResolutionError::WrongOpDefExtension {
118                extension: extension.clone(),
119                def: def.name().clone(),
120                wrong_extension: def.extension_id().clone(),
121            });
122        }
123    }
124
125    resolve_signature_func_exts(
126        extension,
127        def.signature_func_mut(),
128        extensions,
129        used_extensions,
130    )?;
131
132    // We ignore the lowering functions in the operation definition.
133    // They may contain an unresolved hugr, but it's the responsibility of the
134    // lowering call to resolve it, is it may use a different set of extensions.
135
136    Ok(())
137}
138
139/// Update all weak Extension pointers in the [`CustomType`]s inside a
140/// signature computation function.
141///
142/// Adds the extensions used in the type to the `used_extensions` registry.
143pub(super) fn resolve_signature_func_exts(
144    extension: &ExtensionId,
145    signature: &mut SignatureFunc,
146    extensions: &WeakExtensionRegistry,
147    used_extensions: &mut WeakExtensionRegistry,
148) -> Result<(), ExtensionResolutionError> {
149    let signature_body = match signature {
150        SignatureFunc::PolyFuncType(p) => p.body_mut(),
151        SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
152        SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
153        // Binary computation functions should always return valid types.
154        SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
155            return Ok(());
156        }
157    };
158    resolve_signature_exts(None, signature_body, extensions, used_extensions)
159}