hugr_core/extension/resolution/
weak_registry.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, Weak};
3
4use itertools::Itertools;
5
6use derive_more::Display;
7
8use crate::Extension;
9use crate::extension::{ExtensionId, ExtensionRegistry};
10
11/// The equivalent to an [`ExtensionRegistry`] that only contains weak
12/// references.
13///
14/// This is used to resolve extensions pointers while the extensions themselves
15/// (and the [`Arc`] that contains them) are being initialized.
16#[derive(Debug, Display, Default, Clone)]
17#[display("WeakExtensionRegistry[{}]", exts.keys().join(", "))]
18pub struct WeakExtensionRegistry {
19    /// The extensions in the registry.
20    exts: BTreeMap<ExtensionId, Weak<Extension>>,
21}
22
23impl WeakExtensionRegistry {
24    /// Create a new weak registry from a list of extensions and their ids.
25    pub fn new(extensions: impl IntoIterator<Item = (ExtensionId, Weak<Extension>)>) -> Self {
26        let mut res = Self::default();
27        for (id, ext) in extensions {
28            res.register(id, ext);
29        }
30        res
31    }
32
33    /// Gets the Extension with the given name
34    #[must_use]
35    pub fn get(&self, name: &str) -> Option<&Weak<Extension>> {
36        self.exts.get(name)
37    }
38
39    /// Returns `true` if the registry contains an extension with the given name.
40    #[must_use]
41    pub fn contains(&self, name: &str) -> bool {
42        self.exts.contains_key(name)
43    }
44
45    /// Register a new extension in the registry.
46    ///
47    /// Returns `true` if the extension was added, `false` if it was already present.
48    pub fn register(&mut self, id: ExtensionId, ext: impl Into<Weak<Extension>>) -> bool {
49        self.exts.insert(id, ext.into()).is_none()
50    }
51
52    /// Returns an iterator over the weak references in the registry and their ids.
53    pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Weak<Extension>)> {
54        self.exts.iter()
55    }
56
57    /// Returns an iterator over the weak references in the registry.
58    pub fn extensions(&self) -> impl Iterator<Item = &Weak<Extension>> {
59        self.exts.values()
60    }
61
62    /// Returns an iterator over the extension ids in the registry.
63    pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
64        self.exts.keys()
65    }
66}
67
68impl IntoIterator for WeakExtensionRegistry {
69    type Item = Weak<Extension>;
70    type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Weak<Extension>>;
71
72    fn into_iter(self) -> Self::IntoIter {
73        self.exts.into_values()
74    }
75}
76
77impl<'a> TryFrom<&'a WeakExtensionRegistry> for ExtensionRegistry {
78    type Error = ();
79
80    fn try_from(weak: &'a WeakExtensionRegistry) -> Result<Self, Self::Error> {
81        let exts: Vec<Arc<Extension>> = weak
82            .extensions()
83            .map(|w| w.upgrade().ok_or(()))
84            .try_collect()?;
85        Ok(ExtensionRegistry::new(exts))
86    }
87}
88
89impl TryFrom<WeakExtensionRegistry> for ExtensionRegistry {
90    type Error = ();
91
92    fn try_from(weak: WeakExtensionRegistry) -> Result<Self, Self::Error> {
93        let exts: Vec<Arc<Extension>> = weak
94            .into_iter()
95            .map(|w| w.upgrade().ok_or(()))
96            .try_collect()?;
97        Ok(ExtensionRegistry::new(exts))
98    }
99}
100
101impl<'a> From<&'a ExtensionRegistry> for WeakExtensionRegistry {
102    fn from(reg: &'a ExtensionRegistry) -> Self {
103        let exts = reg
104            .iter()
105            .map(|ext| (ext.name().clone(), Arc::downgrade(ext)))
106            .collect();
107        Self { exts }
108    }
109}