a2a_protocol_server/handler/lifecycle/
cancel_task.rs1use std::collections::HashMap;
9use std::time::Instant;
10
11use a2a_protocol_types::params::CancelTaskParams;
12use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
13
14use crate::error::{ServerError, ServerResult};
15use crate::request_context::RequestContext;
16
17use super::super::helpers::build_call_context;
18use super::super::RequestHandler;
19
20impl RequestHandler {
21 #[allow(clippy::too_many_lines)]
27 pub async fn on_cancel_task(
28 &self,
29 params: CancelTaskParams,
30 headers: Option<&HashMap<String, String>>,
31 ) -> ServerResult<Task> {
32 let start = Instant::now();
33 trace_info!(method = "CancelTask", task_id = %params.id, "handling cancel task");
34 self.metrics.on_request("CancelTask");
35
36 let tenant = params.tenant.clone().unwrap_or_default();
37 let result: ServerResult<_> = crate::store::tenant::TenantContext::scope(tenant, async {
38 let call_ctx = build_call_context("CancelTask", headers);
39 self.interceptors.run_before(&call_ctx).await?;
40
41 let task_id = TaskId::new(¶ms.id);
42 let task = self
43 .task_store
44 .get(&task_id)
45 .await?
46 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
47
48 if task.status.state.is_terminal() {
49 return Err(ServerError::TaskNotCancelable(task_id));
50 }
51
52 {
54 let tokens = self.cancellation_tokens.read().await;
55 if let Some(entry) = tokens.get(&task_id) {
56 entry.token.cancel();
57 }
58 }
59
60 let ctx = RequestContext::new(
62 a2a_protocol_types::message::Message {
63 id: a2a_protocol_types::message::MessageId::new(
64 uuid::Uuid::new_v4().to_string(),
65 ),
66 role: a2a_protocol_types::message::MessageRole::User,
67 parts: vec![],
68 task_id: Some(task_id.clone()),
69 context_id: Some(task.context_id.clone()),
70 reference_task_ids: None,
71 extensions: None,
72 metadata: None,
73 },
74 task_id.clone(),
75 task.context_id.0.clone(),
76 );
77
78 let (writer, _reader) = self.event_queue_manager.get_or_create(&task_id).await;
79 self.executor.cancel(&ctx, writer.as_ref()).await?;
80
81 let current = self
85 .task_store
86 .get(&task_id)
87 .await?
88 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
89 if current.status.state.is_terminal() {
90 return Err(ServerError::TaskNotCancelable(task_id));
91 }
92
93 let mut updated = current;
94 updated.status = TaskStatus::with_timestamp(TaskState::Canceled);
95 self.task_store.save(updated).await?;
96 let final_task = self
98 .task_store
99 .get(&task_id)
100 .await?
101 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
102
103 self.interceptors.run_after(&call_ctx).await?;
104 Ok(final_task)
105 })
106 .await;
107
108 let elapsed = start.elapsed();
109 match &result {
110 Ok(_) => {
111 self.metrics.on_response("CancelTask");
112 self.metrics.on_latency("CancelTask", elapsed);
113 }
114 Err(e) => {
115 self.metrics.on_error("CancelTask", &e.to_string());
116 self.metrics.on_latency("CancelTask", elapsed);
117 }
118 }
119 result
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use a2a_protocol_types::params::CancelTaskParams;
126 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
127
128 use crate::agent_executor;
129 use crate::builder::RequestHandlerBuilder;
130 use crate::error::ServerError;
131
132 struct DummyExecutor;
133 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
134
135 struct CancelableExecutor;
136 agent_executor!(CancelableExecutor,
137 execute: |_ctx, _queue| async { Ok(()) },
138 cancel: |_ctx, _queue| async { Ok(()) }
139 );
140
141 fn make_completed_task(id: &str) -> Task {
142 Task {
143 id: TaskId::new(id),
144 context_id: ContextId::new("ctx-1"),
145 status: TaskStatus::new(TaskState::Completed),
146 history: None,
147 artifacts: None,
148 metadata: None,
149 }
150 }
151
152 fn make_submitted_task(id: &str) -> Task {
153 Task {
154 id: TaskId::new(id),
155 context_id: ContextId::new("ctx-1"),
156 status: TaskStatus::new(TaskState::Submitted),
157 history: None,
158 artifacts: None,
159 metadata: None,
160 }
161 }
162
163 #[tokio::test]
164 async fn cancel_task_not_found_returns_error() {
165 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
166 let params = CancelTaskParams {
167 tenant: None,
168 id: "nonexistent-task".to_owned(),
169 metadata: None,
170 };
171 let result = handler.on_cancel_task(params, None).await;
172 assert!(
173 matches!(result, Err(ServerError::TaskNotFound(_))),
174 "expected TaskNotFound for missing task, got: {result:?}"
175 );
176 }
177
178 #[tokio::test]
179 async fn cancel_task_terminal_state_returns_not_cancelable() {
180 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
181 let task = make_completed_task("t-cancel-terminal");
182 handler.task_store.save(task).await.unwrap();
183
184 let params = CancelTaskParams {
185 tenant: None,
186 id: "t-cancel-terminal".to_owned(),
187 metadata: None,
188 };
189 let result = handler.on_cancel_task(params, None).await;
190 assert!(
191 matches!(result, Err(ServerError::TaskNotCancelable(_))),
192 "expected TaskNotCancelable for completed task, got: {result:?}"
193 );
194 }
195
196 #[tokio::test]
197 async fn cancel_task_non_terminal_succeeds() {
198 let handler = RequestHandlerBuilder::new(CancelableExecutor)
199 .build()
200 .unwrap();
201 let task = make_submitted_task("t-cancel-active");
202 handler.task_store.save(task).await.unwrap();
203
204 let params = CancelTaskParams {
205 tenant: None,
206 id: "t-cancel-active".to_owned(),
207 metadata: None,
208 };
209 let result = handler.on_cancel_task(params, None).await;
210 assert!(
211 result.is_ok(),
212 "canceling a non-terminal task should succeed, got: {result:?}"
213 );
214 assert_eq!(
215 result.unwrap().status.state,
216 TaskState::Canceled,
217 "canceled task should have Canceled state"
218 );
219 }
220
221 #[tokio::test]
222 async fn cancel_task_error_path_records_metrics() {
223 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
225 let params = CancelTaskParams {
226 tenant: None,
227 id: "nonexistent-for-metrics".to_owned(),
228 metadata: None,
229 };
230 let result = handler.on_cancel_task(params, None).await;
231 assert!(
232 matches!(result, Err(ServerError::TaskNotFound(_))),
233 "expected TaskNotFound, got: {result:?}"
234 );
235 }
237}