capnweb_core/protocol/
capability_registry.rs

1// Capability Registry for bidirectional capability marshaling
2// This enables real capability passing across HTTP batch requests and WebSocket connections
3
4use crate::protocol::tables::StubReference;
5use crate::RpcTarget;
6use std::{
7    collections::HashMap,
8    sync::{
9        atomic::{AtomicI64, Ordering},
10        Arc, RwLock,
11    },
12};
13use tracing::{debug, info, warn};
14
15/// Registry for managing capability references across protocol boundaries
16/// Supports both import and export of capabilities with proper lifecycle management
17#[derive(Debug)]
18pub struct CapabilityRegistry {
19    /// Map from capability ID to the actual capability implementation
20    capabilities: RwLock<HashMap<i64, Arc<dyn RpcTarget>>>,
21
22    /// Map from Arc pointer address to capability ID (for reverse lookup)
23    reverse_map: RwLock<HashMap<usize, i64>>,
24
25    /// Next capability ID to assign
26    next_id: AtomicI64,
27
28    /// Reference count for each capability
29    ref_counts: RwLock<HashMap<i64, u32>>,
30}
31
32impl CapabilityRegistry {
33    pub fn new() -> Self {
34        Self {
35            capabilities: RwLock::new(HashMap::new()),
36            reverse_map: RwLock::new(HashMap::new()),
37            next_id: AtomicI64::new(1), // Start from 1, 0 is reserved for main capability
38            ref_counts: RwLock::new(HashMap::new()),
39        }
40    }
41
42    /// Export a capability and return its ID for wire marshaling
43    pub fn export_capability(&self, capability: Arc<dyn RpcTarget>) -> i64 {
44        let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
45
46        // Check if this capability is already exported
47        if let Ok(reverse_map) = self.reverse_map.read() {
48            if let Some(&existing_id) = reverse_map.get(&ptr_addr) {
49                // Increment reference count
50                if let Ok(mut ref_counts) = self.ref_counts.write() {
51                    *ref_counts.entry(existing_id).or_insert(0) += 1;
52                }
53                debug!("Reusing existing capability export: ID {}", existing_id);
54                return existing_id;
55            }
56        }
57
58        // Assign new ID
59        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
60
61        // Store both mappings
62        if let (Ok(mut capabilities), Ok(mut reverse_map), Ok(mut ref_counts)) = (
63            self.capabilities.write(),
64            self.reverse_map.write(),
65            self.ref_counts.write(),
66        ) {
67            capabilities.insert(id, capability);
68            reverse_map.insert(ptr_addr, id);
69            ref_counts.insert(id, 1);
70
71            info!("Exported new capability: ID {}", id);
72        }
73
74        id
75    }
76
77    /// Import a capability by ID for method calls
78    pub fn import_capability(&self, id: i64) -> Option<Arc<dyn RpcTarget>> {
79        if let Ok(capabilities) = self.capabilities.read() {
80            capabilities.get(&id).cloned()
81        } else {
82            None
83        }
84    }
85
86    /// Check if a capability ID exists
87    pub fn has_capability(&self, id: i64) -> bool {
88        if let Ok(capabilities) = self.capabilities.read() {
89            capabilities.contains_key(&id)
90        } else {
91            false
92        }
93    }
94
95    /// Release a capability reference (decrement ref count)
96    pub fn release_capability(&self, id: i64) -> bool {
97        if let Ok(mut ref_counts) = self.ref_counts.write() {
98            if let Some(count) = ref_counts.get_mut(&id) {
99                *count = count.saturating_sub(1);
100
101                if *count == 0 {
102                    // Remove from all maps
103                    ref_counts.remove(&id);
104
105                    if let (Ok(mut capabilities), Ok(mut reverse_map)) =
106                        (self.capabilities.write(), self.reverse_map.write())
107                    {
108                        if let Some(capability) = capabilities.remove(&id) {
109                            let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
110                            reverse_map.remove(&ptr_addr);
111                        }
112                    }
113
114                    info!("Released capability: ID {}", id);
115                    true
116                } else {
117                    debug!(
118                        "Decremented capability ref count: ID {} (count: {})",
119                        id, count
120                    );
121                    false
122                }
123            } else {
124                warn!("Attempted to release unknown capability: ID {}", id);
125                false
126            }
127        } else {
128            false
129        }
130    }
131
132    /// Get current reference count for a capability
133    pub fn get_ref_count(&self, id: i64) -> u32 {
134        if let Ok(ref_counts) = self.ref_counts.read() {
135            ref_counts.get(&id).copied().unwrap_or(0)
136        } else {
137            0
138        }
139    }
140
141    /// Get all exported capability IDs
142    pub fn get_exported_ids(&self) -> Vec<i64> {
143        if let Ok(capabilities) = self.capabilities.read() {
144            capabilities.keys().copied().collect()
145        } else {
146            Vec::new()
147        }
148    }
149
150    /// Create a stub reference for a capability (for import table integration)
151    pub fn create_stub_reference(&self, id: i64) -> Option<StubReference> {
152        self.import_capability(id).map(StubReference::new)
153    }
154}
155
156impl Default for CapabilityRegistry {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162/// Trait for capabilities that can be registered in the registry
163pub trait RegistrableCapability: RpcTarget {
164    /// Get a display name for this capability (for debugging)
165    fn name(&self) -> &str {
166        "Unknown"
167    }
168
169    /// Called when the capability is exported
170    fn on_export(&self, _id: i64) {}
171
172    /// Called when the capability is imported
173    fn on_import(&self, _id: i64) {}
174
175    /// Called when the capability is released
176    fn on_release(&self, _id: i64) {}
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::MockRpcTarget;
183
184    #[test]
185    fn test_capability_export_import() {
186        let registry = CapabilityRegistry::new();
187        let capability = Arc::new(MockRpcTarget::new());
188
189        // Export capability
190        let id = registry.export_capability(capability.clone());
191        assert!(id > 0);
192
193        // Import capability
194        let imported = registry.import_capability(id);
195        assert!(imported.is_some());
196
197        // Verify it's the same capability
198        let imported = imported.unwrap();
199        assert_eq!(
200            Arc::as_ptr(&capability) as *const (),
201            Arc::as_ptr(&imported) as *const ()
202        );
203    }
204
205    #[test]
206    fn test_capability_ref_counting() {
207        let registry = CapabilityRegistry::new();
208        let capability = Arc::new(MockRpcTarget::new());
209
210        // Export same capability twice
211        let id1 = registry.export_capability(capability.clone());
212        let id2 = registry.export_capability(capability.clone());
213
214        // Should get same ID
215        assert_eq!(id1, id2);
216
217        // Should have ref count of 2
218        assert_eq!(registry.get_ref_count(id1), 2);
219
220        // Release once - should still exist
221        assert!(!registry.release_capability(id1));
222        assert_eq!(registry.get_ref_count(id1), 1);
223        assert!(registry.has_capability(id1));
224
225        // Release again - should be removed
226        assert!(registry.release_capability(id1));
227        assert_eq!(registry.get_ref_count(id1), 0);
228        assert!(!registry.has_capability(id1));
229    }
230
231    #[test]
232    fn test_stub_reference_creation() {
233        let registry = CapabilityRegistry::new();
234        let capability = Arc::new(MockRpcTarget::new());
235
236        let id = registry.export_capability(capability);
237        let stub_ref = registry.create_stub_reference(id);
238
239        assert!(stub_ref.is_some());
240
241        // Test with non-existent ID
242        let invalid_stub = registry.create_stub_reference(999);
243        assert!(invalid_stub.is_none());
244    }
245}