1use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct McpTaskConfig {
13 pub enable_tasks: bool,
15 pub poll_interval_ms: u64,
17 pub timeout_ms: Option<u64>,
19 pub max_poll_attempts: Option<u32>,
21}
22
23impl Default for McpTaskConfig {
24 fn default() -> Self {
25 Self {
26 enable_tasks: false,
27 poll_interval_ms: 1000,
28 timeout_ms: Some(300_000), max_poll_attempts: None,
30 }
31 }
32}
33
34impl McpTaskConfig {
35 pub fn enabled() -> Self {
37 Self { enable_tasks: true, ..Default::default() }
38 }
39
40 pub fn poll_interval(mut self, interval: Duration) -> Self {
42 self.poll_interval_ms = interval.as_millis() as u64;
43 self
44 }
45
46 pub fn timeout(mut self, timeout: Duration) -> Self {
48 self.timeout_ms = Some(timeout.as_millis() as u64);
49 self
50 }
51
52 pub fn no_timeout(mut self) -> Self {
54 self.timeout_ms = None;
55 self
56 }
57
58 pub fn max_attempts(mut self, attempts: u32) -> Self {
60 self.max_poll_attempts = Some(attempts);
61 self
62 }
63
64 pub fn poll_duration(&self) -> Duration {
66 Duration::from_millis(self.poll_interval_ms)
67 }
68
69 pub fn timeout_duration(&self) -> Option<Duration> {
71 self.timeout_ms.map(Duration::from_millis)
72 }
73
74 pub fn to_task_params(&self) -> Value {
76 json!({
77 "poll_interval_ms": self.poll_interval_ms
78 })
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[serde(rename_all = "lowercase")]
85pub enum TaskStatus {
86 Pending,
88 Running,
90 Completed,
92 Failed,
94 Cancelled,
96}
97
98impl TaskStatus {
99 pub fn is_terminal(&self) -> bool {
101 matches!(self, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled)
102 }
103
104 pub fn is_in_progress(&self) -> bool {
106 matches!(self, TaskStatus::Pending | TaskStatus::Running)
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct TaskInfo {
113 pub task_id: String,
115 pub status: TaskStatus,
117 pub progress: Option<u8>,
119 pub message: Option<String>,
121 pub eta_ms: Option<u64>,
123}
124
125#[derive(Debug, Clone)]
127pub struct CreateTaskResult {
128 pub task_id: String,
130 pub info: TaskInfo,
132}
133
134#[derive(Debug, Clone)]
136pub enum TaskError {
137 CreateFailed(String),
139 PollFailed(String),
141 Timeout { task_id: String, elapsed_ms: u64 },
143 Cancelled(String),
145 TaskFailed { task_id: String, error: String },
147 MaxAttemptsExceeded { task_id: String, attempts: u32 },
149}
150
151impl std::fmt::Display for TaskError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 TaskError::CreateFailed(msg) => write!(f, "Failed to create task: {}", msg),
155 TaskError::PollFailed(msg) => write!(f, "Failed to poll task: {}", msg),
156 TaskError::Timeout { task_id, elapsed_ms } => {
157 write!(f, "Task '{}' timed out after {}ms", task_id, elapsed_ms)
158 }
159 TaskError::Cancelled(task_id) => write!(f, "Task '{}' was cancelled", task_id),
160 TaskError::TaskFailed { task_id, error } => {
161 write!(f, "Task '{}' failed: {}", task_id, error)
162 }
163 TaskError::MaxAttemptsExceeded { task_id, attempts } => {
164 write!(f, "Task '{}' exceeded {} poll attempts", task_id, attempts)
165 }
166 }
167 }
168}
169
170impl std::error::Error for TaskError {}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_task_config_default() {
178 let config = McpTaskConfig::default();
179 assert!(!config.enable_tasks);
180 assert_eq!(config.poll_interval_ms, 1000);
181 assert_eq!(config.timeout_ms, Some(300_000));
182 }
183
184 #[test]
185 fn test_task_config_enabled() {
186 let config = McpTaskConfig::enabled();
187 assert!(config.enable_tasks);
188 }
189
190 #[test]
191 fn test_task_config_builder() {
192 let config = McpTaskConfig::enabled()
193 .poll_interval(Duration::from_secs(2))
194 .timeout(Duration::from_secs(60))
195 .max_attempts(10);
196
197 assert!(config.enable_tasks);
198 assert_eq!(config.poll_interval_ms, 2000);
199 assert_eq!(config.timeout_ms, Some(60_000));
200 assert_eq!(config.max_poll_attempts, Some(10));
201 }
202
203 #[test]
204 fn test_task_status_terminal() {
205 assert!(!TaskStatus::Pending.is_terminal());
206 assert!(!TaskStatus::Running.is_terminal());
207 assert!(TaskStatus::Completed.is_terminal());
208 assert!(TaskStatus::Failed.is_terminal());
209 assert!(TaskStatus::Cancelled.is_terminal());
210 }
211
212 #[test]
213 fn test_task_status_in_progress() {
214 assert!(TaskStatus::Pending.is_in_progress());
215 assert!(TaskStatus::Running.is_in_progress());
216 assert!(!TaskStatus::Completed.is_in_progress());
217 assert!(!TaskStatus::Failed.is_in_progress());
218 assert!(!TaskStatus::Cancelled.is_in_progress());
219 }
220}