claude_agent_sdk/mcp/
tasks.rs

1//! MCP 2025-11-25 Async Tasks Implementation
2//!
3//! This module implements the Tasks primitive from the MCP 2025-11-25 spec,
4//! enabling "call-now, fetch-later" asynchronous workflows.
5//!
6//! # Overview
7//!
8//! The Tasks primitive allows any request to become asynchronous:
9//! - Client calls a tool with a task hint
10//! - Server returns a task handle immediately
11//! - Client polls or subscribes to the task resource for progress and results
12//!
13//! # Task States
14//!
15//! Tasks move through these states:
16//! - `Queued` - Task is waiting to start
17//! - `Working` - Task is in progress
18//! - `InputRequired` - Task needs user input
19//! - `Completed` - Task finished successfully
20//! - `Failed` - Task failed with an error
21//! - `Cancelled` - Task was cancelled
22//!
23//! # Example
24//!
25//! ```no_run
26//! use claude_agent_sdk::mcp::tasks::{TaskManager, TaskStatus, TaskRequest};
27//! use serde_json::json;
28//!
29//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
30//! let manager = TaskManager::new();
31//!
32//! // Create a task
33//! let request = TaskRequest {
34//!     method: "tools/call".to_string(),
35//!     params: json!({"name": "my_tool", "arguments": {}}),
36//!     ..Default::default()
37//! };
38//!
39//! let task = manager.create_task(request).await?;
40//! println!("Task ID: {}", task.id);
41//!
42//! // Poll for status
43//! loop {
44//!     let status = manager.get_task_status(&task.id).await?;
45//!     if status.is_terminal() {
46//!         break;
47//!     }
48//!     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
49//! }
50//!
51//! // Get result
52//! let result = manager.get_task_result(&task.id).await?;
53//! println!("Result: {:?}", result);
54//! # Ok(())
55//! # }
56//! ```
57
58use crate::errors::{ClaudeError, Result};
59use serde::{Deserialize, Serialize};
60use std::collections::HashMap;
61use std::sync::Arc;
62use tokio::sync::RwLock;
63use uuid::Uuid;
64
65/// Task ID
66pub type TaskId = String;
67
68/// Task resource URI
69pub type TaskUri = String;
70
71/// Task request
72#[derive(Debug, Clone, Serialize, Deserialize, Default)]
73pub struct TaskRequest {
74    /// JSON-RPC method name
75    #[serde(default)]
76    pub method: String,
77    /// Request parameters
78    #[serde(default)]
79    pub params: serde_json::Value,
80    /// Task hint - indicates this might take a while
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub task_hint: Option<TaskHint>,
83    /// Priority hint for task scheduling
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub priority: Option<TaskPriority>,
86}
87
88/// Task hint
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskHint {
91    /// Estimated duration in seconds
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub estimated_duration_secs: Option<u64>,
94    /// Whether progress notifications will be sent
95    #[serde(default)]
96    pub supports_progress: bool,
97    /// Whether the task can be cancelled
98    #[serde(default)]
99    pub cancellable: bool,
100}
101
102impl Default for TaskHint {
103    fn default() -> Self {
104        Self {
105            estimated_duration_secs: None,
106            supports_progress: false,
107            cancellable: true,
108        }
109    }
110}
111
112/// Task priority
113#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
114#[serde(rename_all = "lowercase")]
115pub enum TaskPriority {
116    Low,
117    Normal,
118    High,
119    Urgent,
120}
121
122impl Default for TaskPriority {
123    fn default() -> Self {
124        Self::Normal
125    }
126}
127
128/// Task state
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
130#[serde(rename_all = "lowercase")]
131pub enum TaskState {
132    Queued,
133    Working,
134    InputRequired,
135    Completed,
136    Failed,
137    Cancelled,
138}
139
140impl TaskState {
141    /// Check if this is a terminal state (no further transitions possible)
142    pub fn is_terminal(&self) -> bool {
143        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
144    }
145
146    /// Check if this state represents an active task
147    pub fn is_active(&self) -> bool {
148        matches!(self, Self::Queued | Self::Working | Self::InputRequired)
149    }
150}
151
152/// Task progress
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TaskProgress {
155    /// Progress value between 0.0 and 1.0
156    pub value: f64,
157    /// Human-readable progress message
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub message: Option<String>,
160}
161
162impl TaskProgress {
163    /// Create new progress
164    pub fn new(value: f64) -> Self {
165        assert!(
166            (0.0..=1.0).contains(&value),
167            "Progress must be between 0.0 and 1.0"
168        );
169        Self {
170            value,
171            message: None,
172        }
173    }
174
175    /// Add a message to the progress
176    pub fn with_message(mut self, message: impl Into<String>) -> Self {
177        self.message = Some(message.into());
178        self
179    }
180}
181
182/// Task status
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TaskStatus {
185    /// Task ID
186    pub id: TaskId,
187    /// Task state
188    pub state: TaskState,
189    /// Current progress (if available)
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub progress: Option<TaskProgress>,
192    /// Error message (if failed)
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub error: Option<String>,
195    /// Timestamp when task was created
196    pub created_at: chrono::DateTime<chrono::Utc>,
197    /// Timestamp when task was last updated
198    pub updated_at: chrono::DateTime<chrono::Utc>,
199    /// Timestamp when task completed (if terminal)
200    #[serde(skip_serializing_if = "Option::is_none")]
201    pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
202}
203
204impl TaskStatus {
205    /// Check if task is in a terminal state
206    pub fn is_terminal(&self) -> bool {
207        self.state.is_terminal()
208    }
209
210    /// Check if task is still active
211    pub fn is_active(&self) -> bool {
212        self.state.is_active()
213    }
214}
215
216/// Task result
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct TaskResult {
219    /// Task ID
220    pub id: TaskId,
221    /// Result data
222    pub data: serde_json::Value,
223    /// Timestamp when result was produced
224    pub completed_at: chrono::DateTime<chrono::Utc>,
225}
226
227/// Task handle - returned immediately when creating a task
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct TaskHandle {
230    /// Task ID
231    pub id: TaskId,
232    /// Task resource URI for polling/subscribing
233    pub uri: TaskUri,
234    /// Initial status
235    pub status: TaskStatus,
236}
237
238/// Internal task storage
239#[derive(Debug, Clone)]
240struct Task {
241    id: TaskId,
242    request: TaskRequest,
243    state: TaskState,
244    progress: Option<TaskProgress>,
245    result: Option<serde_json::Value>,
246    error: Option<String>,
247    created_at: chrono::DateTime<chrono::Utc>,
248    updated_at: chrono::DateTime<chrono::Utc>,
249    completed_at: Option<chrono::DateTime<chrono::Utc>>,
250}
251
252impl Task {
253    fn new(request: TaskRequest) -> Self {
254        let now = chrono::Utc::now();
255        Self {
256            id: Uuid::new_v4().to_string(),
257            request,
258            state: TaskState::Queued,
259            progress: None,
260            result: None,
261            error: None,
262            created_at: now,
263            updated_at: now,
264            completed_at: None,
265        }
266    }
267
268    fn to_status(&self) -> TaskStatus {
269        TaskStatus {
270            id: self.id.clone(),
271            state: self.state.clone(),
272            progress: self.progress.clone(),
273            error: self.error.clone(),
274            created_at: self.created_at,
275            updated_at: self.updated_at,
276            completed_at: self.completed_at,
277        }
278    }
279}
280
281/// Task manager
282///
283/// Manages the lifecycle of async tasks, including creation,
284/// status polling, progress updates, and result retrieval.
285#[derive(Clone)]
286pub struct TaskManager {
287    tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
288    base_uri: String,
289}
290
291impl TaskManager {
292    /// Create a new task manager
293    pub fn new() -> Self {
294        Self::with_base_uri("mcp://tasks".to_string())
295    }
296
297    /// Create a new task manager with a custom base URI
298    pub fn with_base_uri(base_uri: impl Into<String>) -> Self {
299        Self {
300            tasks: Arc::new(RwLock::new(HashMap::new())),
301            base_uri: base_uri.into(),
302        }
303    }
304
305    /// Create a new task
306    ///
307    /// Returns a task handle immediately with the task in Queued state.
308    pub async fn create_task(&self, request: TaskRequest) -> Result<TaskHandle> {
309        let task = Task::new(request);
310        let task_id = task.id.clone();
311        let uri = format!("{}/{}", self.base_uri, task_id);
312        let status = task.to_status();
313
314        // Store the task
315        let mut tasks = self.tasks.write().await;
316        tasks.insert(task_id.clone(), task);
317
318        Ok(TaskHandle {
319            id: task_id,
320            uri,
321            status,
322        })
323    }
324
325    /// Get task status
326    pub async fn get_task_status(&self, task_id: &TaskId) -> Result<TaskStatus> {
327        let tasks = self.tasks.read().await;
328        let task = tasks
329            .get(task_id)
330            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
331
332        Ok(task.to_status())
333    }
334
335    /// Get task result
336    ///
337    /// Returns an error if the task hasn't completed yet.
338    pub async fn get_task_result(&self, task_id: &TaskId) -> Result<TaskResult> {
339        let tasks = self.tasks.read().await;
340        let task = tasks
341            .get(task_id)
342            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
343
344        if task.state != TaskState::Completed {
345            return Err(ClaudeError::InvalidInput(format!(
346                "Task is not completed. Current state: {:?}",
347                task.state
348            )));
349        }
350
351        let result = task.result.as_ref().ok_or_else(|| {
352            ClaudeError::InternalError("Completed task has no result".to_string())
353        })?;
354
355        Ok(TaskResult {
356            id: task_id.clone(),
357            data: result.clone(),
358            completed_at: task.completed_at.unwrap(),
359        })
360    }
361
362    /// Update task progress
363    ///
364    /// This should be called by the worker executing the task.
365    pub async fn update_progress(&self, task_id: &TaskId, progress: TaskProgress) -> Result<()> {
366        let mut tasks = self.tasks.write().await;
367        let task = tasks
368            .get_mut(task_id)
369            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
370
371        if task.state.is_terminal() {
372            return Err(ClaudeError::InvalidInput(
373                "Cannot update progress for terminal task".to_string(),
374            ));
375        }
376
377        task.progress = Some(progress);
378        task.updated_at = chrono::Utc::now();
379
380        Ok(())
381    }
382
383    /// Mark task as working
384    pub async fn mark_working(&self, task_id: &TaskId) -> Result<()> {
385        let mut tasks = self.tasks.write().await;
386        let task = tasks
387            .get_mut(task_id)
388            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
389
390        if task.state.is_terminal() {
391            return Err(ClaudeError::InvalidInput(
392                "Cannot transition terminal task".to_string(),
393            ));
394        }
395
396        task.state = TaskState::Working;
397        task.updated_at = chrono::Utc::now();
398
399        Ok(())
400    }
401
402    /// Mark task as completed with result
403    pub async fn mark_completed(&self, task_id: &TaskId, result: serde_json::Value) -> Result<()> {
404        let mut tasks = self.tasks.write().await;
405        let task = tasks
406            .get_mut(task_id)
407            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
408
409        if task.state.is_terminal() {
410            return Err(ClaudeError::InvalidInput(
411                "Cannot transition terminal task".to_string(),
412            ));
413        }
414
415        let now = chrono::Utc::now();
416        task.state = TaskState::Completed;
417        task.result = Some(result);
418        task.updated_at = now;
419        task.completed_at = Some(now);
420
421        Ok(())
422    }
423
424    /// Mark task as failed
425    pub async fn mark_failed(&self, task_id: &TaskId, error: impl Into<String>) -> Result<()> {
426        let mut tasks = self.tasks.write().await;
427        let task = tasks
428            .get_mut(task_id)
429            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
430
431        if task.state.is_terminal() {
432            return Err(ClaudeError::InvalidInput(
433                "Cannot transition terminal task".to_string(),
434            ));
435        }
436
437        let now = chrono::Utc::now();
438        task.state = TaskState::Failed;
439        task.error = Some(error.into());
440        task.updated_at = now;
441        task.completed_at = Some(now);
442
443        Ok(())
444    }
445
446    /// Mark task as cancelled
447    pub async fn mark_cancelled(&self, task_id: &TaskId) -> Result<()> {
448        let mut tasks = self.tasks.write().await;
449        let task = tasks
450            .get_mut(task_id)
451            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
452
453        if task.state.is_terminal() {
454            return Err(ClaudeError::InvalidInput(
455                "Cannot transition terminal task".to_string(),
456            ));
457        }
458
459        let now = chrono::Utc::now();
460        task.state = TaskState::Cancelled;
461        task.updated_at = now;
462        task.completed_at = Some(now);
463
464        Ok(())
465    }
466
467    /// Mark task as requiring input
468    pub async fn mark_input_required(&self, task_id: &TaskId) -> Result<()> {
469        let mut tasks = self.tasks.write().await;
470        let task = tasks
471            .get_mut(task_id)
472            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
473
474        if task.state.is_terminal() {
475            return Err(ClaudeError::InvalidInput(
476                "Cannot transition terminal task".to_string(),
477            ));
478        }
479
480        task.state = TaskState::InputRequired;
481        task.updated_at = chrono::Utc::now();
482
483        Ok(())
484    }
485
486    /// List all tasks
487    pub async fn list_tasks(&self) -> Result<Vec<TaskStatus>> {
488        let tasks = self.tasks.read().await;
489        Ok(tasks.values().map(|t| t.to_status()).collect())
490    }
491
492    /// Cancel a task
493    ///
494    /// Returns an error if the task is already in a terminal state
495    /// or doesn't support cancellation.
496    pub async fn cancel_task(&self, task_id: &TaskId) -> Result<()> {
497        let mut tasks = self.tasks.write().await;
498        let task = tasks
499            .get_mut(task_id)
500            .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
501
502        if task.state.is_terminal() {
503            return Err(ClaudeError::InvalidInput(format!(
504                "Cannot cancel task in state: {:?}",
505                task.state
506            )));
507        }
508
509        // Check if task is cancellable (based on request hint)
510        if let Some(hint) = &task.request.task_hint {
511            if !hint.cancellable {
512                return Err(ClaudeError::InvalidInput(
513                    "Task is not cancellable".to_string(),
514                ));
515            }
516        }
517
518        let now = chrono::Utc::now();
519        task.state = TaskState::Cancelled;
520        task.updated_at = now;
521        task.completed_at = Some(now);
522
523        Ok(())
524    }
525
526    /// Clean up old completed tasks
527    ///
528    /// Removes tasks that completed before the given threshold.
529    pub async fn cleanup_old_tasks(&self, older_than: chrono::Duration) -> Result<usize> {
530        let mut tasks = self.tasks.write().await;
531        let cutoff = chrono::Utc::now() - older_than;
532
533        let initial_count = tasks.len();
534        tasks.retain(|_, task| {
535            if let Some(completed_at) = task.completed_at {
536                completed_at > cutoff
537            } else {
538                true // Keep active tasks
539            }
540        });
541
542        Ok(initial_count - tasks.len())
543    }
544}
545
546impl Default for TaskManager {
547    fn default() -> Self {
548        Self::new()
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use serde_json::json;
556
557    #[tokio::test]
558    async fn test_task_creation() {
559        let manager = TaskManager::new();
560
561        let request = TaskRequest {
562            method: "tools/call".to_string(),
563            params: json!({"name": "test"}),
564            ..Default::default()
565        };
566
567        let handle = manager.create_task(request).await.unwrap();
568        assert!(!handle.id.is_empty());
569        assert!(!handle.uri.is_empty());
570        assert_eq!(handle.status.state, TaskState::Queued);
571    }
572
573    #[tokio::test]
574    async fn test_task_status() {
575        let manager = TaskManager::new();
576
577        let request = TaskRequest {
578            method: "tools/call".to_string(),
579            params: json!({}),
580            ..Default::default()
581        };
582
583        let handle = manager.create_task(request).await.unwrap();
584        let status = manager.get_task_status(&handle.id).await.unwrap();
585
586        assert_eq!(status.id, handle.id);
587        assert_eq!(status.state, TaskState::Queued);
588        assert!(status.is_active());
589        assert!(!status.is_terminal());
590    }
591
592    #[tokio::test]
593    async fn test_task_lifecycle() {
594        let manager = TaskManager::new();
595
596        let request = TaskRequest {
597            method: "tools/call".to_string(),
598            params: json!({}),
599            ..Default::default()
600        };
601
602        let handle = manager.create_task(request).await.unwrap();
603
604        // Transition to Working
605        manager.mark_working(&handle.id).await.unwrap();
606        let status = manager.get_task_status(&handle.id).await.unwrap();
607        assert_eq!(status.state, TaskState::Working);
608
609        // Update progress
610        let progress = TaskProgress::new(0.5).with_message("Half done");
611        manager.update_progress(&handle.id, progress).await.unwrap();
612        let status = manager.get_task_status(&handle.id).await.unwrap();
613        assert_eq!(status.progress.as_ref().unwrap().value, 0.5);
614        assert_eq!(
615            status.progress.as_ref().unwrap().message.as_ref().unwrap(),
616            "Half done"
617        );
618
619        // Complete with result
620        let result = json!({"output": "success"});
621        manager.mark_completed(&handle.id, result).await.unwrap();
622        let status = manager.get_task_status(&handle.id).await.unwrap();
623        assert_eq!(status.state, TaskState::Completed);
624        assert!(status.is_terminal());
625        assert!(!status.is_active());
626
627        // Get result
628        let task_result = manager.get_task_result(&handle.id).await.unwrap();
629        assert_eq!(task_result.id, handle.id);
630        assert_eq!(task_result.data, json!({"output": "success"}));
631    }
632
633    #[tokio::test]
634    async fn test_task_failure() {
635        let manager = TaskManager::new();
636
637        let request = TaskRequest {
638            method: "tools/call".to_string(),
639            params: json!({}),
640            ..Default::default()
641        };
642
643        let handle = manager.create_task(request).await.unwrap();
644
645        manager
646            .mark_failed(&handle.id, "Something went wrong")
647            .await
648            .unwrap();
649
650        let status = manager.get_task_status(&handle.id).await.unwrap();
651        assert_eq!(status.state, TaskState::Failed);
652        assert!(status.is_terminal());
653        assert_eq!(status.error.as_ref().unwrap(), "Something went wrong");
654    }
655
656    #[tokio::test]
657    async fn test_task_cancellation() {
658        let manager = TaskManager::new();
659
660        let request = TaskRequest {
661            method: "tools/call".to_string(),
662            params: json!({}),
663            task_hint: Some(TaskHint {
664                cancellable: true,
665                ..Default::default()
666            }),
667            ..Default::default()
668        };
669
670        let handle = manager.create_task(request).await.unwrap();
671        manager.cancel_task(&handle.id).await.unwrap();
672
673        let status = manager.get_task_status(&handle.id).await.unwrap();
674        assert_eq!(status.state, TaskState::Cancelled);
675        assert!(status.is_terminal());
676    }
677
678    #[tokio::test]
679    async fn test_non_cancellable_task() {
680        let manager = TaskManager::new();
681
682        let request = TaskRequest {
683            method: "tools/call".to_string(),
684            params: json!({}),
685            task_hint: Some(TaskHint {
686                cancellable: false,
687                ..Default::default()
688            }),
689            ..Default::default()
690        };
691
692        let handle = manager.create_task(request).await.unwrap();
693        let result = manager.cancel_task(&handle.id).await;
694
695        assert!(result.is_err());
696    }
697
698    #[tokio::test]
699    async fn test_terminal_state_transitions() {
700        let manager = TaskManager::new();
701
702        let request = TaskRequest {
703            method: "tools/call".to_string(),
704            params: json!({}),
705            ..Default::default()
706        };
707
708        let handle = manager.create_task(request).await.unwrap();
709
710        // Complete the task
711        manager.mark_completed(&handle.id, json!({})).await.unwrap();
712
713        // Try to transition from terminal state
714        assert!(manager.mark_working(&handle.id).await.is_err());
715        assert!(
716            manager
717                .update_progress(&handle.id, TaskProgress::new(0.5))
718                .await
719                .is_err()
720        );
721    }
722
723    #[tokio::test]
724    async fn test_list_tasks() {
725        let manager = TaskManager::new();
726
727        let request = TaskRequest {
728            method: "tools/call".to_string(),
729            params: json!({}),
730            ..Default::default()
731        };
732
733        let _task1 = manager.create_task(request.clone()).await.unwrap();
734        let _task2 = manager.create_task(request).await.unwrap();
735
736        let tasks = manager.list_tasks().await.unwrap();
737        assert_eq!(tasks.len(), 2);
738    }
739
740    #[tokio::test]
741    async fn test_progress_bounds() {
742        // Test progress validation
743        assert!(TaskProgress::new(0.0).value == 0.0);
744        assert!(TaskProgress::new(0.5).value == 0.5);
745        assert!(TaskProgress::new(1.0).value == 1.0);
746    }
747
748    #[tokio::test]
749    async fn test_priority_ordering() {
750        assert!(TaskPriority::Low < TaskPriority::Normal);
751        assert!(TaskPriority::Normal < TaskPriority::High);
752        assert!(TaskPriority::High < TaskPriority::Urgent);
753    }
754
755    #[tokio::test]
756    async fn test_cleanup_old_tasks() {
757        let manager = TaskManager::new();
758
759        let request = TaskRequest {
760            method: "tools/call".to_string(),
761            params: json!({}),
762            ..Default::default()
763        };
764
765        let handle = manager.create_task(request).await.unwrap();
766        manager.mark_completed(&handle.id, json!({})).await.unwrap();
767
768        // Cleanup tasks older than 1 second (should be none immediately)
769        let cleaned = manager
770            .cleanup_old_tasks(chrono::Duration::seconds(1))
771            .await
772            .unwrap();
773        assert_eq!(cleaned, 0);
774
775        // Cleanup tasks older than 0 seconds (should clean all)
776        let cleaned = manager
777            .cleanup_old_tasks(chrono::Duration::seconds(0))
778            .await
779            .unwrap();
780        assert_eq!(cleaned, 1);
781    }
782}