use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::protocol::{CallToolResult, TaskObject, TaskStatus};
const DEFAULT_TTL_MS: u64 = 300_000;
const DEFAULT_POLL_INTERVAL_MS: u64 = 2_000;
#[derive(Debug)]
pub struct Task {
pub id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
pub status: TaskStatus,
pub created_at: Instant,
pub created_at_str: String,
pub last_updated_at_str: String,
pub ttl: u64,
pub poll_interval: u64,
pub status_message: Option<String>,
pub result: Option<CallToolResult>,
pub error: Option<String>,
pub cancellation_token: CancellationToken,
pub completed_at: Option<Instant>,
pub completion_notify: Arc<tokio::sync::Notify>,
}
impl Task {
fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
let cancelled = Arc::new(AtomicBool::new(false));
let now_str = chrono_now_iso8601();
Self {
id,
tool_name,
arguments,
status: TaskStatus::Working,
created_at: Instant::now(),
created_at_str: now_str.clone(),
last_updated_at_str: now_str,
ttl: ttl.unwrap_or(DEFAULT_TTL_MS),
poll_interval: DEFAULT_POLL_INTERVAL_MS,
status_message: Some("Task started".to_string()),
result: None,
error: None,
cancellation_token: CancellationToken { cancelled },
completed_at: None,
completion_notify: Arc::new(tokio::sync::Notify::new()),
}
}
pub fn to_task_object(&self) -> TaskObject {
TaskObject {
task_id: self.id.clone(),
status: self.status,
status_message: self.status_message.clone(),
created_at: self.created_at_str.clone(),
last_updated_at: self.last_updated_at_str.clone(),
ttl: Some(self.ttl),
poll_interval: Some(self.poll_interval),
meta: None,
}
}
pub fn is_expired(&self) -> bool {
if let Some(completed_at) = self.completed_at {
completed_at.elapsed() > Duration::from_millis(self.ttl)
} else {
false
}
}
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct TaskStore {
tasks: Arc<RwLock<HashMap<String, Task>>>,
next_id: Arc<AtomicU64>,
}
impl Default for TaskStore {
fn default() -> Self {
Self::new()
}
}
impl TaskStore {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(AtomicU64::new(1)),
}
}
fn generate_id(&self) -> String {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
format!("task-{}", id)
}
pub fn create_task(
&self,
tool_name: &str,
arguments: serde_json::Value,
ttl: Option<u64>,
) -> (String, CancellationToken) {
let id = self.generate_id();
let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
let token = task.cancellation_token.clone();
if let Ok(mut tasks) = self.tasks.write() {
tasks.insert(id.clone(), task);
}
(id, token)
}
pub fn get_task(&self, task_id: &str) -> Option<TaskObject> {
if let Ok(tasks) = self.tasks.read() {
tasks.get(task_id).map(|t| t.to_task_object())
} else {
None
}
}
pub fn get_task_result(
&self,
task_id: &str,
) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
if let Ok(tasks) = self.tasks.read() {
tasks
.get(task_id)
.map(|t| (t.to_task_object(), t.result.clone(), t.error.clone()))
} else {
None
}
}
pub async fn wait_for_completion(
&self,
task_id: &str,
) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
let notify = {
let tasks = self.tasks.read().ok()?;
let task = tasks.get(task_id)?;
if task.status.is_terminal() {
return Some((
task.to_task_object(),
task.result.clone(),
task.error.clone(),
));
}
task.completion_notify.clone()
};
notify.notified().await;
self.get_task_result(task_id)
}
pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskObject> {
if let Ok(tasks) = self.tasks.read() {
tasks
.values()
.filter(|t| status_filter.is_none() || status_filter == Some(t.status))
.map(|t| t.to_task_object())
.collect()
} else {
vec![]
}
}
pub fn require_input(&self, task_id: &str, message: &str) -> bool {
let Ok(mut tasks) = self.tasks.write() else {
return false;
};
let Some(task) = tasks.get_mut(task_id) else {
return false;
};
if task.status.is_terminal() {
return false;
}
task.status = TaskStatus::InputRequired;
task.status_message = Some(message.to_string());
task.last_updated_at_str = chrono_now_iso8601();
true
}
pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
let Ok(mut tasks) = self.tasks.write() else {
return false;
};
let Some(task) = tasks.get_mut(task_id) else {
return false;
};
if task.status.is_terminal() {
return false;
}
task.status = TaskStatus::Completed;
task.status_message = Some("Task completed".to_string());
task.result = Some(result);
task.completed_at = Some(Instant::now());
task.last_updated_at_str = chrono_now_iso8601();
task.completion_notify.notify_waiters();
true
}
pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
let Ok(mut tasks) = self.tasks.write() else {
return false;
};
let Some(task) = tasks.get_mut(task_id) else {
return false;
};
if task.status.is_terminal() {
return false;
}
task.status = TaskStatus::Failed;
task.status_message = Some(format!("Task failed: {}", error));
task.error = Some(error.to_string());
task.completed_at = Some(Instant::now());
task.last_updated_at_str = chrono_now_iso8601();
task.completion_notify.notify_waiters();
true
}
pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskObject> {
let mut tasks = self.tasks.write().ok()?;
let task = tasks.get_mut(task_id)?;
task.cancellation_token.cancel();
if !task.status.is_terminal() {
task.status = TaskStatus::Cancelled;
task.status_message = Some(
reason
.map(|r| format!("Cancelled: {}", r))
.unwrap_or_else(|| "Task cancelled".to_string()),
);
task.completed_at = Some(Instant::now());
task.last_updated_at_str = chrono_now_iso8601();
task.completion_notify.notify_waiters();
}
Some(task.to_task_object())
}
pub fn cleanup_expired(&self) -> usize {
if let Ok(mut tasks) = self.tasks.write() {
let before = tasks.len();
tasks.retain(|_, t| !t.is_expired());
before - tasks.len()
} else {
0
}
}
#[cfg(test)]
pub fn len(&self) -> usize {
if let Ok(tasks) = self.tasks.read() {
tasks.len()
} else {
0
}
}
#[cfg(test)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
fn chrono_now_iso8601() -> String {
use std::time::SystemTime;
let now = SystemTime::now();
let duration = now
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
let secs = duration.as_secs();
let millis = duration.subsec_millis();
let days = secs / 86400;
let remaining = secs % 86400;
let hours = remaining / 3600;
let remaining = remaining % 3600;
let minutes = remaining / 60;
let seconds = remaining % 60;
let mut year = 1970i32;
let mut remaining_days = days as i32;
loop {
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
if remaining_days < days_in_year {
break;
}
remaining_days -= days_in_year;
year += 1;
}
let days_in_months: [i32; 12] = if is_leap_year(year) {
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
} else {
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
};
let mut month = 1;
for days_in_month in days_in_months.iter() {
if remaining_days < *days_in_month {
break;
}
remaining_days -= days_in_month;
month += 1;
}
let day = remaining_days + 1;
format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
year, month, day, hours, minutes, seconds, millis
)
}
fn is_leap_year(year: i32) -> bool {
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_task() {
let store = TaskStore::new();
let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
assert!(id.starts_with("task-"));
assert!(!token.is_cancelled());
let info = store.get_task(&id).expect("task should exist");
assert_eq!(info.task_id, id);
assert_eq!(info.status, TaskStatus::Working);
}
#[test]
fn test_task_lifecycle() {
let store = TaskStore::new();
let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
assert!(store.complete_task(&id, CallToolResult::text("Done")));
let info = store.get_task(&id).unwrap();
assert_eq!(info.status, TaskStatus::Completed);
}
#[test]
fn test_task_cancellation() {
let store = TaskStore::new();
let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
assert!(!token.is_cancelled());
let task_obj = store.cancel_task(&id, Some("User requested"));
assert!(task_obj.is_some());
assert_eq!(task_obj.unwrap().status, TaskStatus::Cancelled);
assert!(token.is_cancelled());
let info = store.get_task(&id).unwrap();
assert_eq!(info.status, TaskStatus::Cancelled);
}
#[test]
fn test_task_failure() {
let store = TaskStore::new();
let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
assert!(store.fail_task(&id, "Something went wrong"));
let info = store.get_task(&id).unwrap();
assert_eq!(info.status, TaskStatus::Failed);
assert!(info.status_message.as_ref().unwrap().contains("failed"));
}
#[test]
fn test_list_tasks() {
let store = TaskStore::new();
store.create_task("tool1", serde_json::json!({}), None);
store.create_task("tool2", serde_json::json!({}), None);
let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
store.complete_task(&id3, CallToolResult::text("Done"));
let all = store.list_tasks(None);
assert_eq!(all.len(), 3);
let working = store.list_tasks(Some(TaskStatus::Working));
assert_eq!(working.len(), 2);
let completed = store.list_tasks(Some(TaskStatus::Completed));
assert_eq!(completed.len(), 1);
}
#[test]
fn test_terminal_state_immutable() {
let store = TaskStore::new();
let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
store.complete_task(&id, CallToolResult::text("Done"));
assert!(!store.fail_task(&id, "Error"));
let info = store.get_task(&id).unwrap();
assert_eq!(info.status, TaskStatus::Completed);
}
#[test]
fn test_task_ids_unique() {
let store = TaskStore::new();
let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
}
#[test]
fn test_get_task_result() {
let store = TaskStore::new();
let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
let result = CallToolResult::text("The result");
store.complete_task(&id, result);
let (task_obj, result, error) = store.get_task_result(&id).unwrap();
assert_eq!(task_obj.status, TaskStatus::Completed);
assert!(result.is_some());
assert!(error.is_none());
}
#[test]
fn test_iso8601_timestamp() {
let ts = chrono_now_iso8601();
assert!(ts.ends_with('Z'));
assert!(ts.contains('T'));
assert_eq!(ts.len(), 24); }
#[test]
fn test_task_status_display() {
assert_eq!(TaskStatus::Working.to_string(), "working");
assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
assert_eq!(TaskStatus::Completed.to_string(), "completed");
assert_eq!(TaskStatus::Failed.to_string(), "failed");
assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
}
#[test]
fn test_task_status_is_terminal() {
assert!(!TaskStatus::Working.is_terminal());
assert!(!TaskStatus::InputRequired.is_terminal());
assert!(TaskStatus::Completed.is_terminal());
assert!(TaskStatus::Failed.is_terminal());
assert!(TaskStatus::Cancelled.is_terminal());
}
}