a2a_protocol_server/handler/lifecycle/
subscribe.rs1use std::collections::HashMap;
9use std::time::Instant;
10
11use a2a_protocol_types::params::TaskIdParams;
12use a2a_protocol_types::task::TaskId;
13
14use crate::error::{ServerError, ServerResult};
15use crate::streaming::InMemoryQueueReader;
16
17use super::super::helpers::build_call_context;
18use super::super::RequestHandler;
19
20impl RequestHandler {
21 pub async fn on_resubscribe(
27 &self,
28 params: TaskIdParams,
29 headers: Option<&HashMap<String, String>>,
30 ) -> ServerResult<InMemoryQueueReader> {
31 let start = Instant::now();
32 trace_info!(method = "SubscribeToTask", task_id = %params.id, "handling resubscribe");
33 self.metrics.on_request("SubscribeToTask");
34
35 let tenant = params.tenant.clone().unwrap_or_default();
36 let result: ServerResult<_> = crate::store::tenant::TenantContext::scope(tenant, async {
37 let call_ctx = build_call_context("SubscribeToTask", headers);
38 self.interceptors.run_before(&call_ctx).await?;
39
40 let task_id = TaskId::new(¶ms.id);
41
42 let task = self
44 .task_store
45 .get(&task_id)
46 .await?
47 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
48
49 if task.status.state.is_terminal() {
52 return Err(ServerError::UnsupportedOperation(format!(
53 "task {} is in terminal state '{}' and cannot be subscribed to",
54 task_id, task.status.state
55 )));
56 }
57
58 let snapshot = a2a_protocol_types::events::StreamResponse::Task(task);
61 let reader = self
62 .event_queue_manager
63 .subscribe_with_snapshot(&task_id, snapshot)
64 .await
65 .ok_or_else(|| ServerError::Internal("no active event queue for task".into()))?;
66
67 self.interceptors.run_after(&call_ctx).await?;
68 Ok(reader)
69 })
70 .await;
71
72 let elapsed = start.elapsed();
73 match &result {
74 Ok(_) => {
75 self.metrics.on_response("SubscribeToTask");
76 self.metrics.on_latency("SubscribeToTask", elapsed);
77 }
78 Err(e) => {
79 self.metrics.on_error("SubscribeToTask", &e.to_string());
80 self.metrics.on_latency("SubscribeToTask", elapsed);
81 }
82 }
83 result
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use a2a_protocol_types::params::TaskIdParams;
90
91 use crate::agent_executor;
92 use crate::builder::RequestHandlerBuilder;
93 use crate::error::ServerError;
94
95 struct DummyExecutor;
96 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
97
98 #[tokio::test]
99 async fn resubscribe_task_not_found_returns_error() {
100 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
101 let params = TaskIdParams {
102 tenant: None,
103 id: "nonexistent-task".to_owned(),
104 };
105 let result = handler.on_resubscribe(params, None).await;
106 assert!(
107 matches!(result, Err(ServerError::TaskNotFound(_))),
108 "expected TaskNotFound for missing task, got: {result:?}"
109 );
110 }
111
112 #[tokio::test]
113 async fn resubscribe_terminal_task_returns_unsupported_operation() {
114 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
116
117 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
118 let task = Task {
119 id: TaskId::new("t-resub-1"),
120 context_id: ContextId::new("ctx-1"),
121 status: TaskStatus::new(TaskState::Completed),
122 history: None,
123 artifacts: None,
124 metadata: None,
125 };
126 handler.task_store.save(&task).await.unwrap();
127
128 let params = TaskIdParams {
129 tenant: None,
130 id: "t-resub-1".to_owned(),
131 };
132 let result = handler.on_resubscribe(params, None).await;
133 assert!(
134 matches!(result, Err(ServerError::UnsupportedOperation(ref msg)) if msg.contains("terminal")),
135 "expected UnsupportedOperation for terminal task, got: {result:?}"
136 );
137 }
138
139 #[tokio::test]
140 async fn resubscribe_nonterminal_no_queue_returns_internal_error() {
141 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
143
144 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
145 let task = Task {
146 id: TaskId::new("t-resub-nonterminal"),
147 context_id: ContextId::new("ctx-1"),
148 status: TaskStatus::new(TaskState::Working),
149 history: None,
150 artifacts: None,
151 metadata: None,
152 };
153 handler.task_store.save(&task).await.unwrap();
154
155 let params = TaskIdParams {
156 tenant: None,
157 id: "t-resub-nonterminal".to_owned(),
158 };
159 let result = handler.on_resubscribe(params, None).await;
160 assert!(
161 matches!(result, Err(ServerError::Internal(_))),
162 "expected Internal error when no event queue exists for non-terminal task, got: {result:?}"
163 );
164 }
165
166 #[tokio::test]
167 async fn resubscribe_success_returns_reader() {
168 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
172 use a2a_protocol_types::params::MessageSendParams;
173 use a2a_protocol_types::task::ContextId;
174
175 use crate::handler::SendMessageResult;
176
177 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
178
179 let params = MessageSendParams {
181 message: Message {
182 id: MessageId::new("msg-resub"),
183 role: MessageRole::User,
184 parts: vec![Part::text("hello")],
185 context_id: Some(ContextId::new("ctx-resub")),
186 task_id: None,
187 reference_task_ids: None,
188 extensions: None,
189 metadata: None,
190 },
191 configuration: None,
192 metadata: None,
193 tenant: None,
194 };
195
196 let result = handler.on_send_message(params, true, None).await;
197 assert!(matches!(result, Ok(SendMessageResult::Stream(_))));
198
199 let tasks = handler
201 .task_store
202 .list(&a2a_protocol_types::params::ListTasksParams::default())
203 .await
204 .unwrap();
205 assert!(!tasks.tasks.is_empty(), "should have at least one task");
206
207 let task_id = tasks.tasks[0].id.0.clone();
208
209 let sub_params = TaskIdParams {
211 tenant: None,
212 id: task_id,
213 };
214 let sub_result = handler.on_resubscribe(sub_params, None).await;
215 match &sub_result {
219 Ok(_) | Err(ServerError::Internal(_)) => {} Err(e) => panic!("unexpected error: {e:?}"),
221 }
222 }
223
224 #[tokio::test]
225 async fn resubscribe_with_tenant() {
226 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
228 let params = TaskIdParams {
229 tenant: Some("test-tenant".to_string()),
230 id: "nonexistent-task".to_owned(),
231 };
232 let result = handler.on_resubscribe(params, None).await;
233 assert!(result.is_err(), "resubscribe for missing task should fail");
234 }
235
236 #[tokio::test]
237 async fn resubscribe_with_headers() {
238 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
240 let params = TaskIdParams {
241 tenant: None,
242 id: "nonexistent-task".to_owned(),
243 };
244 let mut headers = std::collections::HashMap::new();
245 headers.insert("authorization".to_string(), "Bearer tok".to_string());
246 let result = handler.on_resubscribe(params, Some(&headers)).await;
247 assert!(result.is_err());
248 }
249
250 #[tokio::test]
251 async fn resubscribe_error_path_records_error_metrics() {
252 use crate::call_context::CallContext;
254 use crate::interceptor::ServerInterceptor;
255 use std::future::Future;
256 use std::pin::Pin;
257
258 struct FailInterceptor;
259 impl ServerInterceptor for FailInterceptor {
260 fn before<'a>(
261 &'a self,
262 _ctx: &'a CallContext,
263 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
264 {
265 Box::pin(async {
266 Err(a2a_protocol_types::error::A2aError::internal(
267 "forced failure",
268 ))
269 })
270 }
271 fn after<'a>(
272 &'a self,
273 _ctx: &'a CallContext,
274 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
275 {
276 Box::pin(async { Ok(()) })
277 }
278 }
279
280 let handler = RequestHandlerBuilder::new(DummyExecutor)
281 .with_interceptor(FailInterceptor)
282 .build()
283 .unwrap();
284
285 let params = TaskIdParams {
286 tenant: None,
287 id: "t-resub-fail".to_owned(),
288 };
289 let result = handler.on_resubscribe(params, None).await;
290 assert!(
291 result.is_err(),
292 "resubscribe should fail when interceptor rejects"
293 );
294 }
295}