a2a_protocol_server/handler/lifecycle/
subscribe.rs1use std::collections::HashMap;
9use std::time::Instant;
10
11use a2a_protocol_types::params::TaskIdParams;
12use a2a_protocol_types::task::TaskId;
13
14use crate::error::{ServerError, ServerResult};
15use crate::streaming::InMemoryQueueReader;
16
17use super::super::helpers::build_call_context;
18use super::super::RequestHandler;
19
20impl RequestHandler {
21 pub async fn on_resubscribe(
27 &self,
28 params: TaskIdParams,
29 headers: Option<&HashMap<String, String>>,
30 ) -> ServerResult<InMemoryQueueReader> {
31 let start = Instant::now();
32 trace_info!(method = "SubscribeToTask", task_id = %params.id, "handling resubscribe");
33 self.metrics.on_request("SubscribeToTask");
34
35 let tenant = params.tenant.clone().unwrap_or_default();
36 let result: ServerResult<_> = crate::store::tenant::TenantContext::scope(tenant, async {
37 let call_ctx = build_call_context("SubscribeToTask", headers);
38 self.interceptors.run_before(&call_ctx).await?;
39
40 let task_id = TaskId::new(¶ms.id);
41
42 let _task = self
44 .task_store
45 .get(&task_id)
46 .await?
47 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
48
49 let reader = self
50 .event_queue_manager
51 .subscribe(&task_id)
52 .await
53 .ok_or_else(|| ServerError::Internal("no active event queue for task".into()))?;
54
55 self.interceptors.run_after(&call_ctx).await?;
56 Ok(reader)
57 })
58 .await;
59
60 let elapsed = start.elapsed();
61 match &result {
62 Ok(_) => {
63 self.metrics.on_response("SubscribeToTask");
64 self.metrics.on_latency("SubscribeToTask", elapsed);
65 }
66 Err(e) => {
67 self.metrics.on_error("SubscribeToTask", &e.to_string());
68 self.metrics.on_latency("SubscribeToTask", elapsed);
69 }
70 }
71 result
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use a2a_protocol_types::params::TaskIdParams;
78
79 use crate::agent_executor;
80 use crate::builder::RequestHandlerBuilder;
81 use crate::error::ServerError;
82
83 struct DummyExecutor;
84 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
85
86 #[tokio::test]
87 async fn resubscribe_task_not_found_returns_error() {
88 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
89 let params = TaskIdParams {
90 tenant: None,
91 id: "nonexistent-task".to_owned(),
92 };
93 let result = handler.on_resubscribe(params, None).await;
94 assert!(
95 matches!(result, Err(ServerError::TaskNotFound(_))),
96 "expected TaskNotFound for missing task, got: {result:?}"
97 );
98 }
99
100 #[tokio::test]
101 async fn resubscribe_task_exists_but_no_queue_returns_internal_error() {
102 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
104
105 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
106 let task = Task {
107 id: TaskId::new("t-resub-1"),
108 context_id: ContextId::new("ctx-1"),
109 status: TaskStatus::new(TaskState::Completed),
110 history: None,
111 artifacts: None,
112 metadata: None,
113 };
114 handler.task_store.save(task).await.unwrap();
115
116 let params = TaskIdParams {
117 tenant: None,
118 id: "t-resub-1".to_owned(),
119 };
120 let result = handler.on_resubscribe(params, None).await;
121 assert!(
122 matches!(result, Err(ServerError::Internal(_))),
123 "expected Internal error when no event queue exists, got: {result:?}"
124 );
125 }
126
127 #[tokio::test]
128 async fn resubscribe_success_returns_reader() {
129 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
133 use a2a_protocol_types::params::MessageSendParams;
134 use a2a_protocol_types::task::ContextId;
135
136 use crate::handler::SendMessageResult;
137
138 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
139
140 let params = MessageSendParams {
142 context_id: None,
143 message: Message {
144 id: MessageId::new("msg-resub"),
145 role: MessageRole::User,
146 parts: vec![Part::text("hello")],
147 context_id: Some(ContextId::new("ctx-resub")),
148 task_id: None,
149 reference_task_ids: None,
150 extensions: None,
151 metadata: None,
152 },
153 configuration: None,
154 metadata: None,
155 tenant: None,
156 };
157
158 let result = handler.on_send_message(params, true, None).await;
159 assert!(matches!(result, Ok(SendMessageResult::Stream(_))));
160
161 let tasks = handler
163 .task_store
164 .list(&a2a_protocol_types::params::ListTasksParams::default())
165 .await
166 .unwrap();
167 assert!(!tasks.tasks.is_empty(), "should have at least one task");
168
169 let task_id = tasks.tasks[0].id.0.clone();
170
171 let sub_params = TaskIdParams {
173 tenant: None,
174 id: task_id,
175 };
176 let sub_result = handler.on_resubscribe(sub_params, None).await;
177 match &sub_result {
181 Ok(_) | Err(ServerError::Internal(_)) => {} Err(e) => panic!("unexpected error: {e:?}"),
183 }
184 }
185
186 #[tokio::test]
187 async fn resubscribe_with_tenant() {
188 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
190 let params = TaskIdParams {
191 tenant: Some("test-tenant".to_string()),
192 id: "nonexistent-task".to_owned(),
193 };
194 let result = handler.on_resubscribe(params, None).await;
195 assert!(result.is_err(), "resubscribe for missing task should fail");
196 }
197
198 #[tokio::test]
199 async fn resubscribe_with_headers() {
200 let handler = RequestHandlerBuilder::new(DummyExecutor).build().unwrap();
202 let params = TaskIdParams {
203 tenant: None,
204 id: "nonexistent-task".to_owned(),
205 };
206 let mut headers = std::collections::HashMap::new();
207 headers.insert("authorization".to_string(), "Bearer tok".to_string());
208 let result = handler.on_resubscribe(params, Some(&headers)).await;
209 assert!(result.is_err());
210 }
211
212 #[tokio::test]
213 async fn resubscribe_error_path_records_error_metrics() {
214 use crate::call_context::CallContext;
216 use crate::interceptor::ServerInterceptor;
217 use std::future::Future;
218 use std::pin::Pin;
219
220 struct FailInterceptor;
221 impl ServerInterceptor for FailInterceptor {
222 fn before<'a>(
223 &'a self,
224 _ctx: &'a CallContext,
225 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
226 {
227 Box::pin(async {
228 Err(a2a_protocol_types::error::A2aError::internal(
229 "forced failure",
230 ))
231 })
232 }
233 fn after<'a>(
234 &'a self,
235 _ctx: &'a CallContext,
236 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
237 {
238 Box::pin(async { Ok(()) })
239 }
240 }
241
242 let handler = RequestHandlerBuilder::new(DummyExecutor)
243 .with_interceptor(FailInterceptor)
244 .build()
245 .unwrap();
246
247 let params = TaskIdParams {
248 tenant: None,
249 id: "t-resub-fail".to_owned(),
250 };
251 let result = handler.on_resubscribe(params, None).await;
252 assert!(
253 result.is_err(),
254 "resubscribe should fail when interceptor rejects"
255 );
256 }
257}