a2a_protocol_client/methods/
tasks.rs1use a2a_protocol_types::{
12 CancelTaskParams, ListTasksParams, Task, TaskIdParams, TaskListResponse, TaskQueryParams,
13};
14
15use crate::client::A2aClient;
16use crate::error::{ClientError, ClientResult};
17use crate::interceptor::{ClientRequest, ClientResponse};
18use crate::streaming::EventStream;
19
20impl A2aClient {
21 pub async fn get_task(&self, params: TaskQueryParams) -> ClientResult<Task> {
30 const METHOD: &str = "GetTask";
31
32 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
33
34 let mut req = ClientRequest::new(METHOD, params_value);
35 self.interceptors.run_before(&mut req).await?;
36
37 let result = self
38 .transport
39 .send_request(METHOD, req.params, &req.extra_headers)
40 .await?;
41
42 let resp = ClientResponse {
43 method: METHOD.to_owned(),
44 result,
45 status_code: 200,
46 };
47 self.interceptors.run_after(&resp).await?;
48
49 serde_json::from_value::<Task>(resp.result).map_err(ClientError::Serialization)
50 }
51
52 pub async fn list_tasks(&self, params: ListTasksParams) -> ClientResult<TaskListResponse> {
61 const METHOD: &str = "ListTasks";
62
63 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
64
65 let mut req = ClientRequest::new(METHOD, params_value);
66 self.interceptors.run_before(&mut req).await?;
67
68 let result = self
69 .transport
70 .send_request(METHOD, req.params, &req.extra_headers)
71 .await?;
72
73 let resp = ClientResponse {
74 method: METHOD.to_owned(),
75 result,
76 status_code: 200,
77 };
78 self.interceptors.run_after(&resp).await?;
79
80 serde_json::from_value::<TaskListResponse>(resp.result).map_err(ClientError::Serialization)
81 }
82
83 pub async fn cancel_task(&self, id: impl Into<String>) -> ClientResult<Task> {
94 const METHOD: &str = "CancelTask";
95
96 let params = CancelTaskParams {
97 tenant: None,
98 id: id.into(),
99 metadata: None,
100 };
101 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
102
103 let mut req = ClientRequest::new(METHOD, params_value);
104 self.interceptors.run_before(&mut req).await?;
105
106 let result = self
107 .transport
108 .send_request(METHOD, req.params, &req.extra_headers)
109 .await?;
110
111 let resp = ClientResponse {
112 method: METHOD.to_owned(),
113 result,
114 status_code: 200,
115 };
116 self.interceptors.run_after(&resp).await?;
117
118 serde_json::from_value::<Task>(resp.result).map_err(ClientError::Serialization)
119 }
120
121 pub async fn subscribe_to_task(&self, id: impl Into<String>) -> ClientResult<EventStream> {
134 const METHOD: &str = "SubscribeToTask";
135
136 let params = TaskIdParams {
137 tenant: None,
138 id: id.into(),
139 };
140 let params_value = serde_json::to_value(¶ms).map_err(ClientError::Serialization)?;
141
142 let mut req = ClientRequest::new(METHOD, params_value);
143 self.interceptors.run_before(&mut req).await?;
144
145 let stream = self
146 .transport
147 .send_streaming_request(METHOD, req.params, &req.extra_headers)
148 .await?;
149
150 let resp = ClientResponse {
153 method: METHOD.to_owned(),
154 result: serde_json::Value::Null,
155 status_code: 200,
156 };
157 self.interceptors.run_after(&resp).await?;
158
159 Ok(stream)
160 }
161}
162
163#[cfg(test)]
166mod tests {
167 use std::collections::HashMap;
168 use std::future::Future;
169 use std::pin::Pin;
170
171 use a2a_protocol_types::{ListTasksParams, TaskQueryParams};
172
173 use crate::error::{ClientError, ClientResult};
174 use crate::streaming::EventStream;
175 use crate::transport::Transport;
176 use crate::ClientBuilder;
177
178 struct MockTransport {
181 response: serde_json::Value,
182 }
183
184 impl MockTransport {
185 fn new(response: serde_json::Value) -> Self {
186 Self { response }
187 }
188 }
189
190 impl Transport for MockTransport {
191 fn send_request<'a>(
192 &'a self,
193 _method: &'a str,
194 _params: serde_json::Value,
195 _extra_headers: &'a HashMap<String, String>,
196 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
197 let resp = self.response.clone();
198 Box::pin(async move { Ok(resp) })
199 }
200
201 fn send_streaming_request<'a>(
202 &'a self,
203 _method: &'a str,
204 _params: serde_json::Value,
205 _extra_headers: &'a HashMap<String, String>,
206 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
207 Box::pin(async move {
208 Err(ClientError::Transport(
209 "mock: streaming not supported".into(),
210 ))
211 })
212 }
213 }
214
215 fn make_client(transport: impl Transport) -> crate::A2aClient {
216 ClientBuilder::new("http://localhost:8080")
217 .with_custom_transport(transport)
218 .build()
219 .expect("build client")
220 }
221
222 fn task_json() -> serde_json::Value {
223 serde_json::json!({
224 "id": "task-1",
225 "contextId": "ctx-1",
226 "status": {
227 "state": "TASK_STATE_COMPLETED"
228 }
229 })
230 }
231
232 #[tokio::test]
233 async fn get_task_success() {
234 let transport = MockTransport::new(task_json());
235 let client = make_client(transport);
236
237 let params = TaskQueryParams {
238 tenant: None,
239 id: "task-1".into(),
240 history_length: None,
241 };
242 let task = client.get_task(params).await.unwrap();
243 assert_eq!(task.id.as_ref(), "task-1");
244 }
245
246 #[tokio::test]
247 async fn list_tasks_success() {
248 let response = serde_json::json!({
249 "tasks": [
250 {
251 "id": "task-1",
252 "contextId": "ctx-1",
253 "status": { "state": "TASK_STATE_COMPLETED" }
254 },
255 {
256 "id": "task-2",
257 "contextId": "ctx-2",
258 "status": { "state": "TASK_STATE_WORKING" }
259 }
260 ]
261 });
262 let transport = MockTransport::new(response);
263 let client = make_client(transport);
264
265 let params = ListTasksParams::default();
266 let result = client.list_tasks(params).await.unwrap();
267 assert_eq!(result.tasks.len(), 2);
268 assert_eq!(result.tasks[0].id.as_ref(), "task-1");
269 }
270
271 #[tokio::test]
272 async fn cancel_task_success() {
273 let transport = MockTransport::new(task_json());
274 let client = make_client(transport);
275
276 let task = client.cancel_task("task-1").await.unwrap();
277 assert_eq!(task.id.as_ref(), "task-1");
278 }
279
280 #[tokio::test]
282 async fn subscribe_to_task_calls_after_interceptor() {
283 use std::sync::atomic::{AtomicUsize, Ordering};
284 use std::sync::Arc;
285
286 use crate::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
287
288 struct StreamingOkTransport;
289
290 impl Transport for StreamingOkTransport {
291 fn send_request<'a>(
292 &'a self,
293 _method: &'a str,
294 _params: serde_json::Value,
295 _extra_headers: &'a HashMap<String, String>,
296 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
297 {
298 Box::pin(async move { Ok(serde_json::Value::Null) })
299 }
300
301 fn send_streaming_request<'a>(
302 &'a self,
303 _method: &'a str,
304 _params: serde_json::Value,
305 _extra_headers: &'a HashMap<String, String>,
306 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
307 Box::pin(async move {
308 let (tx, rx) = tokio::sync::mpsc::channel(8);
309 drop(tx);
310 Ok(EventStream::new(rx))
311 })
312 }
313 }
314
315 struct CountingInterceptor {
316 before_count: Arc<AtomicUsize>,
317 after_count: Arc<AtomicUsize>,
318 }
319
320 impl CallInterceptor for CountingInterceptor {
321 async fn before<'a>(&'a self, _req: &'a mut ClientRequest) -> ClientResult<()> {
322 self.before_count.fetch_add(1, Ordering::SeqCst);
323 Ok(())
324 }
325 async fn after<'a>(&'a self, _resp: &'a ClientResponse) -> ClientResult<()> {
326 self.after_count.fetch_add(1, Ordering::SeqCst);
327 Ok(())
328 }
329 }
330
331 let before = Arc::new(AtomicUsize::new(0));
332 let after = Arc::new(AtomicUsize::new(0));
333 let interceptor = CountingInterceptor {
334 before_count: Arc::clone(&before),
335 after_count: Arc::clone(&after),
336 };
337
338 let client = ClientBuilder::new("http://localhost:8080")
339 .with_custom_transport(StreamingOkTransport)
340 .with_interceptor(interceptor)
341 .build()
342 .expect("build");
343
344 let result = client.subscribe_to_task("task-1").await;
345 assert!(result.is_ok(), "subscribe should succeed");
346 assert_eq!(before.load(Ordering::SeqCst), 1, "before should be called");
347 assert_eq!(
348 after.load(Ordering::SeqCst),
349 1,
350 "after should be called for subscribe streaming"
351 );
352 }
353
354 #[tokio::test]
355 async fn subscribe_to_task_returns_transport_error() {
356 let transport = MockTransport::new(serde_json::Value::Null);
360 let client = make_client(transport);
361
362 let err = client.subscribe_to_task("task-1").await.unwrap_err();
363 assert!(
364 matches!(err, ClientError::Transport(ref msg) if msg.contains("streaming not supported")),
365 "expected Transport error, got {err:?}"
366 );
367 }
368}