capnweb_core/protocol/
capability_registry.rs1use crate::protocol::tables::StubReference;
5use crate::RpcTarget;
6use std::{
7 collections::HashMap,
8 sync::{
9 atomic::{AtomicI64, Ordering},
10 Arc, RwLock,
11 },
12};
13use tracing::{debug, info, warn};
14
15#[derive(Debug)]
18pub struct CapabilityRegistry {
19 capabilities: RwLock<HashMap<i64, Arc<dyn RpcTarget>>>,
21
22 reverse_map: RwLock<HashMap<usize, i64>>,
24
25 next_id: AtomicI64,
27
28 ref_counts: RwLock<HashMap<i64, u32>>,
30}
31
32impl CapabilityRegistry {
33 pub fn new() -> Self {
34 Self {
35 capabilities: RwLock::new(HashMap::new()),
36 reverse_map: RwLock::new(HashMap::new()),
37 next_id: AtomicI64::new(1), ref_counts: RwLock::new(HashMap::new()),
39 }
40 }
41
42 pub fn export_capability(&self, capability: Arc<dyn RpcTarget>) -> i64 {
44 let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
45
46 if let Ok(reverse_map) = self.reverse_map.read() {
48 if let Some(&existing_id) = reverse_map.get(&ptr_addr) {
49 if let Ok(mut ref_counts) = self.ref_counts.write() {
51 *ref_counts.entry(existing_id).or_insert(0) += 1;
52 }
53 debug!("Reusing existing capability export: ID {}", existing_id);
54 return existing_id;
55 }
56 }
57
58 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
60
61 if let (Ok(mut capabilities), Ok(mut reverse_map), Ok(mut ref_counts)) = (
63 self.capabilities.write(),
64 self.reverse_map.write(),
65 self.ref_counts.write(),
66 ) {
67 capabilities.insert(id, capability);
68 reverse_map.insert(ptr_addr, id);
69 ref_counts.insert(id, 1);
70
71 info!("Exported new capability: ID {}", id);
72 }
73
74 id
75 }
76
77 pub fn import_capability(&self, id: i64) -> Option<Arc<dyn RpcTarget>> {
79 if let Ok(capabilities) = self.capabilities.read() {
80 capabilities.get(&id).cloned()
81 } else {
82 None
83 }
84 }
85
86 pub fn has_capability(&self, id: i64) -> bool {
88 if let Ok(capabilities) = self.capabilities.read() {
89 capabilities.contains_key(&id)
90 } else {
91 false
92 }
93 }
94
95 pub fn release_capability(&self, id: i64) -> bool {
97 if let Ok(mut ref_counts) = self.ref_counts.write() {
98 if let Some(count) = ref_counts.get_mut(&id) {
99 *count = count.saturating_sub(1);
100
101 if *count == 0 {
102 ref_counts.remove(&id);
104
105 if let (Ok(mut capabilities), Ok(mut reverse_map)) =
106 (self.capabilities.write(), self.reverse_map.write())
107 {
108 if let Some(capability) = capabilities.remove(&id) {
109 let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
110 reverse_map.remove(&ptr_addr);
111 }
112 }
113
114 info!("Released capability: ID {}", id);
115 true
116 } else {
117 debug!(
118 "Decremented capability ref count: ID {} (count: {})",
119 id, count
120 );
121 false
122 }
123 } else {
124 warn!("Attempted to release unknown capability: ID {}", id);
125 false
126 }
127 } else {
128 false
129 }
130 }
131
132 pub fn get_ref_count(&self, id: i64) -> u32 {
134 if let Ok(ref_counts) = self.ref_counts.read() {
135 ref_counts.get(&id).copied().unwrap_or(0)
136 } else {
137 0
138 }
139 }
140
141 pub fn get_exported_ids(&self) -> Vec<i64> {
143 if let Ok(capabilities) = self.capabilities.read() {
144 capabilities.keys().copied().collect()
145 } else {
146 Vec::new()
147 }
148 }
149
150 pub fn create_stub_reference(&self, id: i64) -> Option<StubReference> {
152 self.import_capability(id).map(StubReference::new)
153 }
154}
155
156impl Default for CapabilityRegistry {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162pub trait RegistrableCapability: RpcTarget {
164 fn name(&self) -> &str {
166 "Unknown"
167 }
168
169 fn on_export(&self, _id: i64) {}
171
172 fn on_import(&self, _id: i64) {}
174
175 fn on_release(&self, _id: i64) {}
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::MockRpcTarget;
183
184 #[test]
185 fn test_capability_export_import() {
186 let registry = CapabilityRegistry::new();
187 let capability = Arc::new(MockRpcTarget::new());
188
189 let id = registry.export_capability(capability.clone());
191 assert!(id > 0);
192
193 let imported = registry.import_capability(id);
195 assert!(imported.is_some());
196
197 let imported = imported.unwrap();
199 assert_eq!(
200 Arc::as_ptr(&capability) as *const (),
201 Arc::as_ptr(&imported) as *const ()
202 );
203 }
204
205 #[test]
206 fn test_capability_ref_counting() {
207 let registry = CapabilityRegistry::new();
208 let capability = Arc::new(MockRpcTarget::new());
209
210 let id1 = registry.export_capability(capability.clone());
212 let id2 = registry.export_capability(capability.clone());
213
214 assert_eq!(id1, id2);
216
217 assert_eq!(registry.get_ref_count(id1), 2);
219
220 assert!(!registry.release_capability(id1));
222 assert_eq!(registry.get_ref_count(id1), 1);
223 assert!(registry.has_capability(id1));
224
225 assert!(registry.release_capability(id1));
227 assert_eq!(registry.get_ref_count(id1), 0);
228 assert!(!registry.has_capability(id1));
229 }
230
231 #[test]
232 fn test_stub_reference_creation() {
233 let registry = CapabilityRegistry::new();
234 let capability = Arc::new(MockRpcTarget::new());
235
236 let id = registry.export_capability(capability);
237 let stub_ref = registry.create_stub_reference(id);
238
239 assert!(stub_ref.is_some());
240
241 let invalid_stub = registry.create_stub_reference(999);
243 assert!(invalid_stub.is_none());
244 }
245}