use std::collections::HashMap;
use std::time::Instant;
use a2a_protocol_types::params::CancelTaskParams;
use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
use crate::error::{ServerError, ServerResult};
use crate::request_context::RequestContext;
use super::super::helpers::build_call_context;
use super::super::RequestHandler;
impl RequestHandler {
#[allow(clippy::too_many_lines)]
pub async fn on_cancel_task(
&self,
params: CancelTaskParams,
headers: Option<&HashMap<String, String>>,
) -> ServerResult<Task> {
let start = Instant::now();
trace_info!(method = "CancelTask", task_id = %params.id, "handling cancel task");
self.metrics.on_request("CancelTask");
let tenant = params.tenant.clone().unwrap_or_default();
let result: ServerResult<_> = crate::store::tenant::TenantContext::scope(tenant, async {
let call_ctx = build_call_context("CancelTask", headers);
self.interceptors.run_before(&call_ctx).await?;
let task_id = TaskId::new(¶ms.id);
let task = self
.task_store
.get(&task_id)
.await?
.ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
if task.status.state.is_terminal() {
return Err(ServerError::TaskNotCancelable(task_id));
}
{
let tokens = self.cancellation_tokens.read().await;
if let Some(entry) = tokens.get(&task_id) {
entry.token.cancel();
}
}
let ctx = RequestContext::new(
a2a_protocol_types::message::Message {
id: a2a_protocol_types::message::MessageId::new(
uuid::Uuid::new_v4().to_string(),
),
role: a2a_protocol_types::message::MessageRole::User,
parts: vec![],
task_id: Some(task_id.clone()),
context_id: Some(task.context_id.clone()),
reference_task_ids: None,
extensions: None,
metadata: None,
},
task_id.clone(),
task.context_id.0.clone(),
);
let (writer, _reader) = self.event_queue_manager.get_or_create(&task_id).await;
self.executor.cancel(&ctx, writer.as_ref()).await?;
let current = self
.task_store
.get(&task_id)
.await?
.ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
if current.status.state.is_terminal() {
return Err(ServerError::TaskNotCancelable(task_id));
}
let mut updated = current;
updated.status = TaskStatus::with_timestamp(TaskState::Canceled);
self.task_store.save(updated).await?;
let final_task = self
.task_store
.get(&task_id)
.await?
.ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
self.interceptors.run_after(&call_ctx).await?;
Ok(final_task)
})
.await;
let elapsed = start.elapsed();
match &result {
Ok(_) => {
self.metrics.on_response("CancelTask");
self.metrics.on_latency("CancelTask", elapsed);
}
Err(e) => {
self.metrics.on_error("CancelTask", &e.to_string());
self.metrics.on_latency("CancelTask", elapsed);
}
}
result
}
}
#[cfg(test)]
mod tests {
use a2a_protocol_types::params::CancelTaskParams;
use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
use crate::agent_executor;
use crate::builder::RequestHandlerBuilder;
use crate::error::ServerError;
struct DummyExecutor;
agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
struct CancelableExecutor;
agent_executor!(CancelableExecutor,
execute: |_ctx, _queue| async { Ok(()) },
cancel: |_ctx, _queue| async { Ok(()) }
);
fn make_completed_task(id: &str) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new("ctx-1"),
status: TaskStatus::new(TaskState::Completed),
history: None,
artifacts: None,
metadata: None,
}
}
fn make_submitted_task(id: &str) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new("ctx-1"),
status: TaskStatus::new(TaskState::Submitted),
history: None,
artifacts: None,
metadata: None,
}
}
#[tokio::test]
async fn cancel_task_not_found_returns_error() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = CancelTaskParams {
tenant: None,
id: "nonexistent-task".to_owned(),
metadata: None,
};
let result = handler.on_cancel_task(params, None).await;
assert!(
matches!(result, Err(ServerError::TaskNotFound(_))),
"expected TaskNotFound for missing task, got: {result:?}"
);
}
#[tokio::test]
async fn cancel_task_terminal_state_returns_not_cancelable() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let task = make_completed_task("t-cancel-terminal");
handler.task_store.save(task).await.unwrap();
let params = CancelTaskParams {
tenant: None,
id: "t-cancel-terminal".to_owned(),
metadata: None,
};
let result = handler.on_cancel_task(params, None).await;
assert!(
matches!(result, Err(ServerError::TaskNotCancelable(_))),
"expected TaskNotCancelable for completed task, got: {result:?}"
);
}
#[tokio::test]
async fn cancel_task_non_terminal_succeeds() {
let handler = RequestHandlerBuilder::new(CancelableExecutor)
.build()
.unwrap();
let task = make_submitted_task("t-cancel-active");
handler.task_store.save(task).await.unwrap();
let params = CancelTaskParams {
tenant: None,
id: "t-cancel-active".to_owned(),
metadata: None,
};
let result = handler.on_cancel_task(params, None).await;
assert!(
result.is_ok(),
"canceling a non-terminal task should succeed, got: {result:?}"
);
assert_eq!(
result.unwrap().status.state,
TaskState::Canceled,
"canceled task should have Canceled state"
);
}
#[tokio::test]
async fn cancel_task_error_path_records_metrics() {
let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
let params = CancelTaskParams {
tenant: None,
id: "nonexistent-for-metrics".to_owned(),
metadata: None,
};
let result = handler.on_cancel_task(params, None).await;
assert!(
matches!(result, Err(ServerError::TaskNotFound(_))),
"expected TaskNotFound, got: {result:?}"
);
}
}