use crate::error::{Error, ErrorCode};
use crate::types::{Task, TaskPayload, TaskStatus};
use chrono::Utc;
use serde::Serialize;
use std::{
cmp::Ordering,
collections::BinaryHeap,
sync::{
Mutex,
atomic::{AtomicU64, Ordering as AtomicOrdering},
},
};
use tokio::sync::watch::{Receiver, Sender, channel};
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
pub(crate) struct TaskTracker {
tasks: dashmap::DashMap<String, TaskEntry>,
expirations: Mutex<BinaryHeap<TaskExpiry>>,
next_expiry_seq: AtomicU64,
}
pub(crate) type MaybePayload = Option<TaskPayload>;
pub(crate) struct TaskEntry {
task: Task,
token: CancellationToken,
#[cfg(feature = "server")]
tx: Sender<MaybePayload>,
rx: Receiver<MaybePayload>,
}
pub(crate) struct TaskHandle {
token: CancellationToken,
tx: Sender<MaybePayload>,
}
struct TaskExpiry {
deadline_ms: i64,
sequence: u64,
id: String,
}
impl PartialEq for TaskExpiry {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.deadline_ms == other.deadline_ms && self.sequence == other.sequence
}
}
impl Eq for TaskExpiry {}
impl PartialOrd for TaskExpiry {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TaskExpiry {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
other
.deadline_ms
.cmp(&self.deadline_ms)
.then_with(|| other.sequence.cmp(&self.sequence))
}
}
impl TaskTracker {
#[inline]
pub(crate) fn new() -> Self {
Self {
tasks: dashmap::DashMap::new(),
expirations: Mutex::new(BinaryHeap::new()),
next_expiry_seq: AtomicU64::new(0),
}
}
pub(crate) fn tasks(&self) -> Vec<Task> {
self.cleanup_expired();
self.tasks
.iter()
.map(|entry| entry.task.clone())
.collect::<Vec<_>>()
}
pub(crate) fn track(&self, task: Task) -> TaskHandle {
self.cleanup_expired();
let token = CancellationToken::new();
let (tx, rx) = channel(None);
self.tasks.insert(
task.id.clone(),
TaskEntry {
token: token.clone(),
#[cfg(feature = "server")]
tx: tx.clone(),
task,
rx,
},
);
TaskHandle { token, tx }
}
pub(crate) fn cancel(&self, id: &str) -> Result<Task, Error> {
self.cleanup_expired();
if let Some((_, entry)) = self.tasks.remove(id) {
entry.token.cancel();
Ok(entry.task.cancel())
} else {
Err(Error::new(
ErrorCode::InvalidParams,
format!("Could not find task with id: {id}"),
))
}
}
pub(crate) fn complete(&self, id: &str) {
self.cleanup_expired();
if let Some(mut entry) = self.tasks.get_mut(id) {
entry.task.complete();
self.schedule_expiry(&entry.task);
}
}
#[cfg(feature = "server")]
pub(crate) fn fail(&self, id: &str) {
self.cleanup_expired();
if let Some(mut entry) = self.tasks.get_mut(id) {
entry.task.fail();
self.schedule_expiry(&entry.task);
}
}
#[cfg(feature = "server")]
pub(crate) fn require_input(&self, id: &str) {
self.cleanup_expired();
if let Some(mut entry) = self.tasks.get_mut(id) {
entry.task.require_input();
}
}
#[cfg(feature = "server")]
pub(crate) fn reset(&self, id: &str) {
self.cleanup_expired();
if let Some(mut entry) = self.tasks.get_mut(id) {
entry.task.reset();
let _ = entry.tx.send(None);
}
}
#[cfg(feature = "server")]
pub(crate) fn set_result<T: Serialize>(&self, id: &str, result: T) {
self.cleanup_expired();
if let Some(entry) = self.tasks.get(id) {
let result = match serde_json::to_value(result) {
Ok(result) => result,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Unable to serialize task result: {_err:?}");
return;
}
};
let _ = entry.tx.send(Some(TaskPayload(result)));
}
}
pub(crate) fn get_status(&self, id: &str) -> Result<Task, Error> {
self.cleanup_expired();
self.tasks.get(id).map(|t| t.task.clone()).ok_or_else(|| {
Error::new(
ErrorCode::InvalidParams,
format!("Could not find task with id: {id}"),
)
})
}
pub(crate) async fn get_result(&self, id: &str) -> Result<TaskPayload, Error> {
self.cleanup_expired();
let (status, mut result_rx, token) = {
let entry = self.tasks.get(id).ok_or_else(|| {
Error::new(
ErrorCode::InvalidParams,
format!("Could not find task with id: {id}"),
)
})?;
(entry.task.status, entry.rx.clone(), entry.token.clone())
};
if let Some(ref result) = *result_rx.borrow_and_update() {
if status != TaskStatus::InputRequired {
self.tasks.remove(id);
}
return Ok(result.clone());
}
loop {
tokio::select! {
changed = result_rx.changed() => {
if changed.is_err() {
return Err(Error::new(ErrorCode::InternalError, "Unable to get task result"));
}
if let Some(result) = result_rx.borrow_and_update().clone() {
let task = self.get_status(id)?;
if task.status != TaskStatus::InputRequired {
self.tasks.remove(id);
}
return Ok(result);
}
}
_ = token.cancelled() => {
return Err(Error::new(ErrorCode::InvalidRequest, "Task has been cancelled"));
}
}
}
}
#[inline]
fn cleanup_expired(&self) {
let now_ms = Utc::now().timestamp_millis();
let mut expired = Vec::new();
if let Ok(mut expirations) = self.expirations.lock() {
while expirations
.peek()
.is_some_and(|entry| entry.deadline_ms <= now_ms)
{
let entry = expirations.pop().expect("peeked entry must exist");
expired.push((entry.id, entry.deadline_ms));
}
}
for (id, deadline_ms) in expired {
let should_remove = self.tasks.get(&id).is_some_and(|entry| {
Self::is_terminal(&entry.task)
&& Self::task_deadline_ms(&entry.task).is_some_and(|d| d == deadline_ms)
});
if should_remove {
let _ = self.tasks.remove(&id);
}
}
}
#[inline]
fn is_terminal(task: &Task) -> bool {
matches!(
task.status,
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
)
}
#[inline]
fn schedule_expiry(&self, task: &Task) {
let Some(deadline_ms) = Self::task_deadline_ms(task) else {
return;
};
let sequence = self.next_expiry_seq.fetch_add(1, AtomicOrdering::Relaxed);
if let Ok(mut expirations) = self.expirations.lock() {
expirations.push(TaskExpiry {
deadline_ms,
sequence,
id: task.id.clone(),
});
}
}
#[inline]
fn task_deadline_ms(task: &Task) -> Option<i64> {
let ttl_ms = i64::try_from(task.ttl).unwrap_or(i64::MAX);
task.created_at.timestamp_millis().checked_add(ttl_ms)
}
}
impl Default for TaskTracker {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl TaskHandle {
pub(crate) fn set_result<T: Serialize>(self, result: T) {
let result = match serde_json::to_value(result) {
Ok(result) => result,
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(logger = "neva", "Unable to serialize task result: {_err:?}");
return;
}
};
let _ = self.tx.send(Some(TaskPayload(result)));
}
#[inline]
pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> {
self.token.cancelled()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TaskStatus;
use std::sync::Arc;
#[cfg(feature = "server")]
use crate::types::CallToolResponse;
#[test]
fn it_can_create_new_tracker() {
let tracker = TaskTracker::new();
assert_eq!(tracker.tasks().len(), 0);
}
#[test]
fn it_can_track_task() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task);
let tasks = tracker.tasks();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].id, task_id);
}
#[test]
fn it_can_return_list_of_tasks() {
let tracker = TaskTracker::new();
let task1 = Task::new();
let task2 = Task::new();
let _handle1 = tracker.track(task1.clone());
let _handle2 = tracker.track(task2.clone());
let tasks = tracker.tasks();
assert_eq!(tasks.len(), 2);
}
#[test]
fn it_can_cancel_task() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task);
let result = tracker.cancel(&task_id).unwrap();
assert_eq!(result.status, TaskStatus::Cancelled);
assert_eq!(tracker.tasks().len(), 0);
}
#[test]
fn it_does_return_error_when_cancelling_nonexistent_task() {
let tracker = TaskTracker::new();
let result = tracker.cancel("nonexistent");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::InvalidParams);
}
#[test]
fn it_can_complete_task() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task);
tracker.complete(&task_id);
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::Completed);
}
#[test]
fn it_does_nothing_when_completing_nonexistent_task() {
let tracker = TaskTracker::new();
tracker.complete("nonexistent");
}
#[cfg(feature = "server")]
#[test]
fn it_can_fail_task() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task);
tracker.fail(&task_id);
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::Failed);
}
#[cfg(feature = "server")]
#[test]
fn it_does_nothing_when_failing_nonexistent_task() {
let tracker = TaskTracker::new();
tracker.fail("nonexistent");
}
#[cfg(feature = "server")]
#[test]
fn it_can_require_input() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task);
tracker.require_input(&task_id);
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::InputRequired);
}
#[cfg(feature = "server")]
#[test]
fn it_does_nothing_when_requiring_input_for_nonexistent_task() {
let tracker = TaskTracker::new();
tracker.require_input("nonexistent");
}
#[test]
fn it_can_get_task_status() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task.clone());
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.id, task.id);
assert_eq!(status.status, TaskStatus::Working);
}
#[test]
fn it_does_return_error_when_getting_status_of_nonexistent_task() {
let tracker = TaskTracker::new();
let result = tracker.get_status("nonexistent");
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::InvalidParams);
}
#[tokio::test]
async fn it_can_get_task_result_when_completed() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
handle.set_result("test_result".to_string());
});
let result = tracker.get_result(&task_id).await.unwrap();
assert_eq!(result.0, "test_result");
}
#[tokio::test]
async fn it_does_return_result_immediately_when_already_available() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
handle.set_result("immediate_result".to_string());
let result = tracker.get_result(&task_id).await.unwrap();
assert_eq!(result.0, "immediate_result");
}
#[tokio::test]
async fn it_does_return_error_when_getting_result_of_nonexistent_task() {
let tracker = TaskTracker::new();
let result = tracker.get_result("nonexistent").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::InvalidParams);
}
#[tokio::test]
async fn it_does_return_error_when_task_is_cancelled() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task.clone());
let tracker = Arc::new(tracker);
tokio::spawn({
let tracker = tracker.clone();
let task_id = task_id.clone();
async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let _ = tracker.cancel(&task_id);
}
});
let result = tracker.get_result(&task_id).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, ErrorCode::InvalidRequest);
}
#[tokio::test]
async fn it_can_wait_for_result_with_multiple_updates() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
handle.set_result("final_result".to_string());
});
let result = tracker.get_result(&task_id).await.unwrap();
assert_eq!(result.0, "final_result");
assert_eq!(tracker.tasks().len(), 0);
}
#[tokio::test]
async fn it_does_remove_task_after_getting_result() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
handle.set_result("result".to_string());
let _ = tracker.get_result(&task_id).await.unwrap();
assert_eq!(tracker.tasks().len(), 0);
}
#[tokio::test]
async fn it_can_create_task_handle() {
let tracker = TaskTracker::new();
let task = Task::new();
let handle = tracker.track(task);
tokio::spawn(async move {
tokio::select! {
_ = handle.cancelled() => {}
}
});
}
#[tokio::test]
async fn it_can_cancel_via_handle() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
let tracker = Arc::new(tracker);
tokio::spawn({
let tracker = tracker.clone();
let task_id = task_id.clone();
async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let _ = tracker.cancel(&task_id);
}
});
tokio::select! {
_ = handle.cancelled() => {
}
_ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {
panic!("Task was not cancelled");
}
}
}
#[test]
#[cfg(feature = "server")]
fn it_can_handle_complex_payload_types() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let handle = tracker.track(task.clone());
let response = CallToolResponse::new("test");
tracker.complete(&task_id);
handle.set_result(response.clone());
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::Completed);
}
#[tokio::test]
async fn it_can_track_multiple_concurrent_tasks() {
let tracker = TaskTracker::new();
let tasks: Vec<_> = (0..5).map(|_| Task::new()).collect();
let task_ids: Vec<_> = tasks.iter().map(|t| t.id.clone()).collect();
let handles: Vec<_> = tasks.into_iter().map(|t| tracker.track(t)).collect();
for (i, handle) in handles.into_iter().enumerate() {
let result = format!("result_{}", i);
handle.set_result(result);
}
for (i, task_id) in task_ids.iter().enumerate() {
let result = tracker.get_result(task_id).await.unwrap();
assert_eq!(result.0, format!("result_{}", i));
}
assert_eq!(tracker.tasks().len(), 0);
}
#[test]
fn it_does_maintain_task_state_transitions() {
let tracker = TaskTracker::new();
let task = Task::new();
let task_id = task.id.clone();
let _handle = tracker.track(task.clone());
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::Working);
tracker.complete(&task_id);
let status = tracker.get_status(&task_id).unwrap();
assert_eq!(status.status, TaskStatus::Completed);
}
#[test]
fn it_does_remove_expired_completed_tasks() {
let tracker = TaskTracker::new();
let task = Task::from(crate::types::TaskMetadata { ttl: Some(1) });
let task_id = task.id.clone();
let _handle = tracker.track(task);
tracker.complete(&task_id);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(tracker.get_status(&task_id).is_err());
assert_eq!(tracker.tasks().len(), 0);
}
}