mcpkit_server/capability/
tasks.rs

1//! Task capability implementation.
2//!
3//! This module provides full support for long-running tasks
4//! in MCP servers - a feature missing from rmcp.
5//!
6//! Tasks allow servers to execute long-running operations while
7//! providing progress updates and supporting cancellation.
8
9use crate::context::CancellationToken;
10use crate::context::Context;
11use crate::handler::TaskHandler;
12use mcpkit_core::error::McpError;
13use mcpkit_core::types::task::{Task, TaskId, TaskStatus};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::Instant;
18
19/// Internal state for a running task.
20#[derive(Debug)]
21pub struct TaskState {
22    /// Task metadata.
23    pub task: Task,
24    /// Cancellation token.
25    pub cancel_token: CancellationToken,
26    /// When the task was last accessed (for cleanup).
27    pub last_access: Instant,
28}
29
30impl TaskState {
31    /// Create a new task state.
32    fn new(task: Task) -> Self {
33        Self {
34            task,
35            cancel_token: CancellationToken::new(),
36            last_access: Instant::now(),
37        }
38    }
39
40    /// Check if the task is cancelled.
41    pub fn is_cancelled(&self) -> bool {
42        self.cancel_token.is_cancelled()
43    }
44}
45
46/// Handle for interacting with a running task.
47///
48/// This handle is given to task executors to report progress
49/// and completion.
50pub struct TaskHandle {
51    task_id: TaskId,
52    manager: Arc<TaskManager>,
53}
54
55impl TaskHandle {
56    /// Get the task ID.
57    pub fn id(&self) -> &TaskId {
58        &self.task_id
59    }
60
61    /// Report that the task is now running.
62    pub async fn running(&self) -> Result<(), McpError> {
63        self.manager.update_status(&self.task_id, TaskStatus::Running).await
64    }
65
66    /// Report progress on the task.
67    pub async fn progress(
68        &self,
69        current: u64,
70        total: Option<u64>,
71        message: Option<&str>,
72    ) -> Result<(), McpError> {
73        self.manager
74            .update_progress(&self.task_id, current, total, message)
75            .await
76    }
77
78    /// Mark the task as completed with a result.
79    pub async fn complete(&self, result: Value) -> Result<(), McpError> {
80        self.manager.complete_success(&self.task_id, result).await
81    }
82
83    /// Mark the task as failed with an error.
84    pub async fn error(&self, message: impl Into<String>) -> Result<(), McpError> {
85        self.manager.complete_error(&self.task_id, message.into()).await
86    }
87
88    /// Check if the task has been cancelled.
89    pub fn is_cancelled(&self) -> bool {
90        self.manager
91            .get(&self.task_id)
92            .map(|s| s.is_cancelled())
93            .unwrap_or(true)
94    }
95
96    /// Get a future that completes when the task is cancelled.
97    pub async fn cancelled(&self) {
98        if let Some(state) = self.manager.get(&self.task_id) {
99            state.cancel_token.cancelled().await;
100        }
101    }
102}
103
104/// Manager for coordinating tasks.
105///
106/// This manages the lifecycle of tasks, including creation,
107/// progress tracking, cancellation, and cleanup.
108pub struct TaskManager {
109    tasks: RwLock<HashMap<TaskId, TaskState>>,
110}
111
112impl Default for TaskManager {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl TaskManager {
119    /// Create a new task manager.
120    pub fn new() -> Self {
121        Self {
122            tasks: RwLock::new(HashMap::new()),
123        }
124    }
125
126    /// Create a new task.
127    pub fn create(self: &Arc<Self>, tool_name: Option<&str>) -> TaskHandle {
128        let mut task = Task::create();
129        task.tool = tool_name.map(String::from);
130
131        let task_id = task.id.clone();
132        let state = TaskState::new(task);
133
134        if let Ok(mut tasks) = self.tasks.write() {
135            tasks.insert(task_id.clone(), state);
136        }
137
138        TaskHandle {
139            task_id,
140            manager: Arc::clone(self),
141        }
142    }
143
144    /// Get a task state by ID.
145    pub fn get(&self, id: &TaskId) -> Option<TaskState> {
146        self.tasks.read().ok()?.get(id).map(|s| TaskState {
147            task: s.task.clone(),
148            cancel_token: s.cancel_token.clone(),
149            last_access: s.last_access,
150        })
151    }
152
153    /// List all tasks.
154    pub fn list(&self) -> Vec<Task> {
155        self.tasks
156            .read()
157            .map(|tasks| tasks.values().map(|s| s.task.clone()).collect())
158            .unwrap_or_default()
159    }
160
161    /// Cancel a task.
162    pub fn cancel(&self, id: &TaskId) -> Result<(), McpError> {
163        let mut tasks = self.tasks.write().map_err(|_| {
164            McpError::internal("Failed to acquire task lock")
165        })?;
166
167        if let Some(state) = tasks.get_mut(id) {
168            state.cancel_token.cancel();
169            state.task.status = TaskStatus::Cancelled;
170            state.task.updated_at = chrono::Utc::now();
171            Ok(())
172        } else {
173            Err(McpError::invalid_params(
174                "tasks/cancel",
175                format!("Unknown task: {}", id.as_str()),
176            ))
177        }
178    }
179
180    /// Update task status.
181    async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Result<(), McpError> {
182        let mut tasks = self.tasks.write().map_err(|_| {
183            McpError::internal("Failed to acquire task lock")
184        })?;
185
186        if let Some(state) = tasks.get_mut(id) {
187            state.task.status = status;
188            state.task.updated_at = chrono::Utc::now();
189            state.last_access = Instant::now();
190            Ok(())
191        } else {
192            Err(McpError::invalid_params(
193                "tasks/get",
194                format!("Unknown task: {}", id.as_str()),
195            ))
196        }
197    }
198
199    /// Update task progress.
200    async fn update_progress(
201        &self,
202        id: &TaskId,
203        current: u64,
204        total: Option<u64>,
205        message: Option<&str>,
206    ) -> Result<(), McpError> {
207        let mut tasks = self.tasks.write().map_err(|_| {
208            McpError::internal("Failed to acquire task lock")
209        })?;
210
211        if let Some(state) = tasks.get_mut(id) {
212            state.task.progress = Some(mcpkit_core::types::task::TaskProgress {
213                current,
214                total,
215                message: message.map(String::from),
216            });
217            state.task.updated_at = chrono::Utc::now();
218            state.last_access = Instant::now();
219            Ok(())
220        } else {
221            Err(McpError::invalid_params(
222                "tasks/get",
223                format!("Unknown task: {}", id.as_str()),
224            ))
225        }
226    }
227
228    /// Complete a task with success.
229    async fn complete_success(&self, id: &TaskId, result: Value) -> Result<(), McpError> {
230        let mut tasks = self.tasks.write().map_err(|_| {
231            McpError::internal("Failed to acquire task lock")
232        })?;
233
234        if let Some(state) = tasks.get_mut(id) {
235            state.task.status = TaskStatus::Completed;
236            state.task.result = Some(result);
237            state.task.updated_at = chrono::Utc::now();
238            state.last_access = Instant::now();
239            Ok(())
240        } else {
241            Err(McpError::invalid_params(
242                "tasks/get",
243                format!("Unknown task: {}", id.as_str()),
244            ))
245        }
246    }
247
248    /// Complete a task with an error.
249    async fn complete_error(&self, id: &TaskId, message: String) -> Result<(), McpError> {
250        let mut tasks = self.tasks.write().map_err(|_| {
251            McpError::internal("Failed to acquire task lock")
252        })?;
253
254        if let Some(state) = tasks.get_mut(id) {
255            state.task.status = TaskStatus::Failed;
256            state.task.error = Some(mcpkit_core::types::task::TaskError {
257                code: -1,
258                message,
259                data: None,
260            });
261            state.task.updated_at = chrono::Utc::now();
262            state.last_access = Instant::now();
263            Ok(())
264        } else {
265            Err(McpError::invalid_params(
266                "tasks/get",
267                format!("Unknown task: {}", id.as_str()),
268            ))
269        }
270    }
271
272    /// Remove completed tasks older than the given duration.
273    pub fn cleanup(&self, max_age: std::time::Duration) {
274        if let Ok(mut tasks) = self.tasks.write() {
275            tasks.retain(|_, state| {
276                let is_terminal = state.task.status.is_terminal();
277                !is_terminal || state.last_access.elapsed() < max_age
278            });
279        }
280    }
281}
282
283/// Task service implementing the TaskHandler trait.
284pub struct TaskService {
285    manager: Arc<TaskManager>,
286}
287
288impl Default for TaskService {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294impl TaskService {
295    /// Create a new task service.
296    pub fn new() -> Self {
297        Self {
298            manager: Arc::new(TaskManager::new()),
299        }
300    }
301
302    /// Get the underlying task manager.
303    pub fn manager(&self) -> &Arc<TaskManager> {
304        &self.manager
305    }
306
307    /// Create a new task and get a handle for it.
308    pub fn create(&self, tool_name: Option<&str>) -> TaskHandle {
309        self.manager.create(tool_name)
310    }
311}
312
313impl TaskHandler for TaskService {
314    async fn list_tasks(&self, _ctx: &Context<'_>) -> Result<Vec<Task>, McpError> {
315        Ok(self.manager.list())
316    }
317
318    async fn get_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<Option<Task>, McpError> {
319        Ok(self.manager.get(task_id).map(|s| s.task))
320    }
321
322    async fn cancel_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<bool, McpError> {
323        match self.manager.cancel(task_id) {
324            Ok(()) => Ok(true),
325            Err(_) => Ok(false),
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_task_manager() {
336        let manager = Arc::new(TaskManager::new());
337
338        let handle = manager.create(Some("test-tool"));
339        assert!(!handle.is_cancelled());
340
341        let tasks = manager.list();
342        assert_eq!(tasks.len(), 1);
343        assert_eq!(tasks[0].tool.as_deref(), Some("test-tool"));
344        assert_eq!(tasks[0].status, TaskStatus::Pending);
345    }
346
347    #[tokio::test]
348    async fn test_task_lifecycle() {
349        let manager = Arc::new(TaskManager::new());
350
351        let handle = manager.create(Some("processor"));
352        let task_id = handle.id().clone();
353
354        // Start running
355        handle.running().await.unwrap();
356        let state = manager.get(&task_id).unwrap();
357        assert_eq!(state.task.status, TaskStatus::Running);
358
359        // Report progress
360        handle.progress(50, Some(100), Some("Halfway done")).await.unwrap();
361        let state = manager.get(&task_id).unwrap();
362        assert_eq!(state.task.progress.as_ref().map(|p| p.current), Some(50));
363
364        // Complete
365        handle.complete(serde_json::json!({"result": "success"})).await.unwrap();
366        let state = manager.get(&task_id).unwrap();
367        assert_eq!(state.task.status, TaskStatus::Completed);
368    }
369
370    #[test]
371    fn test_task_cancellation() {
372        let manager = Arc::new(TaskManager::new());
373
374        let handle = manager.create(None);
375        let task_id = handle.id().clone();
376
377        assert!(!handle.is_cancelled());
378
379        manager.cancel(&task_id).unwrap();
380
381        assert!(handle.is_cancelled());
382        let state = manager.get(&task_id).unwrap();
383        assert_eq!(state.task.status, TaskStatus::Cancelled);
384    }
385
386    #[tokio::test]
387    async fn test_task_service() {
388        let service = TaskService::new();
389
390        let handle = service.create(Some("service-task"));
391        handle.running().await.unwrap();
392
393        let tasks = service.manager.list();
394        assert_eq!(tasks.len(), 1);
395    }
396}