Skip to main content

adk_tool/mcp/
task.rs

1// MCP Task Support (SEP-1686)
2//
3// Implements async task lifecycle for long-running MCP tool operations.
4// Tasks allow tools to be queued and polled rather than blocking.
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::time::Duration;
9
10/// Configuration for MCP task-based execution
11#[derive(Debug, Clone)]
12pub struct McpTaskConfig {
13    /// Enable task mode for long-running tools
14    pub enable_tasks: bool,
15    /// Default poll interval in milliseconds
16    pub poll_interval_ms: u64,
17    /// Maximum wait time before timeout (None = no timeout)
18    pub timeout_ms: Option<u64>,
19    /// Maximum number of poll attempts (None = unlimited)
20    pub max_poll_attempts: Option<u32>,
21}
22
23impl Default for McpTaskConfig {
24    fn default() -> Self {
25        Self {
26            enable_tasks: false,
27            poll_interval_ms: 1000,
28            timeout_ms: Some(300_000), // 5 minutes default
29            max_poll_attempts: None,
30        }
31    }
32}
33
34impl McpTaskConfig {
35    /// Create a new task config with tasks enabled
36    pub fn enabled() -> Self {
37        Self { enable_tasks: true, ..Default::default() }
38    }
39
40    /// Set the poll interval
41    pub fn poll_interval(mut self, interval: Duration) -> Self {
42        self.poll_interval_ms = interval.as_millis() as u64;
43        self
44    }
45
46    /// Set the timeout
47    pub fn timeout(mut self, timeout: Duration) -> Self {
48        self.timeout_ms = Some(timeout.as_millis() as u64);
49        self
50    }
51
52    /// Set no timeout (wait indefinitely)
53    pub fn no_timeout(mut self) -> Self {
54        self.timeout_ms = None;
55        self
56    }
57
58    /// Set maximum poll attempts
59    pub fn max_attempts(mut self, attempts: u32) -> Self {
60        self.max_poll_attempts = Some(attempts);
61        self
62    }
63
64    /// Get poll interval as Duration
65    pub fn poll_duration(&self) -> Duration {
66        Duration::from_millis(self.poll_interval_ms)
67    }
68
69    /// Get timeout as Duration
70    pub fn timeout_duration(&self) -> Option<Duration> {
71        self.timeout_ms.map(Duration::from_millis)
72    }
73
74    /// Convert to MCP task request parameters
75    pub fn to_task_params(&self) -> Value {
76        json!({
77            "poll_interval_ms": self.poll_interval_ms
78        })
79    }
80}
81
82/// Status of an MCP task
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[serde(rename_all = "lowercase")]
85pub enum TaskStatus {
86    /// Task is queued but not started
87    Pending,
88    /// Task is currently running
89    Running,
90    /// Task completed successfully
91    Completed,
92    /// Task failed with an error
93    Failed,
94    /// Task was cancelled
95    Cancelled,
96}
97
98impl TaskStatus {
99    /// Check if the task is in a terminal state
100    pub fn is_terminal(&self) -> bool {
101        matches!(self, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled)
102    }
103
104    /// Check if the task is still in progress
105    pub fn is_in_progress(&self) -> bool {
106        matches!(self, TaskStatus::Pending | TaskStatus::Running)
107    }
108}
109
110/// Information about an MCP task
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct TaskInfo {
113    /// Unique task identifier
114    pub task_id: String,
115    /// Current status
116    pub status: TaskStatus,
117    /// Progress percentage (0-100) if available
118    pub progress: Option<u8>,
119    /// Human-readable status message
120    pub message: Option<String>,
121    /// Estimated time remaining in milliseconds
122    pub eta_ms: Option<u64>,
123}
124
125/// Result of creating a task
126#[derive(Debug, Clone)]
127pub struct CreateTaskResult {
128    /// The task ID for polling
129    pub task_id: String,
130    /// Initial task info
131    pub info: TaskInfo,
132}
133
134/// Error during task operations
135#[derive(Debug, Clone)]
136pub enum TaskError {
137    /// Task creation failed
138    CreateFailed(String),
139    /// Task polling failed
140    PollFailed(String),
141    /// Task timed out
142    Timeout { task_id: String, elapsed_ms: u64 },
143    /// Task was cancelled
144    Cancelled(String),
145    /// Task failed with error
146    TaskFailed { task_id: String, error: String },
147    /// Maximum poll attempts exceeded
148    MaxAttemptsExceeded { task_id: String, attempts: u32 },
149}
150
151impl std::fmt::Display for TaskError {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            TaskError::CreateFailed(msg) => write!(f, "Failed to create task: {}", msg),
155            TaskError::PollFailed(msg) => write!(f, "Failed to poll task: {}", msg),
156            TaskError::Timeout { task_id, elapsed_ms } => {
157                write!(f, "Task '{}' timed out after {}ms", task_id, elapsed_ms)
158            }
159            TaskError::Cancelled(task_id) => write!(f, "Task '{}' was cancelled", task_id),
160            TaskError::TaskFailed { task_id, error } => {
161                write!(f, "Task '{}' failed: {}", task_id, error)
162            }
163            TaskError::MaxAttemptsExceeded { task_id, attempts } => {
164                write!(f, "Task '{}' exceeded {} poll attempts", task_id, attempts)
165            }
166        }
167    }
168}
169
170impl std::error::Error for TaskError {}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_task_config_default() {
178        let config = McpTaskConfig::default();
179        assert!(!config.enable_tasks);
180        assert_eq!(config.poll_interval_ms, 1000);
181        assert_eq!(config.timeout_ms, Some(300_000));
182    }
183
184    #[test]
185    fn test_task_config_enabled() {
186        let config = McpTaskConfig::enabled();
187        assert!(config.enable_tasks);
188    }
189
190    #[test]
191    fn test_task_config_builder() {
192        let config = McpTaskConfig::enabled()
193            .poll_interval(Duration::from_secs(2))
194            .timeout(Duration::from_secs(60))
195            .max_attempts(10);
196
197        assert!(config.enable_tasks);
198        assert_eq!(config.poll_interval_ms, 2000);
199        assert_eq!(config.timeout_ms, Some(60_000));
200        assert_eq!(config.max_poll_attempts, Some(10));
201    }
202
203    #[test]
204    fn test_task_status_terminal() {
205        assert!(!TaskStatus::Pending.is_terminal());
206        assert!(!TaskStatus::Running.is_terminal());
207        assert!(TaskStatus::Completed.is_terminal());
208        assert!(TaskStatus::Failed.is_terminal());
209        assert!(TaskStatus::Cancelled.is_terminal());
210    }
211
212    #[test]
213    fn test_task_status_in_progress() {
214        assert!(TaskStatus::Pending.is_in_progress());
215        assert!(TaskStatus::Running.is_in_progress());
216        assert!(!TaskStatus::Completed.is_in_progress());
217        assert!(!TaskStatus::Failed.is_in_progress());
218        assert!(!TaskStatus::Cancelled.is_in_progress());
219    }
220}