use crate::{
error::Error,
types::{Cursor, IntoResponse, Meta, Page, RequestId, Response},
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::ops::{Deref, DerefMut};
#[cfg(feature = "server")]
use crate::{
app::handler::{FromHandlerParams, HandlerParams},
types::request::{FromRequest, Request},
};
pub(crate) const RELATED_TASK_KEY: &str = "io.modelcontextprotocol/related-task";
const DEFAULT_TTL: usize = 30000;
pub mod commands {
pub const LIST: &str = "tasks/list";
pub const CANCEL: &str = "tasks/cancel";
pub const RESULT: &str = "tasks/result";
pub const GET: &str = "tasks/get";
pub const STATUS: &str = "notifications/tasks/status";
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ListTasksRequestParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<Cursor>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ListTasksResult {
pub tasks: Vec<Task>,
#[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<Cursor>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct CancelTaskRequestParams {
#[serde(rename = "taskId")]
pub id: String,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct GetTaskRequestParams {
#[serde(rename = "taskId")]
pub id: String,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct GetTaskPayloadRequestParams {
#[serde(rename = "taskId")]
pub id: String,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct CreateTaskResult {
pub task: Task,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
#[serde(rename = "taskId")]
pub id: String,
#[serde(rename = "createdAt")]
pub created_at: DateTime<Utc>,
#[serde(rename = "lastUpdatedAt")]
pub last_updated_at: DateTime<Utc>,
pub ttl: usize,
pub status: TaskStatus,
#[serde(rename = "statusMessage", skip_serializing_if = "Option::is_none")]
pub status_msg: Option<String>,
#[serde(rename = "pollInterval", skip_serializing_if = "Option::is_none")]
pub poll_interval: Option<usize>,
}
#[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
pub enum TaskStatus {
#[serde(rename = "cancelled")]
Cancelled,
#[serde(rename = "completed")]
Completed,
#[serde(rename = "failed")]
Failed,
#[default]
#[serde(rename = "working")]
Working,
#[serde(rename = "input_required")]
InputRequired,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TaskMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<usize>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct RelatedTaskMetadata {
#[serde(rename = "taskId")]
pub id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskPayload(pub Value);
impl Deref for TaskPayload {
type Target = Value;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for TaskPayload {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl IntoResponse for Task {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
impl IntoResponse for TaskPayload {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
self.0.into_response(req_id)
}
}
impl IntoResponse for CreateTaskResult {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
impl IntoResponse for ListTasksResult {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
impl<const N: usize> From<[Task; N]> for ListTasksResult {
#[inline]
fn from(tasks: [Task; N]) -> Self {
Self {
next_cursor: None,
tasks: tasks.to_vec(),
}
}
}
impl From<Vec<Task>> for ListTasksResult {
#[inline]
fn from(tasks: Vec<Task>) -> Self {
Self {
next_cursor: None,
tasks,
}
}
}
impl From<Page<'_, Task>> for ListTasksResult {
#[inline]
fn from(page: Page<'_, Task>) -> Self {
Self {
next_cursor: page.next_cursor,
tasks: page.items.to_vec(),
}
}
}
impl<T: Into<String>> From<T> for RelatedTaskMetadata {
#[inline]
fn from(value: T) -> Self {
Self { id: value.into() }
}
}
impl From<Meta<RelatedTaskMetadata>> for RelatedTaskMetadata {
#[inline]
fn from(meta: Meta<RelatedTaskMetadata>) -> Self {
meta.into_inner()
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for ListTasksRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for CancelTaskRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for GetTaskRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for GetTaskPayloadRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
impl Default for Task {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl From<TaskMetadata> for Task {
#[inline]
fn from(meta: TaskMetadata) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
created_at: Utc::now(),
last_updated_at: Utc::now(),
ttl: meta.ttl.unwrap_or(DEFAULT_TTL),
status: TaskStatus::Working,
status_msg: None,
poll_interval: None,
}
}
}
impl ListTasksResult {
#[inline]
pub fn new() -> Self {
Default::default()
}
}
impl CreateTaskResult {
pub fn new(task: Task) -> Self {
Self { task, meta: None }
}
}
impl Task {
#[inline]
pub fn new() -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
created_at: Utc::now(),
last_updated_at: Utc::now(),
ttl: DEFAULT_TTL,
status: TaskStatus::Working,
status_msg: None,
poll_interval: None,
}
}
pub fn set_message(&mut self, msg: impl Into<String>) {
self.status_msg = Some(msg.into());
self.last_updated_at = Utc::now();
}
pub fn reset(&mut self) {
self.status = TaskStatus::Working;
self.last_updated_at = Utc::now();
}
pub fn cancel(mut self) -> Self {
self.status = TaskStatus::Cancelled;
self.last_updated_at = Utc::now();
self
}
pub fn complete(&mut self) {
self.status = TaskStatus::Completed;
self.last_updated_at = Utc::now();
}
pub fn fail(&mut self) {
self.status = TaskStatus::Failed;
self.last_updated_at = Utc::now();
}
pub fn require_input(&mut self) {
self.status = TaskStatus::InputRequired;
self.last_updated_at = Utc::now();
}
}
impl TaskPayload {
#[inline]
pub fn into_inner(self) -> Value {
self.0
}
#[inline]
pub fn to<T: DeserializeOwned>(self) -> Result<T, Error> {
serde_json::from_value::<T>(self.0).map_err(Error::from)
}
}