mcpkit_server/capability/
tasks.rs

1//! Task capability implementation.
2//!
3//! This module provides support for long-running tasks in MCP servers.
4//!
5//! Tasks allow servers to execute long-running operations while
6//! providing progress updates and supporting cancellation.
7
8use crate::context::CancellationToken;
9use crate::context::Context;
10use crate::handler::TaskHandler;
11use mcpkit_core::error::McpError;
12use mcpkit_core::types::task::{Task, TaskId, TaskStatus};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16use std::time::Instant;
17
18/// Internal state for a running task.
19#[derive(Debug)]
20pub struct TaskState {
21    /// Task metadata.
22    pub task: Task,
23    /// Cancellation token.
24    pub cancel_token: CancellationToken,
25    /// When the task was last accessed (for cleanup).
26    pub last_access: Instant,
27}
28
29impl TaskState {
30    /// Create a new task state.
31    fn new(task: Task) -> Self {
32        Self {
33            task,
34            cancel_token: CancellationToken::new(),
35            last_access: Instant::now(),
36        }
37    }
38
39    /// Check if the task is cancelled.
40    #[must_use]
41    pub fn is_cancelled(&self) -> bool {
42        self.cancel_token.is_cancelled()
43    }
44}
45
46/// Handle for interacting with a running task.
47///
48/// This handle is given to task executors to report progress
49/// and completion.
50pub struct TaskHandle {
51    task_id: TaskId,
52    manager: Arc<TaskManager>,
53}
54
55impl TaskHandle {
56    /// Get the task ID.
57    #[must_use]
58    pub const fn id(&self) -> &TaskId {
59        &self.task_id
60    }
61
62    /// Report that the task is now running.
63    pub async fn running(&self) -> Result<(), McpError> {
64        self.manager
65            .update_status(&self.task_id, TaskStatus::Running)
66            .await
67    }
68
69    /// Report progress on the task.
70    pub async fn progress(
71        &self,
72        current: u64,
73        total: Option<u64>,
74        message: Option<&str>,
75    ) -> Result<(), McpError> {
76        self.manager
77            .update_progress(&self.task_id, current, total, message)
78            .await
79    }
80
81    /// Mark the task as completed with a result.
82    pub async fn complete(&self, result: Value) -> Result<(), McpError> {
83        self.manager.complete_success(&self.task_id, result).await
84    }
85
86    /// Mark the task as failed with an error.
87    pub async fn error(&self, message: impl Into<String>) -> Result<(), McpError> {
88        self.manager
89            .complete_error(&self.task_id, message.into())
90            .await
91    }
92
93    /// Check if the task has been cancelled.
94    #[must_use]
95    pub fn is_cancelled(&self) -> bool {
96        self.manager
97            .get(&self.task_id)
98            .is_none_or(|s| s.is_cancelled())
99    }
100
101    /// Get a future that completes when the task is cancelled.
102    pub async fn cancelled(&self) {
103        if let Some(state) = self.manager.get(&self.task_id) {
104            state.cancel_token.cancelled().await;
105        }
106    }
107}
108
109/// Manager for coordinating tasks.
110///
111/// This manages the lifecycle of tasks, including creation,
112/// progress tracking, cancellation, and cleanup.
113pub struct TaskManager {
114    tasks: RwLock<HashMap<TaskId, TaskState>>,
115}
116
117impl Default for TaskManager {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl TaskManager {
124    /// Create a new task manager.
125    #[must_use]
126    pub fn new() -> Self {
127        Self {
128            tasks: RwLock::new(HashMap::new()),
129        }
130    }
131
132    /// Create a new task.
133    pub fn create(self: &Arc<Self>, tool_name: Option<&str>) -> TaskHandle {
134        let mut task = Task::create();
135        task.tool = tool_name.map(String::from);
136
137        let task_id = task.id.clone();
138        let state = TaskState::new(task);
139
140        if let Ok(mut tasks) = self.tasks.write() {
141            tasks.insert(task_id.clone(), state);
142        }
143
144        TaskHandle {
145            task_id,
146            manager: Arc::clone(self),
147        }
148    }
149
150    /// Get a task state by ID.
151    pub fn get(&self, id: &TaskId) -> Option<TaskState> {
152        self.tasks.read().ok()?.get(id).map(|s| TaskState {
153            task: s.task.clone(),
154            cancel_token: s.cancel_token.clone(),
155            last_access: s.last_access,
156        })
157    }
158
159    /// List all tasks.
160    pub fn list(&self) -> Vec<Task> {
161        self.tasks
162            .read()
163            .map(|tasks| tasks.values().map(|s| s.task.clone()).collect())
164            .unwrap_or_default()
165    }
166
167    /// Cancel a task.
168    pub fn cancel(&self, id: &TaskId) -> Result<(), McpError> {
169        let mut tasks = self
170            .tasks
171            .write()
172            .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
173
174        if let Some(state) = tasks.get_mut(id) {
175            state.cancel_token.cancel();
176            state.task.status = TaskStatus::Cancelled;
177            state.task.updated_at = chrono::Utc::now();
178            Ok(())
179        } else {
180            Err(McpError::invalid_params(
181                "tasks/cancel",
182                format!("Unknown task: {}", id.as_str()),
183            ))
184        }
185    }
186
187    /// Update task status.
188    async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Result<(), McpError> {
189        let mut tasks = self
190            .tasks
191            .write()
192            .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
193
194        if let Some(state) = tasks.get_mut(id) {
195            state.task.status = status;
196            state.task.updated_at = chrono::Utc::now();
197            state.last_access = Instant::now();
198            Ok(())
199        } else {
200            Err(McpError::invalid_params(
201                "tasks/get",
202                format!("Unknown task: {}", id.as_str()),
203            ))
204        }
205    }
206
207    /// Update task progress.
208    async fn update_progress(
209        &self,
210        id: &TaskId,
211        current: u64,
212        total: Option<u64>,
213        message: Option<&str>,
214    ) -> Result<(), McpError> {
215        let mut tasks = self
216            .tasks
217            .write()
218            .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
219
220        if let Some(state) = tasks.get_mut(id) {
221            state.task.progress = Some(mcpkit_core::types::task::TaskProgress {
222                current,
223                total,
224                message: message.map(String::from),
225            });
226            state.task.updated_at = chrono::Utc::now();
227            state.last_access = Instant::now();
228            Ok(())
229        } else {
230            Err(McpError::invalid_params(
231                "tasks/get",
232                format!("Unknown task: {}", id.as_str()),
233            ))
234        }
235    }
236
237    /// Complete a task with success.
238    async fn complete_success(&self, id: &TaskId, result: Value) -> Result<(), McpError> {
239        let mut tasks = self
240            .tasks
241            .write()
242            .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
243
244        if let Some(state) = tasks.get_mut(id) {
245            state.task.status = TaskStatus::Completed;
246            state.task.result = Some(result);
247            state.task.updated_at = chrono::Utc::now();
248            state.last_access = Instant::now();
249            Ok(())
250        } else {
251            Err(McpError::invalid_params(
252                "tasks/get",
253                format!("Unknown task: {}", id.as_str()),
254            ))
255        }
256    }
257
258    /// Complete a task with an error.
259    async fn complete_error(&self, id: &TaskId, message: String) -> Result<(), McpError> {
260        let mut tasks = self
261            .tasks
262            .write()
263            .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
264
265        if let Some(state) = tasks.get_mut(id) {
266            state.task.status = TaskStatus::Failed;
267            state.task.error = Some(mcpkit_core::types::task::TaskError {
268                code: -1,
269                message,
270                data: None,
271            });
272            state.task.updated_at = chrono::Utc::now();
273            state.last_access = Instant::now();
274            Ok(())
275        } else {
276            Err(McpError::invalid_params(
277                "tasks/get",
278                format!("Unknown task: {}", id.as_str()),
279            ))
280        }
281    }
282
283    /// Remove completed tasks older than the given duration.
284    pub fn cleanup(&self, max_age: std::time::Duration) {
285        if let Ok(mut tasks) = self.tasks.write() {
286            tasks.retain(|_, state| {
287                let is_terminal = state.task.status.is_terminal();
288                !is_terminal || state.last_access.elapsed() < max_age
289            });
290        }
291    }
292}
293
294/// Task service implementing the `TaskHandler` trait.
295pub struct TaskService {
296    manager: Arc<TaskManager>,
297}
298
299impl Default for TaskService {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl TaskService {
306    /// Create a new task service.
307    #[must_use]
308    pub fn new() -> Self {
309        Self {
310            manager: Arc::new(TaskManager::new()),
311        }
312    }
313
314    /// Get the underlying task manager.
315    #[must_use]
316    pub const fn manager(&self) -> &Arc<TaskManager> {
317        &self.manager
318    }
319
320    /// Create a new task and get a handle for it.
321    #[must_use]
322    pub fn create(&self, tool_name: Option<&str>) -> TaskHandle {
323        self.manager.create(tool_name)
324    }
325}
326
327impl TaskHandler for TaskService {
328    async fn list_tasks(&self, _ctx: &Context<'_>) -> Result<Vec<Task>, McpError> {
329        Ok(self.manager.list())
330    }
331
332    async fn get_task(
333        &self,
334        task_id: &TaskId,
335        _ctx: &Context<'_>,
336    ) -> Result<Option<Task>, McpError> {
337        Ok(self.manager.get(task_id).map(|s| s.task))
338    }
339
340    async fn cancel_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<bool, McpError> {
341        match self.manager.cancel(task_id) {
342            Ok(()) => Ok(true),
343            Err(_) => Ok(false),
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_task_manager() {
354        let manager = Arc::new(TaskManager::new());
355
356        let handle = manager.create(Some("test-tool"));
357        assert!(!handle.is_cancelled());
358
359        let tasks = manager.list();
360        assert_eq!(tasks.len(), 1);
361        assert_eq!(tasks[0].tool.as_deref(), Some("test-tool"));
362        assert_eq!(tasks[0].status, TaskStatus::Pending);
363    }
364
365    #[tokio::test]
366    async fn test_task_lifecycle() -> Result<(), Box<dyn std::error::Error>> {
367        let manager = Arc::new(TaskManager::new());
368
369        let handle = manager.create(Some("processor"));
370        let task_id = handle.id().clone();
371
372        // Start running
373        handle.running().await?;
374        let state = manager.get(&task_id).ok_or("Task not found")?;
375        assert_eq!(state.task.status, TaskStatus::Running);
376
377        // Report progress
378        handle.progress(50, Some(100), Some("Halfway done")).await?;
379        let state = manager.get(&task_id).ok_or("Task not found")?;
380        assert_eq!(state.task.progress.as_ref().map(|p| p.current), Some(50));
381
382        // Complete
383        handle
384            .complete(serde_json::json!({"result": "success"}))
385            .await?;
386        let state = manager.get(&task_id).ok_or("Task not found")?;
387        assert_eq!(state.task.status, TaskStatus::Completed);
388
389        Ok(())
390    }
391
392    #[test]
393    fn test_task_cancellation() -> Result<(), Box<dyn std::error::Error>> {
394        let manager = Arc::new(TaskManager::new());
395
396        let handle = manager.create(None);
397        let task_id = handle.id().clone();
398
399        assert!(!handle.is_cancelled());
400
401        manager.cancel(&task_id)?;
402
403        assert!(handle.is_cancelled());
404        let state = manager.get(&task_id).ok_or("Task not found")?;
405        assert_eq!(state.task.status, TaskStatus::Cancelled);
406
407        Ok(())
408    }
409
410    #[tokio::test]
411    async fn test_task_service() -> Result<(), Box<dyn std::error::Error>> {
412        let service = TaskService::new();
413
414        let handle = service.create(Some("service-task"));
415        handle.running().await?;
416
417        let tasks = service.manager.list();
418        assert_eq!(tasks.len(), 1);
419
420        Ok(())
421    }
422}