1use async_trait::async_trait;
2use capnweb_core::{CapId, RpcError};
3use dashmap::DashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8#[async_trait]
10pub trait Disposable: Send + Sync {
11 async fn dispose(&self) -> Result<(), RpcError>;
13}
14
15pub struct CapabilityLifecycle {
17 ref_counts: Arc<DashMap<CapId, usize>>,
19 dispose_callbacks: Arc<DashMap<CapId, Arc<dyn Disposable>>>,
21 session_caps: Arc<RwLock<DashMap<String, Vec<CapId>>>>,
23}
24
25impl CapabilityLifecycle {
26 pub fn new() -> Self {
27 Self {
28 ref_counts: Arc::new(DashMap::new()),
29 dispose_callbacks: Arc::new(DashMap::new()),
30 session_caps: Arc::new(RwLock::new(DashMap::new())),
31 }
32 }
33
34 pub async fn register(
36 &self,
37 cap_id: CapId,
38 session_id: Option<String>,
39 disposable: Option<Arc<dyn Disposable>>,
40 ) {
41 self.ref_counts.insert(cap_id, 1);
43
44 if let Some(disposable) = disposable {
46 self.dispose_callbacks.insert(cap_id, disposable);
47 }
48
49 if let Some(session_id) = session_id {
51 let session_caps = self.session_caps.write().await;
52 session_caps
53 .entry(session_id)
54 .or_insert_with(Vec::new)
55 .push(cap_id);
56 }
57
58 debug!("Registered capability {:?}", cap_id);
59 }
60
61 pub fn retain(&self, cap_id: &CapId) -> Result<(), RpcError> {
63 if let Some(mut count) = self.ref_counts.get_mut(cap_id) {
64 *count += 1;
65 debug!("Retained capability {:?}, ref_count = {}", cap_id, *count);
66 Ok(())
67 } else {
68 Err(RpcError::not_found(format!(
69 "Capability {:?} not found",
70 cap_id
71 )))
72 }
73 }
74
75 pub async fn release(&self, cap_id: &CapId) -> Result<bool, RpcError> {
77 let should_dispose = {
78 let mut should_dispose = false;
79
80 if let Some(mut count) = self.ref_counts.get_mut(cap_id) {
81 *count = count.saturating_sub(1);
82 debug!("Released capability {:?}, ref_count = {}", cap_id, *count);
83
84 if *count == 0 {
85 should_dispose = true;
86 }
87 } else {
88 return Err(RpcError::not_found(format!(
89 "Capability {:?} not found",
90 cap_id
91 )));
92 }
93
94 should_dispose
95 };
96
97 if should_dispose {
98 self.dispose_internal(cap_id).await?;
99 Ok(true)
100 } else {
101 Ok(false)
102 }
103 }
104
105 pub async fn dispose(&self, cap_id: &CapId) -> Result<(), RpcError> {
107 debug!("Force disposing capability {:?}", cap_id);
108 self.dispose_internal(cap_id).await
109 }
110
111 async fn dispose_internal(&self, cap_id: &CapId) -> Result<(), RpcError> {
113 self.ref_counts.remove(cap_id);
115
116 if let Some((_, disposable)) = self.dispose_callbacks.remove(cap_id) {
118 debug!("Calling disposal callback for capability {:?}", cap_id);
119 if let Err(e) = disposable.dispose().await {
120 warn!(
121 "Disposal callback failed for capability {:?}: {}",
122 cap_id, e
123 );
124 return Err(e);
125 }
126 }
127
128 let session_caps = self.session_caps.write().await;
130 for mut caps in session_caps.iter_mut() {
131 caps.retain(|&id| id != *cap_id);
132 }
133
134 debug!("Disposed capability {:?}", cap_id);
135 Ok(())
136 }
137
138 pub async fn dispose_session(&self, session_id: &str) -> Result<(), RpcError> {
140 debug!("Disposing all capabilities for session {}", session_id);
141
142 let cap_ids = {
143 let session_caps = self.session_caps.read().await;
144 session_caps
145 .get(session_id)
146 .map(|caps| caps.clone())
147 .unwrap_or_default()
148 };
149
150 let mut errors = Vec::new();
151 for cap_id in cap_ids {
152 if let Err(e) = self.dispose(&cap_id).await {
153 errors.push(format!("{:?}: {}", cap_id, e));
154 }
155 }
156
157 {
159 let session_caps = self.session_caps.write().await;
160 session_caps.remove(session_id);
161 }
162
163 if !errors.is_empty() {
164 Err(RpcError::internal(format!(
165 "Failed to dispose some capabilities: {}",
166 errors.join(", ")
167 )))
168 } else {
169 Ok(())
170 }
171 }
172
173 pub fn ref_count(&self, cap_id: &CapId) -> Option<usize> {
175 self.ref_counts.get(cap_id).map(|count| *count)
176 }
177
178 pub fn is_alive(&self, cap_id: &CapId) -> bool {
180 self.ref_count(cap_id)
181 .map(|count| count > 0)
182 .unwrap_or(false)
183 }
184
185 pub async fn session_capabilities(&self, session_id: &str) -> Vec<CapId> {
187 let session_caps = self.session_caps.read().await;
188 session_caps
189 .get(session_id)
190 .map(|caps| caps.clone())
191 .unwrap_or_default()
192 }
193
194 pub async fn stats(&self) -> LifecycleStats {
196 let session_caps = self.session_caps.read().await;
197 let total_sessions = session_caps.len();
198 let total_caps = self.ref_counts.len();
199 let with_callbacks = self.dispose_callbacks.len();
200
201 LifecycleStats {
202 total_capabilities: total_caps,
203 with_dispose_callbacks: with_callbacks,
204 total_sessions,
205 }
206 }
207}
208
209impl Default for CapabilityLifecycle {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[derive(Debug, Clone)]
216pub struct LifecycleStats {
217 pub total_capabilities: usize,
218 pub with_dispose_callbacks: usize,
219 pub total_sessions: usize,
220}
221
222pub struct DisposableResource {
224 name: String,
225}
226
227impl DisposableResource {
228 pub fn new(name: String) -> Self {
229 Self { name }
230 }
231}
232
233#[async_trait]
234impl Disposable for DisposableResource {
235 async fn dispose(&self) -> Result<(), RpcError> {
236 debug!("Disposing resource: {}", self.name);
237 Ok(())
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[tokio::test]
247 async fn test_capability_registration() {
248 let lifecycle = CapabilityLifecycle::new();
249 let cap_id = CapId::new(1);
250
251 lifecycle.register(cap_id, None, None).await;
252 assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
253 assert!(lifecycle.is_alive(&cap_id));
254 }
255
256 #[tokio::test]
257 async fn test_retain_and_release() {
258 let lifecycle = CapabilityLifecycle::new();
259 let cap_id = CapId::new(1);
260
261 lifecycle.register(cap_id, None, None).await;
262 assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
263
264 lifecycle.retain(&cap_id).unwrap();
265 assert_eq!(lifecycle.ref_count(&cap_id), Some(2));
266
267 let disposed = lifecycle.release(&cap_id).await.unwrap();
268 assert!(!disposed);
269 assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
270
271 let disposed = lifecycle.release(&cap_id).await.unwrap();
272 assert!(disposed);
273 assert_eq!(lifecycle.ref_count(&cap_id), None);
274 }
275
276 #[tokio::test]
277 async fn test_disposal_callback() {
278 let lifecycle = CapabilityLifecycle::new();
279 let cap_id = CapId::new(1);
280
281 let resource = Arc::new(DisposableResource::new("test".to_string()));
282 lifecycle.register(cap_id, None, Some(resource)).await;
283
284 lifecycle.dispose(&cap_id).await.unwrap();
285 assert!(!lifecycle.is_alive(&cap_id));
286 }
287
288 #[tokio::test]
289 async fn test_session_management() {
290 let lifecycle = CapabilityLifecycle::new();
291 let session_id = "session1".to_string();
292
293 let cap1 = CapId::new(1);
294 let cap2 = CapId::new(2);
295
296 lifecycle
297 .register(cap1, Some(session_id.clone()), None)
298 .await;
299 lifecycle
300 .register(cap2, Some(session_id.clone()), None)
301 .await;
302
303 let caps = lifecycle.session_capabilities(&session_id).await;
304 assert_eq!(caps.len(), 2);
305
306 lifecycle.dispose_session(&session_id).await.unwrap();
307
308 assert!(!lifecycle.is_alive(&cap1));
309 assert!(!lifecycle.is_alive(&cap2));
310
311 let caps = lifecycle.session_capabilities(&session_id).await;
312 assert_eq!(caps.len(), 0);
313 }
314
315 #[tokio::test]
316 async fn test_force_dispose() {
317 let lifecycle = CapabilityLifecycle::new();
318 let cap_id = CapId::new(1);
319
320 lifecycle.register(cap_id, None, None).await;
321 lifecycle.retain(&cap_id).unwrap();
322 lifecycle.retain(&cap_id).unwrap();
323 assert_eq!(lifecycle.ref_count(&cap_id), Some(3));
324
325 lifecycle.dispose(&cap_id).await.unwrap();
326 assert!(!lifecycle.is_alive(&cap_id));
327 }
328
329 #[tokio::test]
330 async fn test_lifecycle_stats() {
331 let lifecycle = CapabilityLifecycle::new();
332
333 lifecycle
334 .register(CapId::new(1), Some("s1".to_string()), None)
335 .await;
336 lifecycle
337 .register(
338 CapId::new(2),
339 Some("s1".to_string()),
340 Some(Arc::new(DisposableResource::new("r1".to_string()))),
341 )
342 .await;
343 lifecycle
344 .register(CapId::new(3), Some("s2".to_string()), None)
345 .await;
346
347 let stats = lifecycle.stats().await;
348 assert_eq!(stats.total_capabilities, 3);
349 assert_eq!(stats.with_dispose_callbacks, 1);
350 assert_eq!(stats.total_sessions, 2);
351 }
352}