use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
pub const DEFAULT_MAX_RETRIES: u32 = 3;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum McpTaskState {
Working,
InputRequired,
Completed,
Failed,
Cancelled,
}
impl std::fmt::Display for McpTaskState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Working => "working",
Self::InputRequired => "input_required",
Self::Completed => "completed",
Self::Failed => "failed",
Self::Cancelled => "cancelled",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone)]
pub struct McpTask {
pub id: String,
pub state: McpTaskState,
pub created_at: Instant,
pub expires_at: Option<Instant>,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
pub retry_count: u32,
pub max_retries: u32,
}
impl McpTask {
pub fn new() -> Self {
Self {
id: Uuid::new_v4().to_string(),
state: McpTaskState::Working,
created_at: Instant::now(),
expires_at: None,
result: None,
error: None,
retry_count: 0,
max_retries: DEFAULT_MAX_RETRIES,
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.expires_at = Some(Instant::now() + ttl);
self
}
pub fn is_expired(&self) -> bool {
self.expires_at
.map(|exp| Instant::now() >= exp)
.unwrap_or(false)
}
}
impl Default for McpTask {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct McpTaskStore {
inner: Arc<RwLock<HashMap<String, McpTask>>>,
}
impl McpTaskStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn insert(&self, task: McpTask) -> String {
let id = task.id.clone();
self.inner.write().await.insert(id.clone(), task);
id
}
pub async fn get(&self, id: &str) -> Option<McpTask> {
let map = self.inner.read().await;
let task = map.get(id)?;
if task.is_expired() {
None
} else {
Some(task.clone())
}
}
pub async fn cancel(&self, id: &str) -> bool {
let mut map = self.inner.write().await;
match map.get_mut(id) {
Some(task)
if !task.is_expired()
&& !matches!(
task.state,
McpTaskState::Completed | McpTaskState::Failed | McpTaskState::Cancelled
) =>
{
task.state = McpTaskState::Cancelled;
true
}
_ => false,
}
}
pub async fn update_state(&self, id: &str, state: McpTaskState) -> bool {
let mut map = self.inner.write().await;
match map.get_mut(id) {
Some(task) if !task.is_expired() => {
task.state = state;
true
}
_ => false,
}
}
pub async fn complete(&self, id: &str, result: serde_json::Value) -> bool {
let mut map = self.inner.write().await;
match map.get_mut(id) {
Some(task) if !task.is_expired() => {
task.state = McpTaskState::Completed;
task.result = Some(result);
true
}
_ => false,
}
}
pub async fn fail(&self, id: &str, error: impl Into<String>) -> bool {
let mut map = self.inner.write().await;
match map.get_mut(id) {
Some(task) if !task.is_expired() => {
task.state = McpTaskState::Failed;
task.error = Some(error.into());
true
}
_ => false,
}
}
pub async fn evict_expired(&self) -> usize {
let mut map = self.inner.write().await;
let before = map.len();
map.retain(|_, task| !task.is_expired());
before - map.len()
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
}
impl Default for McpTaskStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_task_lifecycle_working_to_completed() {
let store = McpTaskStore::new();
let task = McpTask::new();
let id = store.insert(task).await;
assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Working);
store.complete(&id, serde_json::json!({"ok": true})).await;
assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Completed);
}
#[tokio::test]
async fn test_task_lifecycle_working_to_failed() {
let store = McpTaskStore::new();
let id = store.insert(McpTask::new()).await;
store.fail(&id, "timeout").await;
let task = store.get(&id).await.unwrap();
assert_eq!(task.state, McpTaskState::Failed);
assert_eq!(task.error.as_deref(), Some("timeout"));
}
#[tokio::test]
async fn test_task_lifecycle_working_to_cancelled() {
let store = McpTaskStore::new();
let id = store.insert(McpTask::new()).await;
assert!(store.cancel(&id).await);
assert_eq!(store.get(&id).await.unwrap().state, McpTaskState::Cancelled);
}
#[tokio::test]
async fn test_cancel_terminal_task_returns_false() {
let store = McpTaskStore::new();
let id = store.insert(McpTask::new()).await;
store.complete(&id, serde_json::json!({})).await;
assert!(!store.cancel(&id).await);
}
#[tokio::test]
async fn test_input_required_state() {
let store = McpTaskStore::new();
let id = store.insert(McpTask::new()).await;
store.update_state(&id, McpTaskState::InputRequired).await;
assert_eq!(
store.get(&id).await.unwrap().state,
McpTaskState::InputRequired
);
}
#[tokio::test]
async fn test_ttl_expiry_eviction() {
let store = McpTaskStore::new();
let task = McpTask::new().with_ttl(Duration::from_millis(1));
let id = store.insert(task).await;
tokio::time::sleep(Duration::from_millis(5)).await;
assert!(store.get(&id).await.is_none());
let evicted = store.evict_expired().await;
assert_eq!(evicted, 1);
assert_eq!(store.len().await, 0);
}
}