mcp_execution_server/
state.rs

1//! State management for pending generation sessions.
2//!
3//! The `StateManager` stores temporary session data between `introspect_server`
4//! and `save_categorized_tools` calls. Sessions expire after 30 minutes and
5//! are cleaned up lazily on each operation.
6
7use crate::types::PendingGeneration;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13/// State manager for pending generation sessions.
14///
15/// Uses an in-memory `HashMap` protected by `RwLock` for thread-safe access.
16/// Sessions expire after 30 minutes and are cleaned up lazily.
17///
18/// # Examples
19///
20/// ```
21/// use mcp_execution_server::state::StateManager;
22/// use mcp_execution_server::types::PendingGeneration;
23/// use mcp_execution_core::{ServerId, ServerConfig};
24/// use mcp_execution_introspector::ServerInfo;
25/// use std::path::PathBuf;
26///
27/// # async fn example() {
28/// let state = StateManager::new();
29///
30/// # let server_info = ServerInfo {
31/// #     id: ServerId::new("test"),
32/// #     name: "Test".to_string(),
33/// #     version: "1.0.0".to_string(),
34/// #     capabilities: mcp_execution_introspector::ServerCapabilities {
35/// #         supports_tools: true,
36/// #         supports_resources: false,
37/// #         supports_prompts: false,
38/// #     },
39/// #     tools: vec![],
40/// # };
41/// let pending = PendingGeneration::new(
42///     ServerId::new("github"),
43///     server_info,
44///     ServerConfig::builder().command("npx".to_string()).build(),
45///     PathBuf::from("/tmp/output"),
46/// );
47///
48/// // Store and get session ID
49/// let session_id = state.store(pending).await;
50///
51/// // Retrieve session data
52/// let retrieved = state.take(session_id).await;
53/// assert!(retrieved.is_some());
54/// # }
55/// ```
56#[derive(Debug, Default)]
57pub struct StateManager {
58    pending: Arc<RwLock<HashMap<Uuid, PendingGeneration>>>,
59}
60
61impl StateManager {
62    /// Creates a new state manager.
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Stores a pending generation and returns a session ID.
69    ///
70    /// This operation also performs lazy cleanup of expired sessions.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use mcp_execution_server::state::StateManager;
76    /// # use mcp_execution_server::types::PendingGeneration;
77    /// # use mcp_execution_core::{ServerId, ServerConfig};
78    /// # use mcp_execution_introspector::ServerInfo;
79    /// # use std::path::PathBuf;
80    ///
81    /// # async fn example(pending: PendingGeneration) {
82    /// let state = StateManager::new();
83    /// let session_id = state.store(pending).await;
84    /// # }
85    /// ```
86    pub async fn store(&self, generation: PendingGeneration) -> Uuid {
87        let session_id = Uuid::new_v4();
88        let mut pending = self.pending.write().await;
89
90        // Clean up expired sessions
91        pending.retain(|_, g| !g.is_expired());
92
93        pending.insert(session_id, generation);
94        session_id
95    }
96
97    /// Retrieves and removes a pending generation.
98    ///
99    /// Returns `None` if the session is not found or has expired.
100    /// This operation also performs lazy cleanup of expired sessions.
101    ///
102    /// # Examples
103    ///
104    /// ```
105    /// use mcp_execution_server::state::StateManager;
106    /// # use mcp_execution_server::types::PendingGeneration;
107    /// # use mcp_execution_core::{ServerId, ServerConfig};
108    /// # use mcp_execution_introspector::ServerInfo;
109    /// # use std::path::PathBuf;
110    ///
111    /// # async fn example(pending: PendingGeneration) {
112    /// let state = StateManager::new();
113    /// let session_id = state.store(pending).await;
114    ///
115    /// let retrieved = state.take(session_id).await;
116    /// assert!(retrieved.is_some());
117    ///
118    /// // Second take returns None (already removed)
119    /// let second = state.take(session_id).await;
120    /// assert!(second.is_none());
121    /// # }
122    /// ```
123    pub async fn take(&self, session_id: Uuid) -> Option<PendingGeneration> {
124        let generation = {
125            let mut pending = self.pending.write().await;
126
127            // Clean up expired sessions
128            pending.retain(|_, g| !g.is_expired());
129
130            pending.remove(&session_id)?
131        };
132
133        // Verify not expired (lock already released)
134        if generation.is_expired() {
135            return None;
136        }
137
138        Some(generation)
139    }
140
141    /// Gets a pending generation without removing it.
142    ///
143    /// Returns `None` if the session is not found or has expired.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use mcp_execution_server::state::StateManager;
149    /// # use mcp_execution_server::types::PendingGeneration;
150    /// # use mcp_execution_core::{ServerId, ServerConfig};
151    /// # use mcp_execution_introspector::ServerInfo;
152    /// # use std::path::PathBuf;
153    ///
154    /// # async fn example(pending: PendingGeneration) {
155    /// let state = StateManager::new();
156    /// let session_id = state.store(pending).await;
157    ///
158    /// // Get without removing
159    /// let peeked = state.get(session_id).await;
160    /// assert!(peeked.is_some());
161    ///
162    /// // Still available
163    /// let peeked_again = state.get(session_id).await;
164    /// assert!(peeked_again.is_some());
165    /// # }
166    /// ```
167    pub async fn get(&self, session_id: Uuid) -> Option<PendingGeneration> {
168        let pending = self.pending.read().await;
169        pending
170            .get(&session_id)
171            .filter(|g| !g.is_expired())
172            .cloned()
173    }
174
175    /// Returns the current pending session count (excluding expired).
176    ///
177    /// # Examples
178    ///
179    /// ```
180    /// use mcp_execution_server::state::StateManager;
181    ///
182    /// # async fn example() {
183    /// let state = StateManager::new();
184    /// assert_eq!(state.pending_count().await, 0);
185    /// # }
186    /// ```
187    pub async fn pending_count(&self) -> usize {
188        let pending = self.pending.read().await;
189        pending.values().filter(|g| !g.is_expired()).count()
190    }
191
192    /// Cleans up all expired sessions.
193    ///
194    /// Returns the number of sessions that were removed.
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// use mcp_execution_server::state::StateManager;
200    ///
201    /// # async fn example() {
202    /// let state = StateManager::new();
203    /// let removed = state.cleanup_expired().await;
204    /// assert_eq!(removed, 0);
205    /// # }
206    /// ```
207    pub async fn cleanup_expired(&self) -> usize {
208        let mut pending = self.pending.write().await;
209        let before = pending.len();
210        pending.retain(|_, g| !g.is_expired());
211        before - pending.len()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::types::PendingGeneration;
219    use chrono::{Duration, Utc};
220    use mcp_execution_core::{ServerConfig, ServerId, ToolName};
221    use mcp_execution_introspector::ServerInfo;
222    use std::path::PathBuf;
223
224    fn create_test_pending() -> PendingGeneration {
225        use mcp_execution_introspector::{ServerCapabilities, ToolInfo};
226
227        let server_id = ServerId::new("test");
228        let server_info = ServerInfo {
229            id: server_id.clone(),
230            name: "Test Server".to_string(),
231            version: "1.0.0".to_string(),
232            capabilities: ServerCapabilities {
233                supports_tools: true,
234                supports_resources: false,
235                supports_prompts: false,
236            },
237            tools: vec![ToolInfo {
238                name: ToolName::new("test_tool"),
239                description: "Test tool".to_string(),
240                input_schema: serde_json::json!({}),
241                output_schema: None,
242            }],
243        };
244        let config = ServerConfig::builder().command("echo".to_string()).build();
245        let output_dir = PathBuf::from("/tmp/test");
246
247        PendingGeneration::new(server_id, server_info, config, output_dir)
248    }
249
250    fn create_expired_pending() -> PendingGeneration {
251        let mut pending = create_test_pending();
252        pending.expires_at = Utc::now() - Duration::hours(1);
253        pending
254    }
255
256    #[tokio::test]
257    async fn test_store_and_retrieve() {
258        let state = StateManager::new();
259        let pending = create_test_pending();
260
261        let session_id = state.store(pending.clone()).await;
262        let retrieved = state.take(session_id).await;
263
264        assert!(retrieved.is_some());
265        let retrieved = retrieved.unwrap();
266        assert_eq!(retrieved.server_id, pending.server_id);
267    }
268
269    #[tokio::test]
270    async fn test_take_removes_session() {
271        let state = StateManager::new();
272        let pending = create_test_pending();
273
274        let session_id = state.store(pending).await;
275
276        // First take succeeds
277        let first = state.take(session_id).await;
278        assert!(first.is_some());
279
280        // Second take returns None
281        let second = state.take(session_id).await;
282        assert!(second.is_none());
283    }
284
285    #[tokio::test]
286    async fn test_get_does_not_remove() {
287        let state = StateManager::new();
288        let pending = create_test_pending();
289
290        let session_id = state.store(pending).await;
291
292        // Get multiple times
293        let first = state.get(session_id).await;
294        assert!(first.is_some());
295
296        let second = state.get(session_id).await;
297        assert!(second.is_some());
298
299        // Still available for take
300        let taken = state.take(session_id).await;
301        assert!(taken.is_some());
302    }
303
304    #[tokio::test]
305    async fn test_expired_session() {
306        let state = StateManager::new();
307        let pending = create_expired_pending();
308
309        let session_id = state.store(pending).await;
310
311        // Should return None because expired
312        let retrieved = state.take(session_id).await;
313        assert!(retrieved.is_none());
314    }
315
316    #[tokio::test]
317    async fn test_pending_count() {
318        let state = StateManager::new();
319
320        assert_eq!(state.pending_count().await, 0);
321
322        let session_id = state.store(create_test_pending()).await;
323        assert_eq!(state.pending_count().await, 1);
324
325        state.take(session_id).await;
326        assert_eq!(state.pending_count().await, 0);
327    }
328
329    #[tokio::test]
330    async fn test_cleanup_expired() {
331        let state = StateManager::new();
332
333        // Add valid session
334        state.store(create_test_pending()).await;
335
336        // Add expired session
337        state.store(create_expired_pending()).await;
338
339        assert_eq!(state.pending_count().await, 1); // Only valid session counts
340
341        let removed = state.cleanup_expired().await;
342        assert_eq!(removed, 1); // One expired session removed
343    }
344
345    #[tokio::test]
346    async fn test_concurrent_access() {
347        let state = Arc::new(StateManager::new());
348        let mut handles = vec![];
349
350        // Spawn 10 concurrent store operations
351        for i in 0..10 {
352            let state_clone = Arc::clone(&state);
353            handles.push(tokio::spawn(async move {
354                let mut pending = create_test_pending();
355                pending.server_id = ServerId::new(&format!("server-{i}"));
356                state_clone.store(pending).await
357            }));
358        }
359
360        // Wait for all operations to complete
361        for handle in handles {
362            handle.await.unwrap();
363        }
364
365        assert_eq!(state.pending_count().await, 10);
366    }
367
368    #[tokio::test]
369    async fn test_lazy_cleanup_on_store() {
370        let state = StateManager::new();
371
372        // Store expired session directly
373        {
374            let mut pending = state.pending.write().await;
375            pending.insert(Uuid::new_v4(), create_expired_pending());
376        }
377
378        // Store new session triggers cleanup
379        state.store(create_test_pending()).await;
380
381        // Only the new session should remain
382        assert_eq!(state.pending_count().await, 1);
383    }
384}