1use a2a::jsonrpc::methods;
4use a2a::*;
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use std::sync::Arc;
8
9use crate::middleware::CallInterceptor;
10use crate::transport::{ServiceParams, Transport};
11
12pub struct A2AClient<T: Transport> {
14 transport: T,
15 interceptors: Vec<Arc<dyn CallInterceptor>>,
16 default_params: ServiceParams,
17}
18
19impl<T: Transport> A2AClient<T> {
20 pub fn new(transport: T) -> Self {
21 let mut default_params = ServiceParams::new();
22 default_params.insert(SVC_PARAM_VERSION.to_string(), vec![VERSION.to_string()]);
23 A2AClient {
24 transport,
25 interceptors: Vec::new(),
26 default_params,
27 }
28 }
29
30 pub fn with_interceptors(mut self, interceptors: Vec<Arc<dyn CallInterceptor>>) -> Self {
31 self.interceptors = interceptors;
32 self
33 }
34
35 fn params(&self) -> ServiceParams {
36 self.default_params.clone()
37 }
38
39 async fn apply_before(&self, method: &str) -> Result<ServiceParams, A2AError> {
40 let mut params = self.params();
41 for interceptor in &self.interceptors {
42 interceptor.before(method, &mut params).await?;
43 }
44 Ok(params)
45 }
46
47 async fn apply_after(
48 &self,
49 method: &str,
50 result: &Result<(), A2AError>,
51 ) -> Result<(), A2AError> {
52 for interceptor in self.interceptors.iter().rev() {
53 interceptor.after(method, result).await?;
54 }
55 Ok(())
56 }
57
58 async fn finish_call<R>(
59 &self,
60 method: &str,
61 result: Result<R, A2AError>,
62 ) -> Result<R, A2AError> {
63 let status = result.as_ref().map(|_| ()).map_err(Clone::clone);
64 let after_result = self.apply_after(method, &status).await;
65
66 match (result, after_result) {
67 (Ok(value), Ok(())) => Ok(value),
68 (Err(error), _) => Err(error),
69 (Ok(_), Err(error)) => Err(error),
70 }
71 }
72
73 pub async fn send_message(
74 &self,
75 req: &SendMessageRequest,
76 ) -> Result<SendMessageResponse, A2AError> {
77 let params = self.apply_before(methods::SEND_MESSAGE).await?;
78 let result = self.transport.send_message(¶ms, req).await;
79 self.finish_call(methods::SEND_MESSAGE, result).await
80 }
81
82 pub async fn send_streaming_message(
83 &self,
84 req: &SendMessageRequest,
85 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
86 let params = self.apply_before(methods::SEND_STREAMING_MESSAGE).await?;
87 let result = self.transport.send_streaming_message(¶ms, req).await;
88 self.finish_call(methods::SEND_STREAMING_MESSAGE, result)
89 .await
90 }
91
92 pub async fn get_task(&self, req: &GetTaskRequest) -> Result<Task, A2AError> {
93 let params = self.apply_before(methods::GET_TASK).await?;
94 let result = self.transport.get_task(¶ms, req).await;
95 self.finish_call(methods::GET_TASK, result).await
96 }
97
98 pub async fn list_tasks(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
99 let params = self.apply_before(methods::LIST_TASKS).await?;
100 let result = self.transport.list_tasks(¶ms, req).await;
101 self.finish_call(methods::LIST_TASKS, result).await
102 }
103
104 pub async fn cancel_task(&self, req: &CancelTaskRequest) -> Result<Task, A2AError> {
105 let params = self.apply_before(methods::CANCEL_TASK).await?;
106 let result = self.transport.cancel_task(¶ms, req).await;
107 self.finish_call(methods::CANCEL_TASK, result).await
108 }
109
110 pub async fn subscribe_to_task(
111 &self,
112 req: &SubscribeToTaskRequest,
113 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
114 let params = self.apply_before(methods::SUBSCRIBE_TO_TASK).await?;
115 let result = self.transport.subscribe_to_task(¶ms, req).await;
116 self.finish_call(methods::SUBSCRIBE_TO_TASK, result).await
117 }
118
119 pub async fn create_push_config(
120 &self,
121 req: &CreateTaskPushNotificationConfigRequest,
122 ) -> Result<TaskPushNotificationConfig, A2AError> {
123 let params = self.apply_before(methods::CREATE_PUSH_CONFIG).await?;
124 let result = self.transport.create_push_config(¶ms, req).await;
125 self.finish_call(methods::CREATE_PUSH_CONFIG, result).await
126 }
127
128 pub async fn get_push_config(
129 &self,
130 req: &GetTaskPushNotificationConfigRequest,
131 ) -> Result<TaskPushNotificationConfig, A2AError> {
132 let params = self.apply_before(methods::GET_PUSH_CONFIG).await?;
133 let result = self.transport.get_push_config(¶ms, req).await;
134 self.finish_call(methods::GET_PUSH_CONFIG, result).await
135 }
136
137 pub async fn list_push_configs(
138 &self,
139 req: &ListTaskPushNotificationConfigsRequest,
140 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
141 let params = self.apply_before(methods::LIST_PUSH_CONFIGS).await?;
142 let result = self.transport.list_push_configs(¶ms, req).await;
143 self.finish_call(methods::LIST_PUSH_CONFIGS, result).await
144 }
145
146 pub async fn delete_push_config(
147 &self,
148 req: &DeleteTaskPushNotificationConfigRequest,
149 ) -> Result<(), A2AError> {
150 let params = self.apply_before(methods::DELETE_PUSH_CONFIG).await?;
151 let result = self.transport.delete_push_config(¶ms, req).await;
152 self.finish_call(methods::DELETE_PUSH_CONFIG, result).await
153 }
154
155 pub async fn get_extended_agent_card(
156 &self,
157 req: &GetExtendedAgentCardRequest,
158 ) -> Result<AgentCard, A2AError> {
159 let params = self.apply_before(methods::GET_EXTENDED_AGENT_CARD).await?;
160 let result = self.transport.get_extended_agent_card(¶ms, req).await;
161 self.finish_call(methods::GET_EXTENDED_AGENT_CARD, result)
162 .await
163 }
164
165 pub async fn destroy(&self) -> Result<(), A2AError> {
166 self.transport.destroy().await
167 }
168}
169
170#[async_trait]
172pub trait SendMessageExt {
173 async fn send_text(
174 &self,
175 text: impl Into<String> + Send,
176 ) -> Result<SendMessageResponse, A2AError>;
177}
178
179#[async_trait]
180impl<T: Transport> SendMessageExt for A2AClient<T> {
181 async fn send_text(
182 &self,
183 text: impl Into<String> + Send,
184 ) -> Result<SendMessageResponse, A2AError> {
185 let msg = Message::new(Role::User, vec![Part::text(text)]);
186 let req = SendMessageRequest {
187 message: msg,
188 configuration: None,
189 metadata: None,
190 tenant: None,
191 };
192 self.send_message(&req).await
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use a2a::event::StreamResponse;
200 use futures::stream;
201 use std::sync::Mutex;
202
203 #[derive(Default)]
204 struct MockTransportState {
205 calls: Mutex<Vec<(String, ServiceParams)>>,
206 send_message_error: Mutex<Option<A2AError>>,
207 }
208
209 struct MockTransport {
211 state: Arc<MockTransportState>,
212 }
213
214 impl MockTransport {
215 fn new() -> (Self, Arc<MockTransportState>) {
216 let state = Arc::new(MockTransportState::default());
217 (
218 MockTransport {
219 state: state.clone(),
220 },
221 state,
222 )
223 }
224
225 fn record(&self, method: &str, params: &ServiceParams) {
226 self.state
227 .calls
228 .lock()
229 .unwrap()
230 .push((method.to_string(), params.clone()));
231 }
232 }
233
234 #[async_trait]
235 impl Transport for MockTransport {
236 async fn send_message(
237 &self,
238 params: &ServiceParams,
239 _req: &SendMessageRequest,
240 ) -> Result<SendMessageResponse, A2AError> {
241 self.record(methods::SEND_MESSAGE, params);
242 if let Some(error) = self.state.send_message_error.lock().unwrap().clone() {
243 return Err(error);
244 }
245 Ok(SendMessageResponse::Task(Task {
246 id: "t1".into(),
247 context_id: "c1".into(),
248 status: TaskStatus {
249 state: TaskState::Completed,
250 message: None,
251 timestamp: None,
252 },
253 artifacts: None,
254 history: None,
255 metadata: None,
256 }))
257 }
258
259 async fn send_streaming_message(
260 &self,
261 params: &ServiceParams,
262 _req: &SendMessageRequest,
263 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264 self.record(methods::SEND_STREAMING_MESSAGE, params);
265 Ok(Box::pin(stream::once(async {
266 Ok(StreamResponse::StatusUpdate(
267 a2a::event::TaskStatusUpdateEvent {
268 task_id: "t1".into(),
269 context_id: "c1".into(),
270 status: TaskStatus {
271 state: TaskState::Working,
272 message: None,
273 timestamp: None,
274 },
275 metadata: None,
276 },
277 ))
278 })))
279 }
280
281 async fn get_task(
282 &self,
283 params: &ServiceParams,
284 req: &GetTaskRequest,
285 ) -> Result<Task, A2AError> {
286 self.record(methods::GET_TASK, params);
287 Ok(Task {
288 id: req.id.clone(),
289 context_id: "c1".into(),
290 status: TaskStatus {
291 state: TaskState::Completed,
292 message: None,
293 timestamp: None,
294 },
295 artifacts: None,
296 history: None,
297 metadata: None,
298 })
299 }
300
301 async fn list_tasks(
302 &self,
303 params: &ServiceParams,
304 _req: &ListTasksRequest,
305 ) -> Result<ListTasksResponse, A2AError> {
306 self.record(methods::LIST_TASKS, params);
307 Ok(ListTasksResponse {
308 tasks: vec![],
309 next_page_token: String::new(),
310 page_size: 0,
311 total_size: 0,
312 })
313 }
314
315 async fn cancel_task(
316 &self,
317 params: &ServiceParams,
318 req: &CancelTaskRequest,
319 ) -> Result<Task, A2AError> {
320 self.record(methods::CANCEL_TASK, params);
321 Ok(Task {
322 id: req.id.clone(),
323 context_id: "c1".into(),
324 status: TaskStatus {
325 state: TaskState::Canceled,
326 message: None,
327 timestamp: None,
328 },
329 artifacts: None,
330 history: None,
331 metadata: None,
332 })
333 }
334
335 async fn subscribe_to_task(
336 &self,
337 params: &ServiceParams,
338 _req: &SubscribeToTaskRequest,
339 ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
340 self.record(methods::SUBSCRIBE_TO_TASK, params);
341 Ok(Box::pin(stream::empty()))
342 }
343
344 async fn create_push_config(
345 &self,
346 params: &ServiceParams,
347 req: &CreateTaskPushNotificationConfigRequest,
348 ) -> Result<TaskPushNotificationConfig, A2AError> {
349 self.record(methods::CREATE_PUSH_CONFIG, params);
350 Ok(TaskPushNotificationConfig {
351 task_id: req.task_id.clone(),
352 config: req.config.clone(),
353 tenant: None,
354 })
355 }
356
357 async fn get_push_config(
358 &self,
359 params: &ServiceParams,
360 req: &GetTaskPushNotificationConfigRequest,
361 ) -> Result<TaskPushNotificationConfig, A2AError> {
362 self.record(methods::GET_PUSH_CONFIG, params);
363 Ok(TaskPushNotificationConfig {
364 task_id: req.task_id.clone(),
365 config: PushNotificationConfig {
366 url: "http://example.com".into(),
367 id: Some(req.id.clone()),
368 token: None,
369 authentication: None,
370 },
371 tenant: None,
372 })
373 }
374
375 async fn list_push_configs(
376 &self,
377 params: &ServiceParams,
378 _req: &ListTaskPushNotificationConfigsRequest,
379 ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
380 self.record(methods::LIST_PUSH_CONFIGS, params);
381 Ok(ListTaskPushNotificationConfigsResponse {
382 configs: vec![],
383 next_page_token: None,
384 })
385 }
386
387 async fn delete_push_config(
388 &self,
389 params: &ServiceParams,
390 _req: &DeleteTaskPushNotificationConfigRequest,
391 ) -> Result<(), A2AError> {
392 self.record(methods::DELETE_PUSH_CONFIG, params);
393 Ok(())
394 }
395
396 async fn get_extended_agent_card(
397 &self,
398 params: &ServiceParams,
399 _req: &GetExtendedAgentCardRequest,
400 ) -> Result<AgentCard, A2AError> {
401 self.record(methods::GET_EXTENDED_AGENT_CARD, params);
402 Ok(AgentCard {
403 name: "Test".into(),
404 description: "Test agent".into(),
405 version: "1.0".into(),
406 supported_interfaces: vec![],
407 capabilities: AgentCapabilities::default(),
408 default_input_modes: vec!["text/plain".into()],
409 default_output_modes: vec!["text/plain".into()],
410 skills: vec![],
411 provider: None,
412 documentation_url: None,
413 icon_url: None,
414 security_schemes: None,
415 security_requirements: None,
416 signatures: None,
417 })
418 }
419
420 async fn destroy(&self) -> Result<(), A2AError> {
421 Ok(())
422 }
423 }
424
425 fn make_client() -> A2AClient<MockTransport> {
426 let (transport, _) = MockTransport::new();
427 A2AClient::new(transport)
428 }
429
430 struct RecordingInterceptor {
431 name: &'static str,
432 events: Arc<Mutex<Vec<String>>>,
433 }
434
435 #[async_trait]
436 impl CallInterceptor for RecordingInterceptor {
437 async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
438 self.events
439 .lock()
440 .unwrap()
441 .push(format!("before:{}", self.name));
442 params
443 .entry("X-Interceptor".to_string())
444 .or_default()
445 .push(self.name.to_string());
446 Ok(())
447 }
448
449 async fn after(
450 &self,
451 _method: &str,
452 result: &Result<(), A2AError>,
453 ) -> Result<(), A2AError> {
454 let status = if result.is_ok() { "ok" } else { "err" };
455 self.events
456 .lock()
457 .unwrap()
458 .push(format!("after:{}:{status}", self.name));
459 Ok(())
460 }
461 }
462
463 #[test]
464 fn test_new_sets_default_params() {
465 let client = make_client();
466 let params = client.params();
467 assert!(params.contains_key(SVC_PARAM_VERSION));
468 }
469
470 #[test]
471 fn test_with_interceptors() {
472 let client = make_client().with_interceptors(vec![]);
473 assert!(client.interceptors.is_empty());
474 }
475
476 #[tokio::test]
477 async fn test_send_message() {
478 let client = make_client();
479 let req = SendMessageRequest {
480 message: Message::new(Role::User, vec![Part::text("hi")]),
481 configuration: None,
482 metadata: None,
483 tenant: None,
484 };
485 let resp = client.send_message(&req).await.unwrap();
486 assert!(matches!(resp, SendMessageResponse::Task(_)));
487 }
488
489 #[tokio::test]
490 async fn test_send_message_applies_interceptors_and_reverses_after_order() {
491 let (transport, state) = MockTransport::new();
492 let events = Arc::new(Mutex::new(Vec::new()));
493 let client = A2AClient::new(transport).with_interceptors(vec![
494 Arc::new(RecordingInterceptor {
495 name: "first",
496 events: events.clone(),
497 }),
498 Arc::new(RecordingInterceptor {
499 name: "second",
500 events: events.clone(),
501 }),
502 ]);
503
504 let req = SendMessageRequest {
505 message: Message::new(Role::User, vec![Part::text("hi")]),
506 configuration: None,
507 metadata: None,
508 tenant: None,
509 };
510
511 client.send_message(&req).await.unwrap();
512
513 let calls = state.calls.lock().unwrap();
514 let params = &calls[0].1;
515 assert_eq!(
516 params.get("X-Interceptor").unwrap(),
517 &vec!["first".to_string(), "second".to_string()]
518 );
519
520 let events = events.lock().unwrap().clone();
521 assert_eq!(
522 events,
523 vec![
524 "before:first".to_string(),
525 "before:second".to_string(),
526 "after:second:ok".to_string(),
527 "after:first:ok".to_string(),
528 ]
529 );
530 }
531
532 #[tokio::test]
533 async fn test_send_message_preserves_transport_error_after_after_hooks() {
534 let (transport, state) = MockTransport::new();
535 *state.send_message_error.lock().unwrap() = Some(A2AError::internal("boom"));
536 let events = Arc::new(Mutex::new(Vec::new()));
537 let client =
538 A2AClient::new(transport).with_interceptors(vec![Arc::new(RecordingInterceptor {
539 name: "only",
540 events: events.clone(),
541 })]);
542
543 let req = SendMessageRequest {
544 message: Message::new(Role::User, vec![Part::text("hi")]),
545 configuration: None,
546 metadata: None,
547 tenant: None,
548 };
549
550 let err = client.send_message(&req).await.unwrap_err();
551 assert_eq!(err.message, "boom");
552
553 let events = events.lock().unwrap().clone();
554 assert_eq!(
555 events,
556 vec!["before:only".to_string(), "after:only:err".to_string(),]
557 );
558 }
559
560 #[tokio::test]
561 async fn test_send_streaming_message() {
562 use futures::StreamExt;
563 let client = make_client();
564 let req = SendMessageRequest {
565 message: Message::new(Role::User, vec![Part::text("hi")]),
566 configuration: None,
567 metadata: None,
568 tenant: None,
569 };
570 let mut stream = client.send_streaming_message(&req).await.unwrap();
571 let item = stream.next().await.unwrap().unwrap();
572 assert!(matches!(item, StreamResponse::StatusUpdate(_)));
573 }
574
575 #[tokio::test]
576 async fn test_get_task() {
577 let client = make_client();
578 let req = GetTaskRequest {
579 id: "t1".into(),
580 history_length: None,
581 tenant: None,
582 };
583 let task = client.get_task(&req).await.unwrap();
584 assert_eq!(task.id, "t1");
585 }
586
587 #[tokio::test]
588 async fn test_list_tasks() {
589 let client = make_client();
590 let req = ListTasksRequest {
591 context_id: None,
592 status: None,
593 page_size: None,
594 page_token: None,
595 history_length: None,
596 status_timestamp_after: None,
597 include_artifacts: None,
598 tenant: None,
599 };
600 let resp = client.list_tasks(&req).await.unwrap();
601 assert!(resp.tasks.is_empty());
602 }
603
604 #[tokio::test]
605 async fn test_cancel_task() {
606 let client = make_client();
607 let req = CancelTaskRequest {
608 id: "t1".into(),
609 metadata: None,
610 tenant: None,
611 };
612 let task = client.cancel_task(&req).await.unwrap();
613 assert_eq!(task.status.state, TaskState::Canceled);
614 }
615
616 #[tokio::test]
617 async fn test_subscribe_to_task() {
618 let client = make_client();
619 let req = SubscribeToTaskRequest {
620 id: "t1".into(),
621 tenant: None,
622 };
623 let _stream = client.subscribe_to_task(&req).await.unwrap();
624 }
625
626 #[tokio::test]
627 async fn test_create_push_config() {
628 let client = make_client();
629 let req = CreateTaskPushNotificationConfigRequest {
630 task_id: "t1".into(),
631 config: PushNotificationConfig {
632 url: "http://example.com".into(),
633 id: None,
634 token: None,
635 authentication: None,
636 },
637 tenant: None,
638 };
639 let resp = client.create_push_config(&req).await.unwrap();
640 assert_eq!(resp.task_id, "t1");
641 }
642
643 #[tokio::test]
644 async fn test_get_push_config() {
645 let client = make_client();
646 let req = GetTaskPushNotificationConfigRequest {
647 task_id: "t1".into(),
648 id: "cfg1".into(),
649 tenant: None,
650 };
651 let resp = client.get_push_config(&req).await.unwrap();
652 assert_eq!(resp.config.id, Some("cfg1".into()));
653 }
654
655 #[tokio::test]
656 async fn test_list_push_configs() {
657 let client = make_client();
658 let req = ListTaskPushNotificationConfigsRequest {
659 task_id: "t1".into(),
660 page_size: None,
661 page_token: None,
662 tenant: None,
663 };
664 let resp = client.list_push_configs(&req).await.unwrap();
665 assert!(resp.configs.is_empty());
666 }
667
668 #[tokio::test]
669 async fn test_delete_push_config() {
670 let client = make_client();
671 let req = DeleteTaskPushNotificationConfigRequest {
672 task_id: "t1".into(),
673 id: "cfg1".into(),
674 tenant: None,
675 };
676 client.delete_push_config(&req).await.unwrap();
677 }
678
679 #[tokio::test]
680 async fn test_get_extended_agent_card() {
681 let client = make_client();
682 let req = GetExtendedAgentCardRequest { tenant: None };
683 let card = client.get_extended_agent_card(&req).await.unwrap();
684 assert_eq!(card.name, "Test");
685 }
686
687 #[tokio::test]
688 async fn test_destroy() {
689 let client = make_client();
690 client.destroy().await.unwrap();
691 }
692}