hugr_core/extension/resolution/
weak_registry.rs1use std::collections::BTreeMap;
2use std::sync::{Arc, Weak};
3
4use itertools::Itertools;
5
6use derive_more::Display;
7
8use crate::extension::{ExtensionId, ExtensionRegistry};
9use crate::Extension;
10
11#[derive(Debug, Display, Default, Clone)]
17#[display("WeakExtensionRegistry[{}]", exts.keys().join(", "))]
18pub struct WeakExtensionRegistry {
19    exts: BTreeMap<ExtensionId, Weak<Extension>>,
21}
22
23impl WeakExtensionRegistry {
24    pub fn new(extensions: impl IntoIterator<Item = (ExtensionId, Weak<Extension>)>) -> Self {
26        let mut res = Self::default();
27        for (id, ext) in extensions.into_iter() {
28            res.register(id, ext);
29        }
30        res
31    }
32
33    pub fn get(&self, name: &str) -> Option<&Weak<Extension>> {
35        self.exts.get(name)
36    }
37
38    pub fn contains(&self, name: &str) -> bool {
40        self.exts.contains_key(name)
41    }
42
43    pub fn register(&mut self, id: ExtensionId, ext: impl Into<Weak<Extension>>) -> bool {
47        self.exts.insert(id, ext.into()).is_none()
48    }
49
50    pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Weak<Extension>)> {
52        self.exts.iter()
53    }
54
55    pub fn extensions(&self) -> impl Iterator<Item = &Weak<Extension>> {
57        self.exts.values()
58    }
59
60    pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
62        self.exts.keys()
63    }
64}
65
66impl IntoIterator for WeakExtensionRegistry {
67    type Item = Weak<Extension>;
68    type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Weak<Extension>>;
69
70    fn into_iter(self) -> Self::IntoIter {
71        self.exts.into_values()
72    }
73}
74
75impl<'a> TryFrom<&'a WeakExtensionRegistry> for ExtensionRegistry {
76    type Error = ();
77
78    fn try_from(weak: &'a WeakExtensionRegistry) -> Result<Self, Self::Error> {
79        let exts: Vec<Arc<Extension>> = weak
80            .extensions()
81            .map(|w| w.upgrade().ok_or(()))
82            .try_collect()?;
83        Ok(ExtensionRegistry::new(exts))
84    }
85}
86
87impl TryFrom<WeakExtensionRegistry> for ExtensionRegistry {
88    type Error = ();
89
90    fn try_from(weak: WeakExtensionRegistry) -> Result<Self, Self::Error> {
91        let exts: Vec<Arc<Extension>> = weak
92            .into_iter()
93            .map(|w| w.upgrade().ok_or(()))
94            .try_collect()?;
95        Ok(ExtensionRegistry::new(exts))
96    }
97}
98
99impl<'a> From<&'a ExtensionRegistry> for WeakExtensionRegistry {
100    fn from(reg: &'a ExtensionRegistry) -> Self {
101        let exts = reg
102            .iter()
103            .map(|ext| (ext.name().clone(), Arc::downgrade(ext)))
104            .collect();
105        Self { exts }
106    }
107}