use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct McpTaskConfig {
pub enable_tasks: bool,
pub poll_interval_ms: u64,
pub timeout_ms: Option<u64>,
pub max_poll_attempts: Option<u32>,
}
impl Default for McpTaskConfig {
fn default() -> Self {
Self {
enable_tasks: false,
poll_interval_ms: 1000,
timeout_ms: Some(300_000), max_poll_attempts: None,
}
}
}
impl McpTaskConfig {
pub fn enabled() -> Self {
Self { enable_tasks: true, ..Default::default() }
}
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval_ms = interval.as_millis() as u64;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout_ms = Some(timeout.as_millis() as u64);
self
}
pub fn no_timeout(mut self) -> Self {
self.timeout_ms = None;
self
}
pub fn max_attempts(mut self, attempts: u32) -> Self {
self.max_poll_attempts = Some(attempts);
self
}
pub fn poll_duration(&self) -> Duration {
Duration::from_millis(self.poll_interval_ms)
}
pub fn timeout_duration(&self) -> Option<Duration> {
self.timeout_ms.map(Duration::from_millis)
}
pub fn to_task_params(&self) -> Value {
json!({
"poll_interval_ms": self.poll_interval_ms
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
}
impl TaskStatus {
pub fn is_terminal(&self) -> bool {
matches!(self, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled)
}
pub fn is_in_progress(&self) -> bool {
matches!(self, TaskStatus::Pending | TaskStatus::Running)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskInfo {
pub task_id: String,
pub status: TaskStatus,
pub progress: Option<u8>,
pub message: Option<String>,
pub eta_ms: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct CreateTaskResult {
pub task_id: String,
pub info: TaskInfo,
}
#[derive(Debug, Clone)]
pub enum TaskError {
CreateFailed(String),
PollFailed(String),
Timeout { task_id: String, elapsed_ms: u64 },
Cancelled(String),
TaskFailed { task_id: String, error: String },
MaxAttemptsExceeded { task_id: String, attempts: u32 },
}
impl std::fmt::Display for TaskError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskError::CreateFailed(msg) => write!(f, "Failed to create task: {}", msg),
TaskError::PollFailed(msg) => write!(f, "Failed to poll task: {}", msg),
TaskError::Timeout { task_id, elapsed_ms } => {
write!(f, "Task '{}' timed out after {}ms", task_id, elapsed_ms)
}
TaskError::Cancelled(task_id) => write!(f, "Task '{}' was cancelled", task_id),
TaskError::TaskFailed { task_id, error } => {
write!(f, "Task '{}' failed: {}", task_id, error)
}
TaskError::MaxAttemptsExceeded { task_id, attempts } => {
write!(f, "Task '{}' exceeded {} poll attempts", task_id, attempts)
}
}
}
}
impl std::error::Error for TaskError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_config_default() {
let config = McpTaskConfig::default();
assert!(!config.enable_tasks);
assert_eq!(config.poll_interval_ms, 1000);
assert_eq!(config.timeout_ms, Some(300_000));
}
#[test]
fn test_task_config_enabled() {
let config = McpTaskConfig::enabled();
assert!(config.enable_tasks);
}
#[test]
fn test_task_config_builder() {
let config = McpTaskConfig::enabled()
.poll_interval(Duration::from_secs(2))
.timeout(Duration::from_secs(60))
.max_attempts(10);
assert!(config.enable_tasks);
assert_eq!(config.poll_interval_ms, 2000);
assert_eq!(config.timeout_ms, Some(60_000));
assert_eq!(config.max_poll_attempts, Some(10));
}
#[test]
fn test_task_status_terminal() {
assert!(!TaskStatus::Pending.is_terminal());
assert!(!TaskStatus::Running.is_terminal());
assert!(TaskStatus::Completed.is_terminal());
assert!(TaskStatus::Failed.is_terminal());
assert!(TaskStatus::Cancelled.is_terminal());
}
#[test]
fn test_task_status_in_progress() {
assert!(TaskStatus::Pending.is_in_progress());
assert!(TaskStatus::Running.is_in_progress());
assert!(!TaskStatus::Completed.is_in_progress());
assert!(!TaskStatus::Failed.is_in_progress());
assert!(!TaskStatus::Cancelled.is_in_progress());
}
}