mcp_host/managers/
task.rs

1//! Task management for async task execution
2//!
3//! Provides storage and lifecycle management for MCP tasks.
4
5use std::collections::HashSet;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use chrono::Utc;
10use dashmap::DashMap;
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{RwLock, oneshot};
14use uuid::Uuid;
15
16use crate::protocol::types::{JsonRpcRequest, Task, TaskStatus};
17
18/// Task-related errors
19#[derive(Debug, Error)]
20pub enum TaskError {
21    /// Task not found
22    #[error("Task not found: {0}")]
23    NotFound(String),
24
25    /// Task timeout
26    #[error("Task timeout after {0:?}")]
27    Timeout(Duration),
28
29    /// Task failed
30    #[error("Task failed: {0}")]
31    Failed(String),
32
33    /// Task already in terminal state
34    #[error("Task already in terminal status: {0:?}")]
35    AlreadyTerminal(TaskStatus),
36
37    /// Internal error
38    #[error("Internal error: {0}")]
39    Internal(String),
40}
41
42/// Internal task entry with result channel
43struct TaskEntry {
44    /// Public task metadata
45    task: Arc<RwLock<Task>>,
46    /// Session ID owning this task
47    session_id: String,
48    /// Original request that created the task (reserved for future use)
49    #[allow(dead_code)]
50    original_request: JsonRpcRequest,
51    /// Channel for waiting on result
52    result_tx: Option<oneshot::Sender<Value>>,
53    /// Stored result (when completed)
54    result: Arc<RwLock<Option<Value>>>,
55    /// Expiration time
56    expires_at: Option<Instant>,
57}
58
59/// Task store for managing async task execution
60pub struct TaskStore {
61    /// All tasks indexed by task ID
62    tasks: DashMap<String, TaskEntry>,
63    /// Session → task IDs index
64    by_session: DashMap<String, HashSet<String>>,
65    /// Default TTL for new tasks
66    default_ttl: Duration,
67    /// Default poll interval
68    default_poll_interval: Duration,
69}
70
71impl TaskStore {
72    /// Create a new task store
73    pub fn new(default_ttl: Duration, default_poll_interval: Duration) -> Self {
74        Self {
75            tasks: DashMap::new(),
76            by_session: DashMap::new(),
77            default_ttl,
78            default_poll_interval,
79        }
80    }
81
82    /// Create a new task
83    pub fn create_task(
84        &self,
85        session_id: &str,
86        original_request: JsonRpcRequest,
87        requested_ttl: Option<u64>,
88    ) -> (Task, oneshot::Receiver<Value>) {
89        let task_id = Uuid::new_v4().to_string();
90        let now = Utc::now().to_rfc3339();
91
92        // Determine TTL (use requested or default)
93        let ttl_duration = requested_ttl
94            .map(Duration::from_millis)
95            .unwrap_or(self.default_ttl);
96        let ttl_ms = if ttl_duration == Duration::ZERO {
97            None
98        } else {
99            Some(ttl_duration.as_millis() as u64)
100        };
101
102        let task = Task {
103            task_id: task_id.clone(),
104            status: TaskStatus::Working,
105            status_message: None,
106            created_at: now.clone(),
107            last_updated_at: now,
108            ttl: ttl_ms,
109            poll_interval: Some(self.default_poll_interval.as_millis() as u64),
110        };
111
112        // Create result channel
113        let (result_tx, result_rx) = oneshot::channel();
114
115        let entry = TaskEntry {
116            task: Arc::new(RwLock::new(task.clone())),
117            session_id: session_id.to_string(),
118            original_request,
119            result_tx: Some(result_tx),
120            result: Arc::new(RwLock::new(None)),
121            expires_at: ttl_ms.map(|ms| Instant::now() + Duration::from_millis(ms)),
122        };
123
124        // Store task
125        self.tasks.insert(task_id.clone(), entry);
126
127        // Index by session
128        self.by_session
129            .entry(session_id.to_string())
130            .or_default()
131            .insert(task_id);
132
133        (task, result_rx)
134    }
135
136    /// Get task by ID
137    pub async fn get_task(&self, task_id: &str) -> Option<Task> {
138        let entry = self.tasks.get(task_id)?;
139        Some(entry.task.read().await.clone())
140    }
141
142    /// Get task for specific session (enforces session isolation)
143    pub async fn get_task_for_session(&self, task_id: &str, session_id: &str) -> Option<Task> {
144        let entry = self.tasks.get(task_id)?;
145        if entry.session_id == session_id {
146            Some(entry.task.read().await.clone())
147        } else {
148            None
149        }
150    }
151
152    /// List tasks for a session (with pagination)
153    pub async fn list_tasks(
154        &self,
155        session_id: &str,
156        cursor: Option<&str>,
157        limit: usize,
158    ) -> (Vec<Task>, Option<String>) {
159        let task_ids = match self.by_session.get(session_id) {
160            Some(ids) => ids.clone(),
161            None => return (vec![], None),
162        };
163
164        let mut task_ids: Vec<_> = task_ids.into_iter().collect();
165        task_ids.sort(); // Deterministic ordering
166
167        // Find cursor position
168        let start = if let Some(c) = cursor {
169            task_ids
170                .iter()
171                .position(|id| id == c)
172                .map(|p| p + 1)
173                .unwrap_or(0)
174        } else {
175            0
176        };
177
178        let end = (start + limit).min(task_ids.len());
179        let page_ids = &task_ids[start..end];
180
181        // Collect tasks
182        let mut tasks = Vec::new();
183        for id in page_ids {
184            if let Some(task) = self.get_task(id).await {
185                tasks.push(task);
186            }
187        }
188
189        // Next cursor
190        let next_cursor = if end < task_ids.len() {
191            task_ids.get(end).cloned()
192        } else {
193            None
194        };
195
196        (tasks, next_cursor)
197    }
198
199    /// Update task status
200    pub async fn update_status(
201        &self,
202        task_id: &str,
203        new_status: TaskStatus,
204        status_message: Option<String>,
205    ) -> Result<(), TaskError> {
206        let entry = self
207            .tasks
208            .get(task_id)
209            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
210
211        let mut task = entry.task.write().await;
212
213        // Check valid transition
214        if let (TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled, _) =
215            (&task.status, &new_status)
216        {
217            return Err(TaskError::AlreadyTerminal(task.status));
218        }
219
220        task.status = new_status;
221        task.status_message = status_message;
222        task.last_updated_at = Utc::now().to_rfc3339();
223
224        Ok(())
225    }
226
227    /// Store task result
228    pub async fn store_result(&self, task_id: &str, result: Value) -> Result<(), TaskError> {
229        let entry = self
230            .tasks
231            .get(task_id)
232            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
233
234        // Store result
235        *entry.result.write().await = Some(result.clone());
236
237        // Send to waiting receiver (if any)
238        drop(entry); // Drop the read lock before taking result_tx
239        if let Some(mut entry_mut) = self.tasks.get_mut(task_id)
240            && let Some(tx) = entry_mut.result_tx.take()
241        {
242            let _ = tx.send(result);
243        }
244
245        Ok(())
246    }
247
248    /// Get stored result
249    pub async fn get_result(&self, task_id: &str) -> Option<Value> {
250        let entry = self.tasks.get(task_id)?;
251        entry.result.read().await.clone()
252    }
253
254    /// Wait for task result (blocks until terminal state)
255    pub async fn wait_for_result(
256        &self,
257        task_id: &str,
258        timeout: Duration,
259    ) -> Result<Value, TaskError> {
260        let entry = self
261            .tasks
262            .get(task_id)
263            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
264
265        let task = entry.task.read().await.clone();
266
267        // If already terminal, return stored result
268        if matches!(
269            task.status,
270            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
271        ) {
272            drop(task);
273            if let Some(result) = self.get_result(task_id).await {
274                return Ok(result);
275            } else {
276                return Err(TaskError::Failed(
277                    "Task in terminal state but no result stored".to_string(),
278                ));
279            }
280        }
281
282        // Otherwise wait on channel
283        drop(entry);
284
285        // Get a new receiver (or clone existing one)
286        // Note: this is a simplification - in production you'd want a broadcast channel
287        // For now, we'll poll
288        let start = Instant::now();
289        loop {
290            if start.elapsed() > timeout {
291                return Err(TaskError::Timeout(timeout));
292            }
293
294            let task = self
295                .get_task(task_id)
296                .await
297                .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
298
299            if matches!(
300                task.status,
301                TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
302            ) {
303                if let Some(result) = self.get_result(task_id).await {
304                    return Ok(result);
305                } else {
306                    return Err(TaskError::Failed(
307                        "Task completed but no result stored".to_string(),
308                    ));
309                }
310            }
311
312            tokio::time::sleep(Duration::from_millis(100)).await;
313        }
314    }
315
316    /// Cancel a task
317    pub async fn cancel_task(&self, task_id: &str, session_id: &str) -> Result<Task, TaskError> {
318        // Verify session owns task
319        let entry = self
320            .tasks
321            .get(task_id)
322            .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
323
324        if entry.session_id != session_id {
325            return Err(TaskError::NotFound(task_id.to_string()));
326        }
327
328        let mut task = entry.task.write().await;
329
330        // Check not already terminal
331        if matches!(
332            task.status,
333            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
334        ) {
335            return Err(TaskError::AlreadyTerminal(task.status));
336        }
337
338        // Update to cancelled
339        task.status = TaskStatus::Cancelled;
340        task.status_message = Some("Cancelled by request".to_string());
341        task.last_updated_at = Utc::now().to_rfc3339();
342
343        Ok(task.clone())
344    }
345
346    /// Cleanup expired tasks (should be called periodically)
347    pub async fn cleanup_expired(&self) {
348        let now = Instant::now();
349        let mut to_remove = Vec::new();
350
351        for entry in self.tasks.iter() {
352            if let Some(expires_at) = entry.expires_at
353                && now >= expires_at
354            {
355                to_remove.push(entry.key().clone());
356            }
357        }
358
359        for task_id in to_remove {
360            if let Some((_, entry)) = self.tasks.remove(&task_id) {
361                // Remove from session index
362                if let Some(mut ids) = self.by_session.get_mut(&entry.session_id) {
363                    ids.remove(&task_id);
364                }
365            }
366        }
367    }
368
369    /// Spawn background cleanup task
370    pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) {
371        tokio::spawn(async move {
372            let mut interval = tokio::time::interval(interval);
373            loop {
374                interval.tick().await;
375                self.cleanup_expired().await;
376            }
377        });
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use serde_json::json;
385
386    #[tokio::test]
387    async fn test_create_task() {
388        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
389        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
390
391        let (task, _rx) = store.create_task("session1", req, None);
392
393        assert_eq!(task.status, TaskStatus::Working);
394        assert!(task.task_id.len() > 0);
395        assert_eq!(task.ttl, Some(60000));
396    }
397
398    #[tokio::test]
399    async fn test_get_task() {
400        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
401        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
402
403        let (task, _rx) = store.create_task("session1", req, None);
404        let retrieved = store.get_task(&task.task_id).await.unwrap();
405
406        assert_eq!(retrieved.task_id, task.task_id);
407        assert_eq!(retrieved.status, TaskStatus::Working);
408    }
409
410    #[tokio::test]
411    async fn test_session_isolation() {
412        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
413        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
414
415        let (task, _rx) = store.create_task("session1", req, None);
416
417        // Should not be visible from different session
418        assert!(
419            store
420                .get_task_for_session(&task.task_id, "session2")
421                .await
422                .is_none()
423        );
424
425        // Should be visible from same session
426        assert!(
427            store
428                .get_task_for_session(&task.task_id, "session1")
429                .await
430                .is_some()
431        );
432    }
433
434    #[tokio::test]
435    async fn test_update_status() {
436        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
437        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
438
439        let (task, _rx) = store.create_task("session1", req, None);
440
441        store
442            .update_status(
443                &task.task_id,
444                TaskStatus::Completed,
445                Some("Done".to_string()),
446            )
447            .await
448            .unwrap();
449
450        let updated = store.get_task(&task.task_id).await.unwrap();
451        assert_eq!(updated.status, TaskStatus::Completed);
452        assert_eq!(updated.status_message, Some("Done".to_string()));
453    }
454
455    #[tokio::test]
456    async fn test_store_and_get_result() {
457        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
458        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
459
460        let (task, _rx) = store.create_task("session1", req, None);
461
462        let result = json!({"answer": 42});
463        store
464            .store_result(&task.task_id, result.clone())
465            .await
466            .unwrap();
467
468        let retrieved = store.get_result(&task.task_id).await.unwrap();
469        assert_eq!(retrieved, result);
470    }
471
472    #[tokio::test]
473    async fn test_cancel_task() {
474        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
475        let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
476
477        let (task, _rx) = store.create_task("session1", req, None);
478
479        let cancelled = store.cancel_task(&task.task_id, "session1").await.unwrap();
480        assert_eq!(cancelled.status, TaskStatus::Cancelled);
481    }
482
483    #[tokio::test]
484    async fn test_list_tasks() {
485        let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
486
487        // Create multiple tasks
488        for _ in 0..5 {
489            let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
490            store.create_task("session1", req, None);
491        }
492
493        let (tasks, cursor) = store.list_tasks("session1", None, 10).await;
494        assert_eq!(tasks.len(), 5);
495        assert!(cursor.is_none());
496    }
497}