use crate::errors::{ClaudeError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
pub type TaskId = String;
pub type TaskUri = String;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TaskRequest {
#[serde(default)]
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub task_hint: Option<TaskHint>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<TaskPriority>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub estimated_duration_secs: Option<u64>,
#[serde(default)]
pub supports_progress: bool,
#[serde(default)]
pub cancellable: bool,
}
impl Default for TaskHint {
fn default() -> Self {
Self {
estimated_duration_secs: None,
supports_progress: false,
cancellable: true,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
#[serde(rename_all = "lowercase")]
pub enum TaskPriority {
Low,
#[default]
Normal,
High,
Urgent,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum TaskState {
Queued,
Working,
InputRequired,
Completed,
Failed,
Cancelled,
}
impl TaskState {
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
}
pub fn is_active(&self) -> bool {
matches!(self, Self::Queued | Self::Working | Self::InputRequired)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskProgress {
pub value: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
impl TaskProgress {
pub fn new(value: f64) -> Self {
assert!(
(0.0..=1.0).contains(&value),
"Progress must be between 0.0 and 1.0"
);
Self {
value,
message: None,
}
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskStatus {
pub id: TaskId,
pub state: TaskState,
#[serde(skip_serializing_if = "Option::is_none")]
pub progress: Option<TaskProgress>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl TaskStatus {
pub fn is_terminal(&self) -> bool {
self.state.is_terminal()
}
pub fn is_active(&self) -> bool {
self.state.is_active()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResult {
pub id: TaskId,
pub data: serde_json::Value,
pub completed_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskHandle {
pub id: TaskId,
pub uri: TaskUri,
pub status: TaskStatus,
}
#[derive(Debug, Clone)]
struct Task {
id: TaskId,
request: TaskRequest,
state: TaskState,
progress: Option<TaskProgress>,
result: Option<serde_json::Value>,
error: Option<String>,
created_at: chrono::DateTime<chrono::Utc>,
updated_at: chrono::DateTime<chrono::Utc>,
completed_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl Task {
fn new(request: TaskRequest) -> Self {
let now = chrono::Utc::now();
Self {
id: Uuid::new_v4().to_string(),
request,
state: TaskState::Queued,
progress: None,
result: None,
error: None,
created_at: now,
updated_at: now,
completed_at: None,
}
}
fn to_status(&self) -> TaskStatus {
TaskStatus {
id: self.id.clone(),
state: self.state.clone(),
progress: self.progress.clone(),
error: self.error.clone(),
created_at: self.created_at,
updated_at: self.updated_at,
completed_at: self.completed_at,
}
}
}
#[derive(Clone)]
pub struct TaskManager {
tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
base_uri: String,
}
impl TaskManager {
pub fn new() -> Self {
Self::with_base_uri("mcp://tasks".to_string())
}
pub fn with_base_uri(base_uri: impl Into<String>) -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
base_uri: base_uri.into(),
}
}
pub async fn create_task(&self, request: TaskRequest) -> Result<TaskHandle> {
let task = Task::new(request);
let task_id = task.id.clone();
let uri = format!("{}/{}", self.base_uri, task_id);
let status = task.to_status();
let mut tasks = self.tasks.write().await;
tasks.insert(task_id.clone(), task);
Ok(TaskHandle {
id: task_id,
uri,
status,
})
}
pub async fn get_task_status(&self, task_id: &TaskId) -> Result<TaskStatus> {
let tasks = self.tasks.read().await;
let task = tasks
.get(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
Ok(task.to_status())
}
pub async fn get_task_result(&self, task_id: &TaskId) -> Result<TaskResult> {
let tasks = self.tasks.read().await;
let task = tasks
.get(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state != TaskState::Completed {
return Err(ClaudeError::InvalidInput(format!(
"Task is not completed. Current state: {:?}",
task.state
)));
}
let result = task.result.as_ref().ok_or_else(|| {
ClaudeError::InternalError("Completed task has no result".to_string())
})?;
Ok(TaskResult {
id: task_id.clone(),
data: result.clone(),
completed_at: task.completed_at.unwrap(),
})
}
pub async fn update_progress(&self, task_id: &TaskId, progress: TaskProgress) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot update progress for terminal task".to_string(),
));
}
task.progress = Some(progress);
task.updated_at = chrono::Utc::now();
Ok(())
}
pub async fn mark_working(&self, task_id: &TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot transition terminal task".to_string(),
));
}
task.state = TaskState::Working;
task.updated_at = chrono::Utc::now();
Ok(())
}
pub async fn mark_completed(&self, task_id: &TaskId, result: serde_json::Value) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot transition terminal task".to_string(),
));
}
let now = chrono::Utc::now();
task.state = TaskState::Completed;
task.result = Some(result);
task.updated_at = now;
task.completed_at = Some(now);
Ok(())
}
pub async fn mark_failed(&self, task_id: &TaskId, error: impl Into<String>) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot transition terminal task".to_string(),
));
}
let now = chrono::Utc::now();
task.state = TaskState::Failed;
task.error = Some(error.into());
task.updated_at = now;
task.completed_at = Some(now);
Ok(())
}
pub async fn mark_cancelled(&self, task_id: &TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot transition terminal task".to_string(),
));
}
let now = chrono::Utc::now();
task.state = TaskState::Cancelled;
task.updated_at = now;
task.completed_at = Some(now);
Ok(())
}
pub async fn mark_input_required(&self, task_id: &TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(
"Cannot transition terminal task".to_string(),
));
}
task.state = TaskState::InputRequired;
task.updated_at = chrono::Utc::now();
Ok(())
}
pub async fn list_tasks(&self) -> Result<Vec<TaskStatus>> {
let tasks = self.tasks.read().await;
Ok(tasks.values().map(|t| t.to_status()).collect())
}
pub async fn cancel_task(&self, task_id: &TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
if task.state.is_terminal() {
return Err(ClaudeError::InvalidInput(format!(
"Cannot cancel task in state: {:?}",
task.state
)));
}
if let Some(hint) = &task.request.task_hint {
if !hint.cancellable {
return Err(ClaudeError::InvalidInput(
"Task is not cancellable".to_string(),
));
}
}
let now = chrono::Utc::now();
task.state = TaskState::Cancelled;
task.updated_at = now;
task.completed_at = Some(now);
Ok(())
}
pub async fn cleanup_old_tasks(&self, older_than: chrono::Duration) -> Result<usize> {
let mut tasks = self.tasks.write().await;
let cutoff = chrono::Utc::now() - older_than;
let initial_count = tasks.len();
tasks.retain(|_, task| {
if let Some(completed_at) = task.completed_at {
completed_at > cutoff
} else {
true }
});
Ok(initial_count - tasks.len())
}
}
impl Default for TaskManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_task_creation() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({"name": "test"}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
assert!(!handle.id.is_empty());
assert!(!handle.uri.is_empty());
assert_eq!(handle.status.state, TaskState::Queued);
}
#[tokio::test]
async fn test_task_status() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.id, handle.id);
assert_eq!(status.state, TaskState::Queued);
assert!(status.is_active());
assert!(!status.is_terminal());
}
#[tokio::test]
async fn test_task_lifecycle() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
manager.mark_working(&handle.id).await.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.state, TaskState::Working);
let progress = TaskProgress::new(0.5).with_message("Half done");
manager.update_progress(&handle.id, progress).await.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.progress.as_ref().unwrap().value, 0.5);
assert_eq!(
status.progress.as_ref().unwrap().message.as_ref().unwrap(),
"Half done"
);
let result = json!({"output": "success"});
manager.mark_completed(&handle.id, result).await.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.state, TaskState::Completed);
assert!(status.is_terminal());
assert!(!status.is_active());
let task_result = manager.get_task_result(&handle.id).await.unwrap();
assert_eq!(task_result.id, handle.id);
assert_eq!(task_result.data, json!({"output": "success"}));
}
#[tokio::test]
async fn test_task_failure() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
manager
.mark_failed(&handle.id, "Something went wrong")
.await
.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.state, TaskState::Failed);
assert!(status.is_terminal());
assert_eq!(status.error.as_ref().unwrap(), "Something went wrong");
}
#[tokio::test]
async fn test_task_cancellation() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
task_hint: Some(TaskHint {
cancellable: true,
..Default::default()
}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
manager.cancel_task(&handle.id).await.unwrap();
let status = manager.get_task_status(&handle.id).await.unwrap();
assert_eq!(status.state, TaskState::Cancelled);
assert!(status.is_terminal());
}
#[tokio::test]
async fn test_non_cancellable_task() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
task_hint: Some(TaskHint {
cancellable: false,
..Default::default()
}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
let result = manager.cancel_task(&handle.id).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_terminal_state_transitions() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
manager.mark_completed(&handle.id, json!({})).await.unwrap();
assert!(manager.mark_working(&handle.id).await.is_err());
assert!(
manager
.update_progress(&handle.id, TaskProgress::new(0.5))
.await
.is_err()
);
}
#[tokio::test]
async fn test_list_tasks() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let _task1 = manager.create_task(request.clone()).await.unwrap();
let _task2 = manager.create_task(request).await.unwrap();
let tasks = manager.list_tasks().await.unwrap();
assert_eq!(tasks.len(), 2);
}
#[tokio::test]
async fn test_progress_bounds() {
assert!(TaskProgress::new(0.0).value == 0.0);
assert!(TaskProgress::new(0.5).value == 0.5);
assert!(TaskProgress::new(1.0).value == 1.0);
}
#[tokio::test]
async fn test_priority_ordering() {
assert!(TaskPriority::Low < TaskPriority::Normal);
assert!(TaskPriority::Normal < TaskPriority::High);
assert!(TaskPriority::High < TaskPriority::Urgent);
}
#[tokio::test]
async fn test_cleanup_old_tasks() {
let manager = TaskManager::new();
let request = TaskRequest {
method: "tools/call".to_string(),
params: json!({}),
..Default::default()
};
let handle = manager.create_task(request).await.unwrap();
manager.mark_completed(&handle.id, json!({})).await.unwrap();
let cleaned = manager
.cleanup_old_tasks(chrono::Duration::seconds(1))
.await
.unwrap();
assert_eq!(cleaned, 0);
let cleaned = manager
.cleanup_old_tasks(chrono::Duration::seconds(0))
.await
.unwrap();
assert_eq!(cleaned, 1);
}
}