use rmcp::model::{
CallToolResult, CancelTaskParams, CreateTaskResult, EmptyObject, GetTaskInfoParams, GetTaskInfoResult,
GetTaskResultParams, ListTasksResult, PaginatedRequestParams, Task, TaskResult, TaskStatus,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
const DEFAULT_TASK_TTL_MS: u64 = 3600000;
const DEFAULT_POLL_INTERVAL_MS: u64 = 5000;
#[derive(Clone, Debug)]
pub struct TaskEntry {
pub task: Task,
pub result: Option<Result<CallToolResult, String>>,
}
impl TaskEntry {
fn new(task_id: String) -> Self {
let now = chrono::Utc::now().to_rfc3339();
Self {
task: Task {
task_id,
status: TaskStatus::Working,
status_message: Some("Task enqueued".to_string()),
created_at: now,
last_updated_at: None,
ttl: Some(DEFAULT_TASK_TTL_MS),
poll_interval: Some(DEFAULT_POLL_INTERVAL_MS),
},
result: None,
}
}
fn update_status(&mut self, status: TaskStatus, message: Option<String>) {
self.task.status = status;
self.task.status_message = message;
self.task.last_updated_at = Some(chrono::Utc::now().to_rfc3339());
}
}
#[derive(Clone, Default)]
pub struct TaskStore {
tasks: Arc<RwLock<HashMap<String, TaskEntry>>>,
}
impl TaskStore {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn enqueue(&self) -> CreateTaskResult {
let task_id = Uuid::new_v4().to_string();
let entry = TaskEntry::new(task_id.clone());
let task = entry.task.clone();
let mut tasks = self.tasks.write().await;
tasks.insert(task_id, entry);
tracing::debug!(task_id = %task.task_id, "Task enqueued");
CreateTaskResult { task }
}
pub async fn list(&self, params: Option<PaginatedRequestParams>) -> ListTasksResult {
let tasks = self.tasks.read().await;
let all_tasks: Vec<Task> = tasks.values().map(|e| e.task.clone()).collect();
let (task_list, next_cursor) = if let Some(pagination) = params {
let page_size = 20;
let cursor = pagination.cursor.and_then(|c| c.parse::<usize>().ok()).unwrap_or(0);
let start = cursor;
let end = std::cmp::min(start + page_size, all_tasks.len());
let page_tasks = all_tasks[start..end].to_vec();
let cursor_value = if end < all_tasks.len() {
Some(end.to_string())
} else {
None
};
(page_tasks, cursor_value)
} else {
(all_tasks, None)
};
ListTasksResult {
tasks: task_list,
next_cursor,
..Default::default()
}
}
pub async fn get_info(&self, params: GetTaskInfoParams) -> Option<GetTaskInfoResult> {
let tasks = self.tasks.read().await;
tasks.get(¶ms.task_id).map(|entry| GetTaskInfoResult {
task: Some(entry.task.clone()),
})
}
pub async fn get_result(&self, params: GetTaskResultParams) -> Option<TaskResult> {
let tasks = self.tasks.read().await;
let entry = tasks.get(¶ms.task_id)?;
match entry.task.status {
TaskStatus::Completed => {
if let Some(Ok(result)) = &entry.result {
let value = serde_json::to_value(&result.content).unwrap_or_default();
Some(TaskResult {
content_type: "application/json".to_string(),
value,
summary: None,
})
} else {
None
}
},
TaskStatus::Failed => {
if let Some(Err(error_msg)) = &entry.result {
Some(TaskResult {
content_type: "text/plain".to_string(),
value: serde_json::json!({ "error": error_msg }),
summary: Some(format!("Task failed: {}", error_msg)),
})
} else {
None
}
},
_ => None, }
}
pub async fn cancel(&self, params: CancelTaskParams) -> Option<EmptyObject> {
let mut tasks = self.tasks.write().await;
if let Some(entry) = tasks.get_mut(¶ms.task_id) {
match entry.task.status {
TaskStatus::Working | TaskStatus::InputRequired => {
entry.update_status(TaskStatus::Cancelled, Some("Cancelled by client".to_string()));
tracing::debug!(task_id = %params.task_id, "Task cancelled");
Some(EmptyObject {})
},
_ => {
tracing::debug!(
task_id = %params.task_id,
status = ?entry.task.status,
"Task cannot be cancelled (already terminal state)"
);
None
},
}
} else {
None
}
}
pub async fn update_status(&self, task_id: &str, status: TaskStatus, message: Option<String>) -> bool {
let mut tasks = self.tasks.write().await;
if let Some(entry) = tasks.get_mut(task_id) {
entry.update_status(status, message);
true
} else {
false
}
}
pub async fn complete(&self, task_id: &str, result: CallToolResult) -> bool {
let mut tasks = self.tasks.write().await;
if let Some(entry) = tasks.get_mut(task_id) {
entry.result = Some(Ok(result));
entry.update_status(TaskStatus::Completed, Some("Task completed successfully".to_string()));
tracing::debug!(task_id = %task_id, "Task completed");
true
} else {
false
}
}
pub async fn fail(&self, task_id: &str, error: String) -> bool {
let mut tasks = self.tasks.write().await;
if let Some(entry) = tasks.get_mut(task_id) {
entry.result = Some(Err(error.clone()));
entry.update_status(TaskStatus::Failed, Some(error));
tracing::debug!(task_id = %task_id, "Task failed");
true
} else {
false
}
}
}