use serde::{Deserialize, Serialize};
pub const RELATED_TASK_META_KEY: &str = "io.modelcontextprotocol/related-task";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
#[default]
Working,
InputRequired,
Completed,
Failed,
Cancelled,
}
impl std::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Working => write!(f, "working"),
Self::InputRequired => write!(f, "input_required"),
Self::Completed => write!(f, "completed"),
Self::Failed => write!(f, "failed"),
Self::Cancelled => write!(f, "cancelled"),
}
}
}
impl TaskStatus {
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
}
pub fn can_transition_to(&self, next: &Self) -> bool {
if self == next {
return false;
}
match self {
Self::Working => matches!(
next,
Self::InputRequired | Self::Completed | Self::Failed | Self::Cancelled
),
Self::InputRequired => matches!(
next,
Self::Working | Self::Completed | Self::Failed | Self::Cancelled
),
Self::Completed | Self::Failed | Self::Cancelled => false,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct Task {
pub task_id: String,
pub status: TaskStatus,
pub ttl: Option<u64>,
pub created_at: String,
pub last_updated_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub poll_interval: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status_message: Option<String>,
}
impl Task {
pub fn new(task_id: impl Into<String>, status: TaskStatus) -> Self {
Self {
task_id: task_id.into(),
status,
ttl: None,
created_at: String::new(),
last_updated_at: String::new(),
poll_interval: None,
status_message: None,
}
}
pub fn with_ttl(mut self, ttl: u64) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_timestamps(
mut self,
created_at: impl Into<String>,
last_updated_at: impl Into<String>,
) -> Self {
self.created_at = created_at.into();
self.last_updated_at = last_updated_at.into();
self
}
pub fn with_poll_interval(mut self, interval: u64) -> Self {
self.poll_interval = Some(interval);
self
}
pub fn with_status_message(mut self, message: impl Into<String>) -> Self {
self.status_message = Some(message.into());
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct TaskCreationParams {
pub ttl: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub poll_interval: Option<u64>,
}
impl TaskCreationParams {
pub fn new() -> Self {
Self::default()
}
pub fn with_ttl(mut self, ttl: u64) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_poll_interval(mut self, interval: u64) -> Self {
self.poll_interval = Some(interval);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RelatedTaskMetadata {
pub task_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct CreateTaskResult {
pub task: Task,
}
impl CreateTaskResult {
pub fn new(task: Task) -> Self {
Self { task }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskStatusNotification {
pub task: Task,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetTaskRequest {
pub task_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct GetTaskResult {
pub task: Task,
}
impl GetTaskResult {
pub fn new(task: Task) -> Self {
Self { task }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetTaskPayloadRequest {
pub task_id: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListTasksRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct ListTasksResult {
pub tasks: Vec<Task>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
}
impl ListTasksResult {
pub fn new(tasks: Vec<Task>) -> Self {
Self {
tasks,
next_cursor: None,
}
}
pub fn with_next_cursor(mut self, cursor: impl Into<String>) -> Self {
self.next_cursor = Some(cursor.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CancelTaskRequest {
pub task_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "camelCase")]
pub struct CancelTaskResult {
pub task: Task,
}
impl CancelTaskResult {
pub fn new(task: Task) -> Self {
Self { task }
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn task_status_serialization() {
assert_eq!(
serde_json::to_value(TaskStatus::Working).unwrap(),
"working"
);
assert_eq!(
serde_json::to_value(TaskStatus::InputRequired).unwrap(),
"input_required"
);
assert_eq!(
serde_json::to_value(TaskStatus::Completed).unwrap(),
"completed"
);
assert_eq!(serde_json::to_value(TaskStatus::Failed).unwrap(), "failed");
assert_eq!(
serde_json::to_value(TaskStatus::Cancelled).unwrap(),
"cancelled"
);
}
#[test]
fn task_roundtrip() {
let task = Task::new("t-123", TaskStatus::Working)
.with_timestamps("2025-11-25T00:00:00Z", "2025-11-25T00:01:00Z")
.with_ttl(60000)
.with_poll_interval(5000)
.with_status_message("Processing...");
let json = serde_json::to_value(&task).unwrap();
assert_eq!(json["taskId"], "t-123");
assert_eq!(json["status"], "working");
assert_eq!(json["ttl"], 60000);
assert_eq!(json["createdAt"], "2025-11-25T00:00:00Z");
assert_eq!(json["pollInterval"], 5000);
let roundtrip: Task = serde_json::from_value(json).unwrap();
assert_eq!(roundtrip.task_id, "t-123");
assert_eq!(roundtrip.status, TaskStatus::Working);
}
#[test]
fn create_task_result_roundtrip() {
let result = CreateTaskResult::new(
Task::new("t-456", TaskStatus::Completed)
.with_timestamps("2025-11-25T00:00:00Z", "2025-11-25T00:05:00Z"),
);
let json = serde_json::to_value(&result).unwrap();
assert_eq!(json["task"]["taskId"], "t-456");
assert_eq!(json["task"]["status"], "completed");
let roundtrip: CreateTaskResult = serde_json::from_value(json).unwrap();
assert_eq!(roundtrip.task.status, TaskStatus::Completed);
}
#[test]
fn task_ts_format_interop() {
let ts_json = json!({
"taskId": "task-abc",
"status": "input_required",
"createdAt": "2025-11-25T12:00:00.000Z",
"lastUpdatedAt": "2025-11-25T12:01:00.000Z",
"pollInterval": 3000,
"statusMessage": "Waiting for user input"
});
let task: Task = serde_json::from_value(ts_json).unwrap();
assert_eq!(task.task_id, "task-abc");
assert_eq!(task.status, TaskStatus::InputRequired);
assert_eq!(task.poll_interval, Some(3000));
}
#[test]
fn related_task_meta_key_value() {
assert_eq!(
RELATED_TASK_META_KEY,
"io.modelcontextprotocol/related-task"
);
}
#[test]
fn task_status_is_terminal() {
assert!(!TaskStatus::Working.is_terminal());
assert!(!TaskStatus::InputRequired.is_terminal());
assert!(TaskStatus::Completed.is_terminal());
assert!(TaskStatus::Failed.is_terminal());
assert!(TaskStatus::Cancelled.is_terminal());
}
#[test]
fn task_status_can_transition_to() {
assert!(TaskStatus::Working.can_transition_to(&TaskStatus::InputRequired));
assert!(TaskStatus::Working.can_transition_to(&TaskStatus::Completed));
assert!(TaskStatus::Working.can_transition_to(&TaskStatus::Failed));
assert!(TaskStatus::Working.can_transition_to(&TaskStatus::Cancelled));
assert!(TaskStatus::InputRequired.can_transition_to(&TaskStatus::Working));
assert!(TaskStatus::InputRequired.can_transition_to(&TaskStatus::Completed));
assert!(TaskStatus::InputRequired.can_transition_to(&TaskStatus::Failed));
assert!(TaskStatus::InputRequired.can_transition_to(&TaskStatus::Cancelled));
}
#[test]
fn task_status_self_transition_rejected() {
assert!(!TaskStatus::Working.can_transition_to(&TaskStatus::Working));
assert!(!TaskStatus::InputRequired.can_transition_to(&TaskStatus::InputRequired));
assert!(!TaskStatus::Completed.can_transition_to(&TaskStatus::Completed));
assert!(!TaskStatus::Failed.can_transition_to(&TaskStatus::Failed));
assert!(!TaskStatus::Cancelled.can_transition_to(&TaskStatus::Cancelled));
}
#[test]
fn task_status_terminal_rejects_all() {
for terminal in [
TaskStatus::Completed,
TaskStatus::Failed,
TaskStatus::Cancelled,
] {
for target in [
TaskStatus::Working,
TaskStatus::InputRequired,
TaskStatus::Completed,
TaskStatus::Failed,
TaskStatus::Cancelled,
] {
assert!(
!terminal.can_transition_to(&target),
"{terminal:?} should not transition to {target:?}"
);
}
}
}
#[test]
fn task_ttl_null_serialization() {
let task = Task::new("test-null-ttl", TaskStatus::Working)
.with_timestamps("2025-11-25T00:00:00Z", "2025-11-25T00:01:00Z");
let json = serde_json::to_value(&task).unwrap();
assert!(json.get("ttl").is_some(), "ttl must be present");
assert!(json["ttl"].is_null(), "ttl must be null when None");
assert!(
json.get("pollInterval").is_none(),
"pollInterval should be omitted when None"
);
}
#[test]
fn task_ttl_present_serialization() {
let task = Task::new("test-present-ttl", TaskStatus::Working)
.with_timestamps("2025-11-25T00:00:00Z", "2025-11-25T00:01:00Z")
.with_ttl(60000);
let json = serde_json::to_value(&task).unwrap();
assert_eq!(json["ttl"], 60000);
}
#[test]
fn task_status_display() {
assert_eq!(TaskStatus::Working.to_string(), "working");
assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
assert_eq!(TaskStatus::Completed.to_string(), "completed");
assert_eq!(TaskStatus::Failed.to_string(), "failed");
assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
}
}