mcp_execution_server/state.rs
1//! State management for pending generation sessions.
2//!
3//! The `StateManager` stores temporary session data between `introspect_server`
4//! and `save_categorized_tools` calls. Sessions expire after 30 minutes and
5//! are cleaned up lazily on each operation.
6
7use crate::types::PendingGeneration;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13/// State manager for pending generation sessions.
14///
15/// Uses an in-memory `HashMap` protected by `RwLock` for thread-safe access.
16/// Sessions expire after 30 minutes and are cleaned up lazily.
17///
18/// # Examples
19///
20/// ```
21/// use mcp_execution_server::state::StateManager;
22/// use mcp_execution_server::types::PendingGeneration;
23/// use mcp_execution_core::{ServerId, ServerConfig};
24/// use mcp_execution_introspector::ServerInfo;
25/// use std::path::PathBuf;
26///
27/// # async fn example() {
28/// let state = StateManager::new();
29///
30/// # let server_info = ServerInfo {
31/// # id: ServerId::new("test"),
32/// # name: "Test".to_string(),
33/// # version: "1.0.0".to_string(),
34/// # capabilities: mcp_execution_introspector::ServerCapabilities {
35/// # supports_tools: true,
36/// # supports_resources: false,
37/// # supports_prompts: false,
38/// # },
39/// # tools: vec![],
40/// # };
41/// let pending = PendingGeneration::new(
42/// ServerId::new("github"),
43/// server_info,
44/// ServerConfig::builder().command("npx".to_string()).build(),
45/// PathBuf::from("/tmp/output"),
46/// );
47///
48/// // Store and get session ID
49/// let session_id = state.store(pending).await;
50///
51/// // Retrieve session data
52/// let retrieved = state.take(session_id).await;
53/// assert!(retrieved.is_some());
54/// # }
55/// ```
56#[derive(Debug, Default)]
57pub struct StateManager {
58 pending: Arc<RwLock<HashMap<Uuid, PendingGeneration>>>,
59}
60
61impl StateManager {
62 /// Creates a new state manager.
63 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 /// Stores a pending generation and returns a session ID.
69 ///
70 /// This operation also performs lazy cleanup of expired sessions.
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use mcp_execution_server::state::StateManager;
76 /// # use mcp_execution_server::types::PendingGeneration;
77 /// # use mcp_execution_core::{ServerId, ServerConfig};
78 /// # use mcp_execution_introspector::ServerInfo;
79 /// # use std::path::PathBuf;
80 ///
81 /// # async fn example(pending: PendingGeneration) {
82 /// let state = StateManager::new();
83 /// let session_id = state.store(pending).await;
84 /// # }
85 /// ```
86 pub async fn store(&self, generation: PendingGeneration) -> Uuid {
87 let session_id = Uuid::new_v4();
88 let mut pending = self.pending.write().await;
89
90 // Clean up expired sessions
91 pending.retain(|_, g| !g.is_expired());
92
93 pending.insert(session_id, generation);
94 session_id
95 }
96
97 /// Retrieves and removes a pending generation.
98 ///
99 /// Returns `None` if the session is not found or has expired.
100 /// This operation also performs lazy cleanup of expired sessions.
101 ///
102 /// # Examples
103 ///
104 /// ```
105 /// use mcp_execution_server::state::StateManager;
106 /// # use mcp_execution_server::types::PendingGeneration;
107 /// # use mcp_execution_core::{ServerId, ServerConfig};
108 /// # use mcp_execution_introspector::ServerInfo;
109 /// # use std::path::PathBuf;
110 ///
111 /// # async fn example(pending: PendingGeneration) {
112 /// let state = StateManager::new();
113 /// let session_id = state.store(pending).await;
114 ///
115 /// let retrieved = state.take(session_id).await;
116 /// assert!(retrieved.is_some());
117 ///
118 /// // Second take returns None (already removed)
119 /// let second = state.take(session_id).await;
120 /// assert!(second.is_none());
121 /// # }
122 /// ```
123 pub async fn take(&self, session_id: Uuid) -> Option<PendingGeneration> {
124 let generation = {
125 let mut pending = self.pending.write().await;
126
127 // Clean up expired sessions
128 pending.retain(|_, g| !g.is_expired());
129
130 pending.remove(&session_id)?
131 };
132
133 // Verify not expired (lock already released)
134 if generation.is_expired() {
135 return None;
136 }
137
138 Some(generation)
139 }
140
141 /// Gets a pending generation without removing it.
142 ///
143 /// Returns `None` if the session is not found or has expired.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// use mcp_execution_server::state::StateManager;
149 /// # use mcp_execution_server::types::PendingGeneration;
150 /// # use mcp_execution_core::{ServerId, ServerConfig};
151 /// # use mcp_execution_introspector::ServerInfo;
152 /// # use std::path::PathBuf;
153 ///
154 /// # async fn example(pending: PendingGeneration) {
155 /// let state = StateManager::new();
156 /// let session_id = state.store(pending).await;
157 ///
158 /// // Get without removing
159 /// let peeked = state.get(session_id).await;
160 /// assert!(peeked.is_some());
161 ///
162 /// // Still available
163 /// let peeked_again = state.get(session_id).await;
164 /// assert!(peeked_again.is_some());
165 /// # }
166 /// ```
167 pub async fn get(&self, session_id: Uuid) -> Option<PendingGeneration> {
168 let pending = self.pending.read().await;
169 pending
170 .get(&session_id)
171 .filter(|g| !g.is_expired())
172 .cloned()
173 }
174
175 /// Returns the current pending session count (excluding expired).
176 ///
177 /// # Examples
178 ///
179 /// ```
180 /// use mcp_execution_server::state::StateManager;
181 ///
182 /// # async fn example() {
183 /// let state = StateManager::new();
184 /// assert_eq!(state.pending_count().await, 0);
185 /// # }
186 /// ```
187 pub async fn pending_count(&self) -> usize {
188 let pending = self.pending.read().await;
189 pending.values().filter(|g| !g.is_expired()).count()
190 }
191
192 /// Cleans up all expired sessions.
193 ///
194 /// Returns the number of sessions that were removed.
195 ///
196 /// # Examples
197 ///
198 /// ```
199 /// use mcp_execution_server::state::StateManager;
200 ///
201 /// # async fn example() {
202 /// let state = StateManager::new();
203 /// let removed = state.cleanup_expired().await;
204 /// assert_eq!(removed, 0);
205 /// # }
206 /// ```
207 pub async fn cleanup_expired(&self) -> usize {
208 let mut pending = self.pending.write().await;
209 let before = pending.len();
210 pending.retain(|_, g| !g.is_expired());
211 before - pending.len()
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::types::PendingGeneration;
219 use chrono::{Duration, Utc};
220 use mcp_execution_core::{ServerConfig, ServerId, ToolName};
221 use mcp_execution_introspector::ServerInfo;
222 use std::path::PathBuf;
223
224 fn create_test_pending() -> PendingGeneration {
225 use mcp_execution_introspector::{ServerCapabilities, ToolInfo};
226
227 let server_id = ServerId::new("test");
228 let server_info = ServerInfo {
229 id: server_id.clone(),
230 name: "Test Server".to_string(),
231 version: "1.0.0".to_string(),
232 capabilities: ServerCapabilities {
233 supports_tools: true,
234 supports_resources: false,
235 supports_prompts: false,
236 },
237 tools: vec![ToolInfo {
238 name: ToolName::new("test_tool"),
239 description: "Test tool".to_string(),
240 input_schema: serde_json::json!({}),
241 output_schema: None,
242 }],
243 };
244 let config = ServerConfig::builder().command("echo".to_string()).build();
245 let output_dir = PathBuf::from("/tmp/test");
246
247 PendingGeneration::new(server_id, server_info, config, output_dir)
248 }
249
250 fn create_expired_pending() -> PendingGeneration {
251 let mut pending = create_test_pending();
252 pending.expires_at = Utc::now() - Duration::hours(1);
253 pending
254 }
255
256 #[tokio::test]
257 async fn test_store_and_retrieve() {
258 let state = StateManager::new();
259 let pending = create_test_pending();
260
261 let session_id = state.store(pending.clone()).await;
262 let retrieved = state.take(session_id).await;
263
264 assert!(retrieved.is_some());
265 let retrieved = retrieved.unwrap();
266 assert_eq!(retrieved.server_id, pending.server_id);
267 }
268
269 #[tokio::test]
270 async fn test_take_removes_session() {
271 let state = StateManager::new();
272 let pending = create_test_pending();
273
274 let session_id = state.store(pending).await;
275
276 // First take succeeds
277 let first = state.take(session_id).await;
278 assert!(first.is_some());
279
280 // Second take returns None
281 let second = state.take(session_id).await;
282 assert!(second.is_none());
283 }
284
285 #[tokio::test]
286 async fn test_get_does_not_remove() {
287 let state = StateManager::new();
288 let pending = create_test_pending();
289
290 let session_id = state.store(pending).await;
291
292 // Get multiple times
293 let first = state.get(session_id).await;
294 assert!(first.is_some());
295
296 let second = state.get(session_id).await;
297 assert!(second.is_some());
298
299 // Still available for take
300 let taken = state.take(session_id).await;
301 assert!(taken.is_some());
302 }
303
304 #[tokio::test]
305 async fn test_expired_session() {
306 let state = StateManager::new();
307 let pending = create_expired_pending();
308
309 let session_id = state.store(pending).await;
310
311 // Should return None because expired
312 let retrieved = state.take(session_id).await;
313 assert!(retrieved.is_none());
314 }
315
316 #[tokio::test]
317 async fn test_pending_count() {
318 let state = StateManager::new();
319
320 assert_eq!(state.pending_count().await, 0);
321
322 let session_id = state.store(create_test_pending()).await;
323 assert_eq!(state.pending_count().await, 1);
324
325 state.take(session_id).await;
326 assert_eq!(state.pending_count().await, 0);
327 }
328
329 #[tokio::test]
330 async fn test_cleanup_expired() {
331 let state = StateManager::new();
332
333 // Add valid session
334 state.store(create_test_pending()).await;
335
336 // Add expired session
337 state.store(create_expired_pending()).await;
338
339 assert_eq!(state.pending_count().await, 1); // Only valid session counts
340
341 let removed = state.cleanup_expired().await;
342 assert_eq!(removed, 1); // One expired session removed
343 }
344
345 #[tokio::test]
346 async fn test_concurrent_access() {
347 let state = Arc::new(StateManager::new());
348 let mut handles = vec![];
349
350 // Spawn 10 concurrent store operations
351 for i in 0..10 {
352 let state_clone = Arc::clone(&state);
353 handles.push(tokio::spawn(async move {
354 let mut pending = create_test_pending();
355 pending.server_id = ServerId::new(&format!("server-{i}"));
356 state_clone.store(pending).await
357 }));
358 }
359
360 // Wait for all operations to complete
361 for handle in handles {
362 handle.await.unwrap();
363 }
364
365 assert_eq!(state.pending_count().await, 10);
366 }
367
368 #[tokio::test]
369 async fn test_lazy_cleanup_on_store() {
370 let state = StateManager::new();
371
372 // Store expired session directly
373 {
374 let mut pending = state.pending.write().await;
375 pending.insert(Uuid::new_v4(), create_expired_pending());
376 }
377
378 // Store new session triggers cleanup
379 state.store(create_test_pending()).await;
380
381 // Only the new session should remain
382 assert_eq!(state.pending_count().await, 1);
383 }
384}