mcp_host/managers/
task.rs

1//! Task management for async task execution
2//!
3//! Provides storage and lifecycle management for MCP tasks.
4
5use std::collections::HashSet;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use chrono::Utc;
10use dashmap::DashMap;
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{oneshot, RwLock};
14use uuid::Uuid;
15
16use crate::protocol::types::{Task, TaskStatus, JsonRpcRequest};
17
18/// Task-related errors
19#[derive(Debug, Error)]
20pub enum TaskError {
21    /// Task not found
22    #[error("Task not found: {0}")]
23    NotFound(String),
24
25    /// Task timeout
26    #[error("Task timeout after {0:?}")]
27    Timeout(Duration),
28
29    /// Task failed
30    #[error("Task failed: {0}")]
31    Failed(String),
32
33    /// Task already in terminal state
34    #[error("Task already in terminal status: {0:?}")]
35    AlreadyTerminal(TaskStatus),
36
37    /// Internal error
38    #[error("Internal error: {0}")]
39    Internal(String),
40}
41
42/// Internal task entry with result channel
43struct TaskEntry {
44    /// Public task metadata
45    task: Arc<RwLock<Task>>,
46    /// Session ID owning this task
47    session_id: String,
48    /// Original request that created the task (reserved for future use)
49    #[allow(dead_code)]
50    original_request: JsonRpcRequest,
51    /// Channel for waiting on result
52    result_tx: Option<oneshot::Sender<Value>>,
53    /// Stored result (when completed)
54    result: Arc<RwLock<Option<Value>>>,
55    /// Expiration time
56    expires_at: Option<Instant>,
57}
58
59/// Task store for managing async task execution
60pub struct TaskStore {
61    /// All tasks indexed by task ID
62    tasks: DashMap<String, TaskEntry>,
63    /// Session → task IDs index
64    by_session: DashMap<String, HashSet<String>>,
65    /// Default TTL for new tasks
66    default_ttl: Duration,
67    /// Default poll interval
68    default_poll_interval: Duration,
69}
70
71impl TaskStore {
72    /// Create a new task store
73    pub fn new(default_ttl: Duration, default_poll_interval: Duration) -> Self {
74        Self {
75            tasks: DashMap::new(),
76            by_session: DashMap::new(),
77            default_ttl,
78            default_poll_interval,
79        }
80    }
81
82    /// Create a new task
83    pub fn create_task(
84        &self,
85        session_id: &str,
86        original_request: JsonRpcRequest,
87        requested_ttl: Option<u64>,
88    ) -> (Task, oneshot::Receiver<Value>) {
89        let task_id = Uuid::new_v4().to_string();
90        let now = Utc::now().to_rfc3339();
91
92        // Determine TTL (use requested or default)
93        let ttl_duration = requested_ttl
94            .map(|ms| Duration::from_millis(ms))
95            .unwrap_or(self.default_ttl);
96        let ttl_ms = if ttl_duration == Duration::ZERO {
97            None
98        } else {
99            Some(ttl_duration.as_millis() as u64)
100        };
101
102        let task = Task {
103            task_id: task_id.clone(),
104            status: TaskStatus::Working,
105            status_message: None,
106            created_at: now.clone(),
107            last_updated_at: now,
108            ttl: ttl_ms,
109            poll_interval: Some(self.default_poll_interval.as_millis() as u64),
110        };
111
112        // Create result channel
113        let (result_tx, result_rx) = oneshot::channel();
114
115        let entry = TaskEntry {
116            task: Arc::new(RwLock::new(task.clone())),
117            session_id: session_id.to_string(),
118            original_request,
119            result_tx: Some(result_tx),
120            result: Arc::new(RwLock::new(None)),
121            expires_at: ttl_ms.map(|ms| Instant::now() + Duration::from_millis(ms)),
122        };
123
124        // Store task
125        self.tasks.insert(task_id.clone(), entry);
126
127        // Index by session
128        self.by_session
129            .entry(session_id.to_string())
130            .or_insert_with(HashSet::new)
131            .insert(task_id);
132
133        (task, result_rx)
134    }
135
136    /// Get task by ID
137    pub async fn get_task(&self, task_id: &str) -> Option<Task> {
138        let entry = self.tasks.get(task_id)?;
139        Some(entry.task.read().await.clone())
140    }
141
142    /// Get task for specific session (enforces session isolation)
143    pub async fn get_task_for_session(
144        &self,
145        task_id: &str,
146        session_id: &str,
147    ) -> Option<Task> {
148        let entry = self.tasks.get(task_id)?;
149        if entry.session_id == session_id {
150            Some(entry.task.read().await.clone())
151        } else {
152            None
153        }
154    }
155
156    /// List tasks for a session (with pagination)
157    pub async fn list_tasks(
158        &self,
159        session_id: &str,
160        cursor: Option<&str>,
161        limit: usize,
162    ) -> (Vec<Task>, Option<String>) {
163        let task_ids = match self.by_session.get(session_id) {
164            Some(ids) => ids.clone(),
165            None => return (vec![], None),
166        };
167
168        let mut task_ids: Vec<_> = task_ids.into_iter().collect();
169        task_ids.sort(); // Deterministic ordering
170
171        // Find cursor position
172        let start = if let Some(c) = cursor {
173            task_ids.iter().position(|id| id == c).map(|p| p + 1).unwrap_or(0)
174        } else {
175            0
176        };
177
178        let end = (start + limit).min(task_ids.len());
179        let page_ids = &task_ids[start..end];
180
181        // Collect tasks
182        let mut tasks = Vec::new();
183        for id in page_ids {
184            if let Some(task) = self.get_task(id).await {
185                tasks.push(task);
186            }
187        }
188
189        // Next cursor
190        let next_cursor = if end < task_ids.len() {
191            task_ids.get(end).cloned()
192        } else {
193            None
194        };
195
196        (tasks, next_cursor)
197    }
198
199    /// Update task status
200    pub async fn update_status(
201        &self,
202        task_id: &str,
203        new_status: TaskStatus,
204        status_message: Option<String>,
205    ) -> Result<(), TaskError> {
206        let entry = self.tasks.get(task_id)
207            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
208
209        let mut task = entry.task.write().await;
210
211        // Check valid transition
212        match (&task.status, &new_status) {
213            (TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled, _) => {
214                return Err(TaskError::AlreadyTerminal(task.status));
215            }
216            _ => {}
217        }
218
219        task.status = new_status;
220        task.status_message = status_message;
221        task.last_updated_at = Utc::now().to_rfc3339();
222
223        Ok(())
224    }
225
226    /// Store task result
227    pub async fn store_result(&self, task_id: &str, result: Value) -> Result<(), TaskError> {
228        let entry = self.tasks.get(task_id)
229            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
230
231        // Store result
232        *entry.result.write().await = Some(result.clone());
233
234        // Send to waiting receiver (if any)
235        drop(entry); // Drop the read lock before taking result_tx
236        if let Some(mut entry_mut) = self.tasks.get_mut(task_id) {
237            if let Some(tx) = entry_mut.result_tx.take() {
238                let _ = tx.send(result);
239            }
240        }
241
242        Ok(())
243    }
244
245    /// Get stored result
246    pub async fn get_result(&self, task_id: &str) -> Option<Value> {
247        let entry = self.tasks.get(task_id)?;
248        entry.result.read().await.clone()
249    }
250
251    /// Wait for task result (blocks until terminal state)
252    pub async fn wait_for_result(
253        &self,
254        task_id: &str,
255        timeout: Duration,
256    ) -> Result<Value, TaskError> {
257        let entry = self.tasks.get(task_id)
258            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
259
260        let task = entry.task.read().await.clone();
261
262        // If already terminal, return stored result
263        if matches!(
264            task.status,
265            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
266        ) {
267            drop(task);
268            if let Some(result) = self.get_result(task_id).await {
269                return Ok(result);
270            } else {
271                return Err(TaskError::Failed(
272                    "Task in terminal state but no result stored".to_string(),
273                ));
274            }
275        }
276
277        // Otherwise wait on channel
278        drop(entry);
279
280        // Get a new receiver (or clone existing one)
281        // Note: this is a simplification - in production you'd want a broadcast channel
282        // For now, we'll poll
283        let start = Instant::now();
284        loop {
285            if start.elapsed() > timeout {
286                return Err(TaskError::Timeout(timeout));
287            }
288
289            let task = self.get_task(task_id).await
290                .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
291
292            if matches!(
293                task.status,
294                TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
295            ) {
296                if let Some(result) = self.get_result(task_id).await {
297                    return Ok(result);
298                } else {
299                    return Err(TaskError::Failed(
300                        "Task completed but no result stored".to_string(),
301                    ));
302                }
303            }
304
305            tokio::time::sleep(Duration::from_millis(100)).await;
306        }
307    }
308
309    /// Cancel a task
310    pub async fn cancel_task(
311        &self,
312        task_id: &str,
313        session_id: &str,
314    ) -> Result<Task, TaskError> {
315        // Verify session owns task
316        let entry = self.tasks.get(task_id)
317            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
318
319        if entry.session_id != session_id {
320            return Err(TaskError::NotFound(task_id.to_string()));
321        }
322
323        let mut task = entry.task.write().await;
324
325        // Check not already terminal
326        if matches!(
327            task.status,
328            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
329        ) {
330            return Err(TaskError::AlreadyTerminal(task.status));
331        }
332
333        // Update to cancelled
334        task.status = TaskStatus::Cancelled;
335        task.status_message = Some("Cancelled by request".to_string());
336        task.last_updated_at = Utc::now().to_rfc3339();
337
338        Ok(task.clone())
339    }
340
341    /// Cleanup expired tasks (should be called periodically)
342    pub async fn cleanup_expired(&self) {
343        let now = Instant::now();
344        let mut to_remove = Vec::new();
345
346        for entry in self.tasks.iter() {
347            if let Some(expires_at) = entry.expires_at {
348                if now >= expires_at {
349                    to_remove.push(entry.key().clone());
350                }
351            }
352        }
353
354        for task_id in to_remove {
355            if let Some((_, entry)) = self.tasks.remove(&task_id) {
356                // Remove from session index
357                if let Some(mut ids) = self.by_session.get_mut(&entry.session_id) {
358                    ids.remove(&task_id);
359                }
360            }
361        }
362    }
363
364    /// Spawn background cleanup task
365    pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) {
366        tokio::spawn(async move {
367            let mut interval = tokio::time::interval(interval);
368            loop {
369                interval.tick().await;
370                self.cleanup_expired().await;
371            }
372        });
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use serde_json::json;
380
381    #[tokio::test]
382    async fn test_create_task() {
383        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
384        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
385
386        let (task, _rx) = store.create_task("session1", req, None);
387
388        assert_eq!(task.status, TaskStatus::Working);
389        assert!(task.task_id.len() > 0);
390        assert_eq!(task.ttl, Some(60000));
391    }
392
393    #[tokio::test]
394    async fn test_get_task() {
395        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
396        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
397
398        let (task, _rx) = store.create_task("session1", req, None);
399        let retrieved = store.get_task(&task.task_id).await.unwrap();
400
401        assert_eq!(retrieved.task_id, task.task_id);
402        assert_eq!(retrieved.status, TaskStatus::Working);
403    }
404
405    #[tokio::test]
406    async fn test_session_isolation() {
407        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
408        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
409
410        let (task, _rx) = store.create_task("session1", req, None);
411
412        // Should not be visible from different session
413        assert!(store.get_task_for_session(&task.task_id, "session2").await.is_none());
414
415        // Should be visible from same session
416        assert!(store.get_task_for_session(&task.task_id, "session1").await.is_some());
417    }
418
419    #[tokio::test]
420    async fn test_update_status() {
421        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
422        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
423
424        let (task, _rx) = store.create_task("session1", req, None);
425
426        store.update_status(&task.task_id, TaskStatus::Completed, Some("Done".to_string()))
427            .await
428            .unwrap();
429
430        let updated = store.get_task(&task.task_id).await.unwrap();
431        assert_eq!(updated.status, TaskStatus::Completed);
432        assert_eq!(updated.status_message, Some("Done".to_string()));
433    }
434
435    #[tokio::test]
436    async fn test_store_and_get_result() {
437        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
438        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
439
440        let (task, _rx) = store.create_task("session1", req, None);
441
442        let result = json!({"answer": 42});
443        store.store_result(&task.task_id, result.clone()).await.unwrap();
444
445        let retrieved = store.get_result(&task.task_id).await.unwrap();
446        assert_eq!(retrieved, result);
447    }
448
449    #[tokio::test]
450    async fn test_cancel_task() {
451        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
452        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
453
454        let (task, _rx) = store.create_task("session1", req, None);
455
456        let cancelled = store.cancel_task(&task.task_id, "session1").await.unwrap();
457        assert_eq!(cancelled.status, TaskStatus::Cancelled);
458    }
459
460    #[tokio::test]
461    async fn test_list_tasks() {
462        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
463
464        // Create multiple tasks
465        for _ in 0..5 {
466            let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
467            store.create_task("session1", req, None);
468        }
469
470        let (tasks, cursor) = store.list_tasks("session1", None, 10).await;
471        assert_eq!(tasks.len(), 5);
472        assert!(cursor.is_none());
473    }
474}