1use a2a::jsonrpc::methods;
4use a2a::*;
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use std::sync::Arc;
8
9use crate::middleware::CallInterceptor;
10use crate::transport::{ServiceParams, Transport};
11
12pub struct A2AClient<T: Transport> {
14 transport: T,
15 interceptors: Vec<Arc<dyn CallInterceptor>>,
16 default_params: ServiceParams,
17}
18
19impl<T: Transport> A2AClient<T> {
20 pub fn new(transport: T) -> Self {
21 let mut default_params = ServiceParams::new();
22 default_params.insert(SVC_PARAM_VERSION.to_string(), vec![VERSION.to_string()]);
23 A2AClient {
24 transport,
25 interceptors: Vec::new(),
26 default_params,
27 }
28 }
29
30 pub fn with_interceptors(mut self, interceptors: Vec<Arc<dyn CallInterceptor>>) -> Self {
31 self.interceptors = interceptors;
32 self
33 }
34
35 fn params(&self) -> ServiceParams {
36 self.default_params.clone()
37 }
38
39 async fn apply_before(&self, method: &str) -> Result<ServiceParams, A2AError> {
40 let mut params = self.params();
41 for interceptor in &self.interceptors {
42 interceptor.before(method, &mut params).await?;
43 }
44 Ok(params)
45 }
46
47 async fn apply_after(
48 &self,
49 method: &str,
50 result: &Result<(), A2AError>,
51 ) -> Result<(), A2AError> {
52 for interceptor in self.interceptors.iter().rev() {
53 interceptor.after(method, result).await?;
54 }
55 Ok(())
56 }
57
58 async fn finish_call<R>(
59 &self,
60 method: &str,
61 result: Result<R, A2AError>,
62 ) -> Result<R, A2AError> {
63 let status = result.as_ref().map(|_| ()).map_err(Clone::clone);
64 let after_result = self.apply_after(method, &status).await;
65
66 match (result, after_result) {
67 (Ok(value), Ok(())) => Ok(value),
68 (Err(error), _) => Err(error),
69 (Ok(_), Err(error)) => Err(error),
70 }
71 }
72
73 pub async fn send_message(
74 &self,
75 req: &SendMessageRequest,
76 ) -> Result<SendMessageResponse, A2AError> {
77 let params = self.apply_before(methods::SEND_MESSAGE).await?;
78 let result = self.transport.send_message(¶ms, req).await;
79 self.finish_call(methods::SEND_MESSAGE, result).await
80 }
81
82 pub async fn send_streaming_message(
83 &self,
84 req: &SendMessageRequest,
85 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
86 let params = self.apply_before(methods::SEND_STREAMING_MESSAGE).await?;
87 let result = self.transport.send_streaming_message(¶ms, req).await;
88 self.finish_call(methods::SEND_STREAMING_MESSAGE, result)
89 .await
90 }
91
92 pub async fn get_task(&self, req: &GetTaskRequest) -> Result<Task, A2AError> {
93 let params = self.apply_before(methods::GET_TASK).await?;
94 let result = self.transport.get_task(¶ms, req).await;
95 self.finish_call(methods::GET_TASK, result).await
96 }
97
98 pub async fn list_tasks(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
99 let params = self.apply_before(methods::LIST_TASKS).await?;
100 let result = self.transport.list_tasks(¶ms, req).await;
101 self.finish_call(methods::LIST_TASKS, result).await
102 }
103
104 pub async fn cancel_task(&self, req: &CancelTaskRequest) -> Result<Task, A2AError> {
105 let params = self.apply_before(methods::CANCEL_TASK).await?;
106 let result = self.transport.cancel_task(¶ms, req).await;
107 self.finish_call(methods::CANCEL_TASK, result).await
108 }
109
110 pub async fn subscribe_to_task(
111 &self,
112 req: &SubscribeToTaskRequest,
113 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
114 let params = self.apply_before(methods::SUBSCRIBE_TO_TASK).await?;
115 let result = self.transport.subscribe_to_task(¶ms, req).await;
116 self.finish_call(methods::SUBSCRIBE_TO_TASK, result).await
117 }
118
119 pub async fn create_push_config(
120 &self,
121 req: &TaskPushNotificationConfig,
122 ) -> Result<TaskPushNotificationConfig, A2AError> {
123 let params = self.apply_before(methods::CREATE_PUSH_CONFIG).await?;
124 let result = self.transport.create_push_config(¶ms, req).await;
125 self.finish_call(methods::CREATE_PUSH_CONFIG, result).await
126 }
127
128 pub async fn get_push_config(
129 &self,
130 req: &GetTaskPushNotificationConfigRequest,
131 ) -> Result<TaskPushNotificationConfig, A2AError> {
132 let params = self.apply_before(methods::GET_PUSH_CONFIG).await?;
133 let result = self.transport.get_push_config(¶ms, req).await;
134 self.finish_call(methods::GET_PUSH_CONFIG, result).await
135 }
136
137 pub async fn list_push_configs(
138 &self,
139 req: &ListTaskPushNotificationConfigsRequest,
140 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
141 let params = self.apply_before(methods::LIST_PUSH_CONFIGS).await?;
142 let result = self.transport.list_push_configs(¶ms, req).await;
143 self.finish_call(methods::LIST_PUSH_CONFIGS, result).await
144 }
145
146 pub async fn delete_push_config(
147 &self,
148 req: &DeleteTaskPushNotificationConfigRequest,
149 ) -> Result<(), A2AError> {
150 let params = self.apply_before(methods::DELETE_PUSH_CONFIG).await?;
151 let result = self.transport.delete_push_config(¶ms, req).await;
152 self.finish_call(methods::DELETE_PUSH_CONFIG, result).await
153 }
154
155 pub async fn get_extended_agent_card(
156 &self,
157 req: &GetExtendedAgentCardRequest,
158 ) -> Result<AgentCard, A2AError> {
159 let params = self.apply_before(methods::GET_EXTENDED_AGENT_CARD).await?;
160 let result = self.transport.get_extended_agent_card(¶ms, req).await;
161 self.finish_call(methods::GET_EXTENDED_AGENT_CARD, result)
162 .await
163 }
164
165 pub async fn destroy(&self) -> Result<(), A2AError> {
166 self.transport.destroy().await
167 }
168}
169
170#[async_trait]
172pub trait SendMessageExt {
173 async fn send_text(
174 &self,
175 text: impl Into<String> + Send,
176 ) -> Result<SendMessageResponse, A2AError>;
177}
178
179#[async_trait]
180impl<T: Transport> SendMessageExt for A2AClient<T> {
181 async fn send_text(
182 &self,
183 text: impl Into<String> + Send,
184 ) -> Result<SendMessageResponse, A2AError> {
185 let msg = Message::new(Role::User, vec![Part::text(text)]);
186 let req = SendMessageRequest {
187 message: msg,
188 configuration: None,
189 metadata: None,
190 tenant: None,
191 };
192 self.send_message(&req).await
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use a2a::event::StreamResponse;
200 use futures::stream;
201 use std::sync::Mutex;
202
203 #[derive(Default)]
204 struct MockTransportState {
205 calls: Mutex<Vec<(String, ServiceParams)>>,
206 send_message_error: Mutex<Option<A2AError>>,
207 }
208
209 struct MockTransport {
211 state: Arc<MockTransportState>,
212 }
213
214 impl MockTransport {
215 fn new() -> (Self, Arc<MockTransportState>) {
216 let state = Arc::new(MockTransportState::default());
217 (
218 MockTransport {
219 state: state.clone(),
220 },
221 state,
222 )
223 }
224
225 fn record(&self, method: &str, params: &ServiceParams) {
226 self.state
227 .calls
228 .lock()
229 .unwrap()
230 .push((method.to_string(), params.clone()));
231 }
232 }
233
234 #[async_trait]
235 impl Transport for MockTransport {
236 async fn send_message(
237 &self,
238 params: &ServiceParams,
239 _req: &SendMessageRequest,
240 ) -> Result<SendMessageResponse, A2AError> {
241 self.record(methods::SEND_MESSAGE, params);
242 if let Some(error) = self.state.send_message_error.lock().unwrap().clone() {
243 return Err(error);
244 }
245 Ok(SendMessageResponse::Task(Task {
246 id: "t1".into(),
247 context_id: "c1".into(),
248 status: TaskStatus {
249 state: TaskState::Completed,
250 message: None,
251 timestamp: None,
252 },
253 artifacts: None,
254 history: None,
255 metadata: None,
256 }))
257 }
258
259 async fn send_streaming_message(
260 &self,
261 params: &ServiceParams,
262 _req: &SendMessageRequest,
263 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264 self.record(methods::SEND_STREAMING_MESSAGE, params);
265 Ok(Box::pin(stream::once(async {
266 Ok(StreamResponse::StatusUpdate(
267 a2a::event::TaskStatusUpdateEvent {
268 task_id: "t1".into(),
269 context_id: "c1".into(),
270 status: TaskStatus {
271 state: TaskState::Working,
272 message: None,
273 timestamp: None,
274 },
275 metadata: None,
276 },
277 ))
278 })))
279 }
280
281 async fn get_task(
282 &self,
283 params: &ServiceParams,
284 req: &GetTaskRequest,
285 ) -> Result<Task, A2AError> {
286 self.record(methods::GET_TASK, params);
287 Ok(Task {
288 id: req.id.clone(),
289 context_id: "c1".into(),
290 status: TaskStatus {
291 state: TaskState::Completed,
292 message: None,
293 timestamp: None,
294 },
295 artifacts: None,
296 history: None,
297 metadata: None,
298 })
299 }
300
301 async fn list_tasks(
302 &self,
303 params: &ServiceParams,
304 _req: &ListTasksRequest,
305 ) -> Result<ListTasksResponse, A2AError> {
306 self.record(methods::LIST_TASKS, params);
307 Ok(ListTasksResponse {
308 tasks: vec![],
309 next_page_token: String::new(),
310 page_size: 0,
311 total_size: 0,
312 })
313 }
314
315 async fn cancel_task(
316 &self,
317 params: &ServiceParams,
318 req: &CancelTaskRequest,
319 ) -> Result<Task, A2AError> {
320 self.record(methods::CANCEL_TASK, params);
321 Ok(Task {
322 id: req.id.clone(),
323 context_id: "c1".into(),
324 status: TaskStatus {
325 state: TaskState::Canceled,
326 message: None,
327 timestamp: None,
328 },
329 artifacts: None,
330 history: None,
331 metadata: None,
332 })
333 }
334
335 async fn subscribe_to_task(
336 &self,
337 params: &ServiceParams,
338 _req: &SubscribeToTaskRequest,
339 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
340 self.record(methods::SUBSCRIBE_TO_TASK, params);
341 Ok(Box::pin(stream::empty()))
342 }
343
344 async fn create_push_config(
345 &self,
346 params: &ServiceParams,
347 req: &TaskPushNotificationConfig,
348 ) -> Result<TaskPushNotificationConfig, A2AError> {
349 self.record(methods::CREATE_PUSH_CONFIG, params);
350 Ok(req.clone())
351 }
352
353 async fn get_push_config(
354 &self,
355 params: &ServiceParams,
356 req: &GetTaskPushNotificationConfigRequest,
357 ) -> Result<TaskPushNotificationConfig, A2AError> {
358 self.record(methods::GET_PUSH_CONFIG, params);
359 Ok(TaskPushNotificationConfig {
360 task_id: req.task_id.clone(),
361 url: "http://example.com".into(),
362 id: Some(req.id.clone()),
363 token: None,
364 authentication: None,
365 tenant: None,
366 })
367 }
368
369 async fn list_push_configs(
370 &self,
371 params: &ServiceParams,
372 _req: &ListTaskPushNotificationConfigsRequest,
373 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
374 self.record(methods::LIST_PUSH_CONFIGS, params);
375 Ok(ListTaskPushNotificationConfigsResponse {
376 configs: vec![],
377 next_page_token: None,
378 })
379 }
380
381 async fn delete_push_config(
382 &self,
383 params: &ServiceParams,
384 _req: &DeleteTaskPushNotificationConfigRequest,
385 ) -> Result<(), A2AError> {
386 self.record(methods::DELETE_PUSH_CONFIG, params);
387 Ok(())
388 }
389
390 async fn get_extended_agent_card(
391 &self,
392 params: &ServiceParams,
393 _req: &GetExtendedAgentCardRequest,
394 ) -> Result<AgentCard, A2AError> {
395 self.record(methods::GET_EXTENDED_AGENT_CARD, params);
396 Ok(AgentCard {
397 name: "Test".into(),
398 description: "Test agent".into(),
399 version: "1.0".into(),
400 supported_interfaces: vec![],
401 capabilities: AgentCapabilities::default(),
402 default_input_modes: vec!["text/plain".into()],
403 default_output_modes: vec!["text/plain".into()],
404 skills: vec![],
405 provider: None,
406 documentation_url: None,
407 icon_url: None,
408 security_schemes: None,
409 security_requirements: None,
410 signatures: None,
411 })
412 }
413
414 async fn destroy(&self) -> Result<(), A2AError> {
415 Ok(())
416 }
417 }
418
419 fn make_client() -> A2AClient<MockTransport> {
420 let (transport, _) = MockTransport::new();
421 A2AClient::new(transport)
422 }
423
424 struct RecordingInterceptor {
425 name: &'static str,
426 events: Arc<Mutex<Vec<String>>>,
427 }
428
429 #[async_trait]
430 impl CallInterceptor for RecordingInterceptor {
431 async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
432 self.events
433 .lock()
434 .unwrap()
435 .push(format!("before:{}", self.name));
436 params
437 .entry("X-Interceptor".to_string())
438 .or_default()
439 .push(self.name.to_string());
440 Ok(())
441 }
442
443 async fn after(
444 &self,
445 _method: &str,
446 result: &Result<(), A2AError>,
447 ) -> Result<(), A2AError> {
448 let status = if result.is_ok() { "ok" } else { "err" };
449 self.events
450 .lock()
451 .unwrap()
452 .push(format!("after:{}:{status}", self.name));
453 Ok(())
454 }
455 }
456
457 #[test]
458 fn test_new_sets_default_params() {
459 let client = make_client();
460 let params = client.params();
461 assert!(params.contains_key(SVC_PARAM_VERSION));
462 }
463
464 #[test]
465 fn test_with_interceptors() {
466 let client = make_client().with_interceptors(vec![]);
467 assert!(client.interceptors.is_empty());
468 }
469
470 #[tokio::test]
471 async fn test_send_message() {
472 let client = make_client();
473 let req = SendMessageRequest {
474 message: Message::new(Role::User, vec![Part::text("hi")]),
475 configuration: None,
476 metadata: None,
477 tenant: None,
478 };
479 let resp = client.send_message(&req).await.unwrap();
480 assert!(matches!(resp, SendMessageResponse::Task(_)));
481 }
482
483 #[tokio::test]
484 async fn test_send_message_applies_interceptors_and_reverses_after_order() {
485 let (transport, state) = MockTransport::new();
486 let events = Arc::new(Mutex::new(Vec::new()));
487 let client = A2AClient::new(transport).with_interceptors(vec![
488 Arc::new(RecordingInterceptor {
489 name: "first",
490 events: events.clone(),
491 }),
492 Arc::new(RecordingInterceptor {
493 name: "second",
494 events: events.clone(),
495 }),
496 ]);
497
498 let req = SendMessageRequest {
499 message: Message::new(Role::User, vec![Part::text("hi")]),
500 configuration: None,
501 metadata: None,
502 tenant: None,
503 };
504
505 client.send_message(&req).await.unwrap();
506
507 let calls = state.calls.lock().unwrap();
508 let params = &calls[0].1;
509 assert_eq!(
510 params.get("X-Interceptor").unwrap(),
511 &vec!["first".to_string(), "second".to_string()]
512 );
513
514 let events = events.lock().unwrap().clone();
515 assert_eq!(
516 events,
517 vec![
518 "before:first".to_string(),
519 "before:second".to_string(),
520 "after:second:ok".to_string(),
521 "after:first:ok".to_string(),
522 ]
523 );
524 }
525
526 #[tokio::test]
527 async fn test_send_message_preserves_transport_error_after_after_hooks() {
528 let (transport, state) = MockTransport::new();
529 *state.send_message_error.lock().unwrap() = Some(A2AError::internal("boom"));
530 let events = Arc::new(Mutex::new(Vec::new()));
531 let client =
532 A2AClient::new(transport).with_interceptors(vec![Arc::new(RecordingInterceptor {
533 name: "only",
534 events: events.clone(),
535 })]);
536
537 let req = SendMessageRequest {
538 message: Message::new(Role::User, vec![Part::text("hi")]),
539 configuration: None,
540 metadata: None,
541 tenant: None,
542 };
543
544 let err = client.send_message(&req).await.unwrap_err();
545 assert_eq!(err.message, "boom");
546
547 let events = events.lock().unwrap().clone();
548 assert_eq!(
549 events,
550 vec!["before:only".to_string(), "after:only:err".to_string(),]
551 );
552 }
553
554 #[tokio::test]
555 async fn test_send_streaming_message() {
556 use futures::StreamExt;
557 let client = make_client();
558 let req = SendMessageRequest {
559 message: Message::new(Role::User, vec![Part::text("hi")]),
560 configuration: None,
561 metadata: None,
562 tenant: None,
563 };
564 let mut stream = client.send_streaming_message(&req).await.unwrap();
565 let item = stream.next().await.unwrap().unwrap();
566 assert!(matches!(item, StreamResponse::StatusUpdate(_)));
567 }
568
569 #[tokio::test]
570 async fn test_get_task() {
571 let client = make_client();
572 let req = GetTaskRequest {
573 id: "t1".into(),
574 history_length: None,
575 tenant: None,
576 };
577 let task = client.get_task(&req).await.unwrap();
578 assert_eq!(task.id, "t1");
579 }
580
581 #[tokio::test]
582 async fn test_list_tasks() {
583 let client = make_client();
584 let req = ListTasksRequest {
585 context_id: None,
586 status: None,
587 page_size: None,
588 page_token: None,
589 history_length: None,
590 status_timestamp_after: None,
591 include_artifacts: None,
592 tenant: None,
593 };
594 let resp = client.list_tasks(&req).await.unwrap();
595 assert!(resp.tasks.is_empty());
596 }
597
598 #[tokio::test]
599 async fn test_cancel_task() {
600 let client = make_client();
601 let req = CancelTaskRequest {
602 id: "t1".into(),
603 metadata: None,
604 tenant: None,
605 };
606 let task = client.cancel_task(&req).await.unwrap();
607 assert_eq!(task.status.state, TaskState::Canceled);
608 }
609
610 #[tokio::test]
611 async fn test_subscribe_to_task() {
612 let client = make_client();
613 let req = SubscribeToTaskRequest {
614 id: "t1".into(),
615 tenant: None,
616 };
617 let _stream = client.subscribe_to_task(&req).await.unwrap();
618 }
619
620 #[tokio::test]
621 async fn test_create_push_config() {
622 let client = make_client();
623 let req = TaskPushNotificationConfig {
624 task_id: "t1".into(),
625 url: "http://example.com".into(),
626 id: None,
627 token: None,
628 authentication: None,
629 tenant: None,
630 };
631 let resp = client.create_push_config(&req).await.unwrap();
632 assert_eq!(resp.task_id, "t1");
633 }
634
635 #[tokio::test]
636 async fn test_get_push_config() {
637 let client = make_client();
638 let req = GetTaskPushNotificationConfigRequest {
639 task_id: "t1".into(),
640 id: "cfg1".into(),
641 tenant: None,
642 };
643 let resp = client.get_push_config(&req).await.unwrap();
644 assert_eq!(resp.id, Some("cfg1".into()));
645 }
646
647 #[tokio::test]
648 async fn test_list_push_configs() {
649 let client = make_client();
650 let req = ListTaskPushNotificationConfigsRequest {
651 task_id: "t1".into(),
652 page_size: None,
653 page_token: None,
654 tenant: None,
655 };
656 let resp = client.list_push_configs(&req).await.unwrap();
657 assert!(resp.configs.is_empty());
658 }
659
660 #[tokio::test]
661 async fn test_delete_push_config() {
662 let client = make_client();
663 let req = DeleteTaskPushNotificationConfigRequest {
664 task_id: "t1".into(),
665 id: "cfg1".into(),
666 tenant: None,
667 };
668 client.delete_push_config(&req).await.unwrap();
669 }
670
671 #[tokio::test]
672 async fn test_get_extended_agent_card() {
673 let client = make_client();
674 let req = GetExtendedAgentCardRequest { tenant: None };
675 let card = client.get_extended_agent_card(&req).await.unwrap();
676 assert_eq!(card.name, "Test");
677 }
678
679 #[tokio::test]
680 async fn test_destroy() {
681 let client = make_client();
682 client.destroy().await.unwrap();
683 }
684}