use crate::error::TaskStorageError;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use turul_mcp_protocol::TaskStatus;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskOutcome {
Success(Value),
Error {
code: i64,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskRecord {
pub task_id: String,
pub session_id: Option<String>,
pub status: TaskStatus,
pub status_message: Option<String>,
pub created_at: String,
pub last_updated_at: String,
pub ttl: Option<i64>,
pub poll_interval: Option<u64>,
pub original_method: String,
pub original_params: Option<Value>,
pub result: Option<TaskOutcome>,
pub meta: Option<HashMap<String, Value>>,
}
impl TaskRecord {
pub fn to_protocol_task(&self) -> turul_mcp_protocol::Task {
let mut task = turul_mcp_protocol::Task::new(
&self.task_id,
self.status,
&self.created_at,
&self.last_updated_at,
);
if let Some(ref msg) = self.status_message {
task = task.with_status_message(msg);
}
if let Some(ttl) = self.ttl {
task = task.with_ttl(ttl);
}
if let Some(interval) = self.poll_interval {
task = task.with_poll_interval(interval);
}
if let Some(ref meta) = self.meta {
task = task.with_meta(meta.clone());
}
task
}
}
#[derive(Debug, Clone)]
pub struct TaskListPage {
pub tasks: Vec<TaskRecord>,
pub next_cursor: Option<String>,
}
#[async_trait]
pub trait TaskStorage: Send + Sync {
fn backend_name(&self) -> &'static str;
async fn create_task(&self, task: TaskRecord) -> Result<TaskRecord, TaskStorageError>;
async fn get_task(&self, task_id: &str) -> Result<Option<TaskRecord>, TaskStorageError>;
async fn update_task(&self, task: TaskRecord) -> Result<(), TaskStorageError>;
async fn delete_task(&self, task_id: &str) -> Result<bool, TaskStorageError>;
async fn list_tasks(
&self,
cursor: Option<&str>,
limit: Option<u32>,
) -> Result<TaskListPage, TaskStorageError>;
async fn update_task_status(
&self,
task_id: &str,
new_status: TaskStatus,
status_message: Option<String>,
) -> Result<TaskRecord, TaskStorageError>;
async fn store_task_result(
&self,
task_id: &str,
result: TaskOutcome,
) -> Result<(), TaskStorageError>;
async fn get_task_result(&self, task_id: &str)
-> Result<Option<TaskOutcome>, TaskStorageError>;
async fn expire_tasks(&self) -> Result<Vec<String>, TaskStorageError>;
async fn task_count(&self) -> Result<usize, TaskStorageError>;
async fn maintenance(&self) -> Result<(), TaskStorageError>;
async fn list_tasks_for_session(
&self,
session_id: &str,
cursor: Option<&str>,
limit: Option<u32>,
) -> Result<TaskListPage, TaskStorageError>;
async fn recover_stuck_tasks(&self, max_age_ms: u64) -> Result<Vec<String>, TaskStorageError>;
}