capnweb_server/
lifecycle.rs

1use async_trait::async_trait;
2use capnweb_core::{CapId, RpcError};
3use dashmap::DashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8/// Trait for objects that need cleanup when disposed
9#[async_trait]
10pub trait Disposable: Send + Sync {
11    /// Called when the capability is being disposed
12    async fn dispose(&self) -> Result<(), RpcError>;
13}
14
15/// Tracks the lifecycle of capabilities
16pub struct CapabilityLifecycle {
17    /// Reference counts for capabilities
18    ref_counts: Arc<DashMap<CapId, usize>>,
19    /// Disposal callbacks for capabilities
20    dispose_callbacks: Arc<DashMap<CapId, Arc<dyn Disposable>>>,
21    /// Capabilities owned by each session
22    session_caps: Arc<RwLock<DashMap<String, Vec<CapId>>>>,
23}
24
25impl CapabilityLifecycle {
26    pub fn new() -> Self {
27        Self {
28            ref_counts: Arc::new(DashMap::new()),
29            dispose_callbacks: Arc::new(DashMap::new()),
30            session_caps: Arc::new(RwLock::new(DashMap::new())),
31        }
32    }
33
34    /// Register a new capability with optional disposal callback
35    pub async fn register(
36        &self,
37        cap_id: CapId,
38        session_id: Option<String>,
39        disposable: Option<Arc<dyn Disposable>>,
40    ) {
41        // Initialize reference count
42        self.ref_counts.insert(cap_id, 1);
43
44        // Register disposal callback if provided
45        if let Some(disposable) = disposable {
46            self.dispose_callbacks.insert(cap_id, disposable);
47        }
48
49        // Track session ownership if provided
50        if let Some(session_id) = session_id {
51            let session_caps = self.session_caps.write().await;
52            session_caps
53                .entry(session_id)
54                .or_insert_with(Vec::new)
55                .push(cap_id);
56        }
57
58        debug!("Registered capability {:?}", cap_id);
59    }
60
61    /// Increment reference count for a capability
62    pub fn retain(&self, cap_id: &CapId) -> Result<(), RpcError> {
63        if let Some(mut count) = self.ref_counts.get_mut(cap_id) {
64            *count += 1;
65            debug!("Retained capability {:?}, ref_count = {}", cap_id, *count);
66            Ok(())
67        } else {
68            Err(RpcError::not_found(format!(
69                "Capability {:?} not found",
70                cap_id
71            )))
72        }
73    }
74
75    /// Decrement reference count and dispose if it reaches zero
76    pub async fn release(&self, cap_id: &CapId) -> Result<bool, RpcError> {
77        let should_dispose = {
78            let mut should_dispose = false;
79
80            if let Some(mut count) = self.ref_counts.get_mut(cap_id) {
81                *count = count.saturating_sub(1);
82                debug!("Released capability {:?}, ref_count = {}", cap_id, *count);
83
84                if *count == 0 {
85                    should_dispose = true;
86                }
87            } else {
88                return Err(RpcError::not_found(format!(
89                    "Capability {:?} not found",
90                    cap_id
91                )));
92            }
93
94            should_dispose
95        };
96
97        if should_dispose {
98            self.dispose_internal(cap_id).await?;
99            Ok(true)
100        } else {
101            Ok(false)
102        }
103    }
104
105    /// Force disposal of a capability regardless of reference count
106    pub async fn dispose(&self, cap_id: &CapId) -> Result<(), RpcError> {
107        debug!("Force disposing capability {:?}", cap_id);
108        self.dispose_internal(cap_id).await
109    }
110
111    /// Internal disposal implementation
112    async fn dispose_internal(&self, cap_id: &CapId) -> Result<(), RpcError> {
113        // Remove from reference counts
114        self.ref_counts.remove(cap_id);
115
116        // Call disposal callback if exists
117        if let Some((_, disposable)) = self.dispose_callbacks.remove(cap_id) {
118            debug!("Calling disposal callback for capability {:?}", cap_id);
119            if let Err(e) = disposable.dispose().await {
120                warn!(
121                    "Disposal callback failed for capability {:?}: {}",
122                    cap_id, e
123                );
124                return Err(e);
125            }
126        }
127
128        // Remove from session tracking
129        let session_caps = self.session_caps.write().await;
130        for mut caps in session_caps.iter_mut() {
131            caps.retain(|&id| id != *cap_id);
132        }
133
134        debug!("Disposed capability {:?}", cap_id);
135        Ok(())
136    }
137
138    /// Dispose all capabilities owned by a session
139    pub async fn dispose_session(&self, session_id: &str) -> Result<(), RpcError> {
140        debug!("Disposing all capabilities for session {}", session_id);
141
142        let cap_ids = {
143            let session_caps = self.session_caps.read().await;
144            session_caps
145                .get(session_id)
146                .map(|caps| caps.clone())
147                .unwrap_or_default()
148        };
149
150        let mut errors = Vec::new();
151        for cap_id in cap_ids {
152            if let Err(e) = self.dispose(&cap_id).await {
153                errors.push(format!("{:?}: {}", cap_id, e));
154            }
155        }
156
157        // Remove session from tracking
158        {
159            let session_caps = self.session_caps.write().await;
160            session_caps.remove(session_id);
161        }
162
163        if !errors.is_empty() {
164            Err(RpcError::internal(format!(
165                "Failed to dispose some capabilities: {}",
166                errors.join(", ")
167            )))
168        } else {
169            Ok(())
170        }
171    }
172
173    /// Get reference count for a capability
174    pub fn ref_count(&self, cap_id: &CapId) -> Option<usize> {
175        self.ref_counts.get(cap_id).map(|count| *count)
176    }
177
178    /// Check if a capability is alive (has references)
179    pub fn is_alive(&self, cap_id: &CapId) -> bool {
180        self.ref_count(cap_id)
181            .map(|count| count > 0)
182            .unwrap_or(false)
183    }
184
185    /// Get all capabilities for a session
186    pub async fn session_capabilities(&self, session_id: &str) -> Vec<CapId> {
187        let session_caps = self.session_caps.read().await;
188        session_caps
189            .get(session_id)
190            .map(|caps| caps.clone())
191            .unwrap_or_default()
192    }
193
194    /// Get lifecycle statistics
195    pub async fn stats(&self) -> LifecycleStats {
196        let session_caps = self.session_caps.read().await;
197        let total_sessions = session_caps.len();
198        let total_caps = self.ref_counts.len();
199        let with_callbacks = self.dispose_callbacks.len();
200
201        LifecycleStats {
202            total_capabilities: total_caps,
203            with_dispose_callbacks: with_callbacks,
204            total_sessions,
205        }
206    }
207}
208
209impl Default for CapabilityLifecycle {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215#[derive(Debug, Clone)]
216pub struct LifecycleStats {
217    pub total_capabilities: usize,
218    pub with_dispose_callbacks: usize,
219    pub total_sessions: usize,
220}
221
222/// Example disposable resource
223pub struct DisposableResource {
224    name: String,
225}
226
227impl DisposableResource {
228    pub fn new(name: String) -> Self {
229        Self { name }
230    }
231}
232
233#[async_trait]
234impl Disposable for DisposableResource {
235    async fn dispose(&self) -> Result<(), RpcError> {
236        debug!("Disposing resource: {}", self.name);
237        // Cleanup logic here (close files, connections, etc.)
238        Ok(())
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[tokio::test]
247    async fn test_capability_registration() {
248        let lifecycle = CapabilityLifecycle::new();
249        let cap_id = CapId::new(1);
250
251        lifecycle.register(cap_id, None, None).await;
252        assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
253        assert!(lifecycle.is_alive(&cap_id));
254    }
255
256    #[tokio::test]
257    async fn test_retain_and_release() {
258        let lifecycle = CapabilityLifecycle::new();
259        let cap_id = CapId::new(1);
260
261        lifecycle.register(cap_id, None, None).await;
262        assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
263
264        lifecycle.retain(&cap_id).unwrap();
265        assert_eq!(lifecycle.ref_count(&cap_id), Some(2));
266
267        let disposed = lifecycle.release(&cap_id).await.unwrap();
268        assert!(!disposed);
269        assert_eq!(lifecycle.ref_count(&cap_id), Some(1));
270
271        let disposed = lifecycle.release(&cap_id).await.unwrap();
272        assert!(disposed);
273        assert_eq!(lifecycle.ref_count(&cap_id), None);
274    }
275
276    #[tokio::test]
277    async fn test_disposal_callback() {
278        let lifecycle = CapabilityLifecycle::new();
279        let cap_id = CapId::new(1);
280
281        let resource = Arc::new(DisposableResource::new("test".to_string()));
282        lifecycle.register(cap_id, None, Some(resource)).await;
283
284        lifecycle.dispose(&cap_id).await.unwrap();
285        assert!(!lifecycle.is_alive(&cap_id));
286    }
287
288    #[tokio::test]
289    async fn test_session_management() {
290        let lifecycle = CapabilityLifecycle::new();
291        let session_id = "session1".to_string();
292
293        let cap1 = CapId::new(1);
294        let cap2 = CapId::new(2);
295
296        lifecycle
297            .register(cap1, Some(session_id.clone()), None)
298            .await;
299        lifecycle
300            .register(cap2, Some(session_id.clone()), None)
301            .await;
302
303        let caps = lifecycle.session_capabilities(&session_id).await;
304        assert_eq!(caps.len(), 2);
305
306        lifecycle.dispose_session(&session_id).await.unwrap();
307
308        assert!(!lifecycle.is_alive(&cap1));
309        assert!(!lifecycle.is_alive(&cap2));
310
311        let caps = lifecycle.session_capabilities(&session_id).await;
312        assert_eq!(caps.len(), 0);
313    }
314
315    #[tokio::test]
316    async fn test_force_dispose() {
317        let lifecycle = CapabilityLifecycle::new();
318        let cap_id = CapId::new(1);
319
320        lifecycle.register(cap_id, None, None).await;
321        lifecycle.retain(&cap_id).unwrap();
322        lifecycle.retain(&cap_id).unwrap();
323        assert_eq!(lifecycle.ref_count(&cap_id), Some(3));
324
325        lifecycle.dispose(&cap_id).await.unwrap();
326        assert!(!lifecycle.is_alive(&cap_id));
327    }
328
329    #[tokio::test]
330    async fn test_lifecycle_stats() {
331        let lifecycle = CapabilityLifecycle::new();
332
333        lifecycle
334            .register(CapId::new(1), Some("s1".to_string()), None)
335            .await;
336        lifecycle
337            .register(
338                CapId::new(2),
339                Some("s1".to_string()),
340                Some(Arc::new(DisposableResource::new("r1".to_string()))),
341            )
342            .await;
343        lifecycle
344            .register(CapId::new(3), Some("s2".to_string()), None)
345            .await;
346
347        let stats = lifecycle.stats().await;
348        assert_eq!(stats.total_capabilities, 3);
349        assert_eq!(stats.with_dispose_callbacks, 1);
350        assert_eq!(stats.total_sessions, 2);
351    }
352}