use crate::protocol::tables::StubReference;
use crate::RpcTarget;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicI64, Ordering},
Arc, RwLock,
},
};
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct CapabilityRegistry {
capabilities: RwLock<HashMap<i64, Arc<dyn RpcTarget>>>,
reverse_map: RwLock<HashMap<usize, i64>>,
next_id: AtomicI64,
ref_counts: RwLock<HashMap<i64, u32>>,
}
impl CapabilityRegistry {
pub fn new() -> Self {
Self {
capabilities: RwLock::new(HashMap::new()),
reverse_map: RwLock::new(HashMap::new()),
next_id: AtomicI64::new(1), ref_counts: RwLock::new(HashMap::new()),
}
}
pub fn export_capability(&self, capability: Arc<dyn RpcTarget>) -> i64 {
let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
if let Ok(reverse_map) = self.reverse_map.read() {
if let Some(&existing_id) = reverse_map.get(&ptr_addr) {
if let Ok(mut ref_counts) = self.ref_counts.write() {
*ref_counts.entry(existing_id).or_insert(0) += 1;
}
debug!("Reusing existing capability export: ID {}", existing_id);
return existing_id;
}
}
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
if let (Ok(mut capabilities), Ok(mut reverse_map), Ok(mut ref_counts)) = (
self.capabilities.write(),
self.reverse_map.write(),
self.ref_counts.write(),
) {
capabilities.insert(id, capability);
reverse_map.insert(ptr_addr, id);
ref_counts.insert(id, 1);
info!("Exported new capability: ID {}", id);
}
id
}
pub fn import_capability(&self, id: i64) -> Option<Arc<dyn RpcTarget>> {
if let Ok(capabilities) = self.capabilities.read() {
capabilities.get(&id).cloned()
} else {
None
}
}
pub fn has_capability(&self, id: i64) -> bool {
if let Ok(capabilities) = self.capabilities.read() {
capabilities.contains_key(&id)
} else {
false
}
}
pub fn release_capability(&self, id: i64) -> bool {
if let Ok(mut ref_counts) = self.ref_counts.write() {
if let Some(count) = ref_counts.get_mut(&id) {
*count = count.saturating_sub(1);
if *count == 0 {
ref_counts.remove(&id);
if let (Ok(mut capabilities), Ok(mut reverse_map)) =
(self.capabilities.write(), self.reverse_map.write())
{
if let Some(capability) = capabilities.remove(&id) {
let ptr_addr = Arc::as_ptr(&capability) as *const () as usize;
reverse_map.remove(&ptr_addr);
}
}
info!("Released capability: ID {}", id);
true
} else {
debug!(
"Decremented capability ref count: ID {} (count: {})",
id, count
);
false
}
} else {
warn!("Attempted to release unknown capability: ID {}", id);
false
}
} else {
false
}
}
pub fn get_ref_count(&self, id: i64) -> u32 {
if let Ok(ref_counts) = self.ref_counts.read() {
ref_counts.get(&id).copied().unwrap_or(0)
} else {
0
}
}
pub fn get_exported_ids(&self) -> Vec<i64> {
if let Ok(capabilities) = self.capabilities.read() {
capabilities.keys().copied().collect()
} else {
Vec::new()
}
}
pub fn create_stub_reference(&self, id: i64) -> Option<StubReference> {
self.import_capability(id).map(StubReference::new)
}
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait RegistrableCapability: RpcTarget {
fn name(&self) -> &str {
"Unknown"
}
fn on_export(&self, _id: i64) {}
fn on_import(&self, _id: i64) {}
fn on_release(&self, _id: i64) {}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MockRpcTarget;
#[test]
fn test_capability_export_import() {
let registry = CapabilityRegistry::new();
let capability = Arc::new(MockRpcTarget::new());
let id = registry.export_capability(capability.clone());
assert!(id > 0);
let imported = registry.import_capability(id);
assert!(imported.is_some());
let imported = imported.unwrap();
assert_eq!(
Arc::as_ptr(&capability) as *const (),
Arc::as_ptr(&imported) as *const ()
);
}
#[test]
fn test_capability_ref_counting() {
let registry = CapabilityRegistry::new();
let capability = Arc::new(MockRpcTarget::new());
let id1 = registry.export_capability(capability.clone());
let id2 = registry.export_capability(capability.clone());
assert_eq!(id1, id2);
assert_eq!(registry.get_ref_count(id1), 2);
assert!(!registry.release_capability(id1));
assert_eq!(registry.get_ref_count(id1), 1);
assert!(registry.has_capability(id1));
assert!(registry.release_capability(id1));
assert_eq!(registry.get_ref_count(id1), 0);
assert!(!registry.has_capability(id1));
}
#[test]
fn test_stub_reference_creation() {
let registry = CapabilityRegistry::new();
let capability = Arc::new(MockRpcTarget::new());
let id = registry.export_capability(capability);
let stub_ref = registry.create_stub_reference(id);
assert!(stub_ref.is_some());
let invalid_stub = registry.create_stub_reference(999);
assert!(invalid_stub.is_none());
}
}