Skip to main content

a2a_client/
client.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::jsonrpc::methods;
4use a2a::*;
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use std::sync::Arc;
8
9use crate::middleware::CallInterceptor;
10use crate::transport::{ServiceParams, Transport};
11
12/// High-level A2A client wrapping a transport with middleware.
13pub struct A2AClient<T: Transport> {
14    transport: T,
15    interceptors: Vec<Arc<dyn CallInterceptor>>,
16    default_params: ServiceParams,
17}
18
19impl<T: Transport> A2AClient<T> {
20    pub fn new(transport: T) -> Self {
21        let mut default_params = ServiceParams::new();
22        default_params.insert(SVC_PARAM_VERSION.to_string(), vec![VERSION.to_string()]);
23        A2AClient {
24            transport,
25            interceptors: Vec::new(),
26            default_params,
27        }
28    }
29
30    pub fn with_interceptors(mut self, interceptors: Vec<Arc<dyn CallInterceptor>>) -> Self {
31        self.interceptors = interceptors;
32        self
33    }
34
35    fn params(&self) -> ServiceParams {
36        self.default_params.clone()
37    }
38
39    async fn apply_before(&self, method: &str) -> Result<ServiceParams, A2AError> {
40        let mut params = self.params();
41        for interceptor in &self.interceptors {
42            interceptor.before(method, &mut params).await?;
43        }
44        Ok(params)
45    }
46
47    async fn apply_after(
48        &self,
49        method: &str,
50        result: &Result<(), A2AError>,
51    ) -> Result<(), A2AError> {
52        for interceptor in self.interceptors.iter().rev() {
53            interceptor.after(method, result).await?;
54        }
55        Ok(())
56    }
57
58    async fn finish_call<R>(
59        &self,
60        method: &str,
61        result: Result<R, A2AError>,
62    ) -> Result<R, A2AError> {
63        let status = result.as_ref().map(|_| ()).map_err(Clone::clone);
64        let after_result = self.apply_after(method, &status).await;
65
66        match (result, after_result) {
67            (Ok(value), Ok(())) => Ok(value),
68            (Err(error), _) => Err(error),
69            (Ok(_), Err(error)) => Err(error),
70        }
71    }
72
73    pub async fn send_message(
74        &self,
75        req: &SendMessageRequest,
76    ) -> Result<SendMessageResponse, A2AError> {
77        let params = self.apply_before(methods::SEND_MESSAGE).await?;
78        let result = self.transport.send_message(&params, req).await;
79        self.finish_call(methods::SEND_MESSAGE, result).await
80    }
81
82    pub async fn send_streaming_message(
83        &self,
84        req: &SendMessageRequest,
85    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
86        let params = self.apply_before(methods::SEND_STREAMING_MESSAGE).await?;
87        let result = self.transport.send_streaming_message(&params, req).await;
88        self.finish_call(methods::SEND_STREAMING_MESSAGE, result)
89            .await
90    }
91
92    pub async fn get_task(&self, req: &GetTaskRequest) -> Result<Task, A2AError> {
93        let params = self.apply_before(methods::GET_TASK).await?;
94        let result = self.transport.get_task(&params, req).await;
95        self.finish_call(methods::GET_TASK, result).await
96    }
97
98    pub async fn list_tasks(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
99        let params = self.apply_before(methods::LIST_TASKS).await?;
100        let result = self.transport.list_tasks(&params, req).await;
101        self.finish_call(methods::LIST_TASKS, result).await
102    }
103
104    pub async fn cancel_task(&self, req: &CancelTaskRequest) -> Result<Task, A2AError> {
105        let params = self.apply_before(methods::CANCEL_TASK).await?;
106        let result = self.transport.cancel_task(&params, req).await;
107        self.finish_call(methods::CANCEL_TASK, result).await
108    }
109
110    pub async fn subscribe_to_task(
111        &self,
112        req: &SubscribeToTaskRequest,
113    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
114        let params = self.apply_before(methods::SUBSCRIBE_TO_TASK).await?;
115        let result = self.transport.subscribe_to_task(&params, req).await;
116        self.finish_call(methods::SUBSCRIBE_TO_TASK, result).await
117    }
118
119    pub async fn create_push_config(
120        &self,
121        req: &TaskPushNotificationConfig,
122    ) -> Result<TaskPushNotificationConfig, A2AError> {
123        let params = self.apply_before(methods::CREATE_PUSH_CONFIG).await?;
124        let result = self.transport.create_push_config(&params, req).await;
125        self.finish_call(methods::CREATE_PUSH_CONFIG, result).await
126    }
127
128    pub async fn get_push_config(
129        &self,
130        req: &GetTaskPushNotificationConfigRequest,
131    ) -> Result<TaskPushNotificationConfig, A2AError> {
132        let params = self.apply_before(methods::GET_PUSH_CONFIG).await?;
133        let result = self.transport.get_push_config(&params, req).await;
134        self.finish_call(methods::GET_PUSH_CONFIG, result).await
135    }
136
137    pub async fn list_push_configs(
138        &self,
139        req: &ListTaskPushNotificationConfigsRequest,
140    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
141        let params = self.apply_before(methods::LIST_PUSH_CONFIGS).await?;
142        let result = self.transport.list_push_configs(&params, req).await;
143        self.finish_call(methods::LIST_PUSH_CONFIGS, result).await
144    }
145
146    pub async fn delete_push_config(
147        &self,
148        req: &DeleteTaskPushNotificationConfigRequest,
149    ) -> Result<(), A2AError> {
150        let params = self.apply_before(methods::DELETE_PUSH_CONFIG).await?;
151        let result = self.transport.delete_push_config(&params, req).await;
152        self.finish_call(methods::DELETE_PUSH_CONFIG, result).await
153    }
154
155    pub async fn get_extended_agent_card(
156        &self,
157        req: &GetExtendedAgentCardRequest,
158    ) -> Result<AgentCard, A2AError> {
159        let params = self.apply_before(methods::GET_EXTENDED_AGENT_CARD).await?;
160        let result = self.transport.get_extended_agent_card(&params, req).await;
161        self.finish_call(methods::GET_EXTENDED_AGENT_CARD, result)
162            .await
163    }
164
165    pub async fn destroy(&self) -> Result<(), A2AError> {
166        self.transport.destroy().await
167    }
168}
169
170/// Convenience trait to extract client results.
171#[async_trait]
172pub trait SendMessageExt {
173    async fn send_text(
174        &self,
175        text: impl Into<String> + Send,
176    ) -> Result<SendMessageResponse, A2AError>;
177}
178
179#[async_trait]
180impl<T: Transport> SendMessageExt for A2AClient<T> {
181    async fn send_text(
182        &self,
183        text: impl Into<String> + Send,
184    ) -> Result<SendMessageResponse, A2AError> {
185        let msg = Message::new(Role::User, vec![Part::text(text)]);
186        let req = SendMessageRequest {
187            message: msg,
188            configuration: None,
189            metadata: None,
190            tenant: None,
191        };
192        self.send_message(&req).await
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use a2a::event::StreamResponse;
200    use futures::stream;
201    use std::sync::Mutex;
202
203    #[derive(Default)]
204    struct MockTransportState {
205        calls: Mutex<Vec<(String, ServiceParams)>>,
206        send_message_error: Mutex<Option<A2AError>>,
207    }
208
209    /// Mock transport that returns canned responses.
210    struct MockTransport {
211        state: Arc<MockTransportState>,
212    }
213
214    impl MockTransport {
215        fn new() -> (Self, Arc<MockTransportState>) {
216            let state = Arc::new(MockTransportState::default());
217            (
218                MockTransport {
219                    state: state.clone(),
220                },
221                state,
222            )
223        }
224
225        fn record(&self, method: &str, params: &ServiceParams) {
226            self.state
227                .calls
228                .lock()
229                .unwrap()
230                .push((method.to_string(), params.clone()));
231        }
232    }
233
234    #[async_trait]
235    impl Transport for MockTransport {
236        async fn send_message(
237            &self,
238            params: &ServiceParams,
239            _req: &SendMessageRequest,
240        ) -> Result<SendMessageResponse, A2AError> {
241            self.record(methods::SEND_MESSAGE, params);
242            if let Some(error) = self.state.send_message_error.lock().unwrap().clone() {
243                return Err(error);
244            }
245            Ok(SendMessageResponse::Task(Task {
246                id: "t1".into(),
247                context_id: "c1".into(),
248                status: TaskStatus {
249                    state: TaskState::Completed,
250                    message: None,
251                    timestamp: None,
252                },
253                artifacts: None,
254                history: None,
255                metadata: None,
256            }))
257        }
258
259        async fn send_streaming_message(
260            &self,
261            params: &ServiceParams,
262            _req: &SendMessageRequest,
263        ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264            self.record(methods::SEND_STREAMING_MESSAGE, params);
265            Ok(Box::pin(stream::once(async {
266                Ok(StreamResponse::StatusUpdate(
267                    a2a::event::TaskStatusUpdateEvent {
268                        task_id: "t1".into(),
269                        context_id: "c1".into(),
270                        status: TaskStatus {
271                            state: TaskState::Working,
272                            message: None,
273                            timestamp: None,
274                        },
275                        metadata: None,
276                    },
277                ))
278            })))
279        }
280
281        async fn get_task(
282            &self,
283            params: &ServiceParams,
284            req: &GetTaskRequest,
285        ) -> Result<Task, A2AError> {
286            self.record(methods::GET_TASK, params);
287            Ok(Task {
288                id: req.id.clone(),
289                context_id: "c1".into(),
290                status: TaskStatus {
291                    state: TaskState::Completed,
292                    message: None,
293                    timestamp: None,
294                },
295                artifacts: None,
296                history: None,
297                metadata: None,
298            })
299        }
300
301        async fn list_tasks(
302            &self,
303            params: &ServiceParams,
304            _req: &ListTasksRequest,
305        ) -> Result<ListTasksResponse, A2AError> {
306            self.record(methods::LIST_TASKS, params);
307            Ok(ListTasksResponse {
308                tasks: vec![],
309                next_page_token: String::new(),
310                page_size: 0,
311                total_size: 0,
312            })
313        }
314
315        async fn cancel_task(
316            &self,
317            params: &ServiceParams,
318            req: &CancelTaskRequest,
319        ) -> Result<Task, A2AError> {
320            self.record(methods::CANCEL_TASK, params);
321            Ok(Task {
322                id: req.id.clone(),
323                context_id: "c1".into(),
324                status: TaskStatus {
325                    state: TaskState::Canceled,
326                    message: None,
327                    timestamp: None,
328                },
329                artifacts: None,
330                history: None,
331                metadata: None,
332            })
333        }
334
335        async fn subscribe_to_task(
336            &self,
337            params: &ServiceParams,
338            _req: &SubscribeToTaskRequest,
339        ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
340            self.record(methods::SUBSCRIBE_TO_TASK, params);
341            Ok(Box::pin(stream::empty()))
342        }
343
344        async fn create_push_config(
345            &self,
346            params: &ServiceParams,
347            req: &TaskPushNotificationConfig,
348        ) -> Result<TaskPushNotificationConfig, A2AError> {
349            self.record(methods::CREATE_PUSH_CONFIG, params);
350            Ok(req.clone())
351        }
352
353        async fn get_push_config(
354            &self,
355            params: &ServiceParams,
356            req: &GetTaskPushNotificationConfigRequest,
357        ) -> Result<TaskPushNotificationConfig, A2AError> {
358            self.record(methods::GET_PUSH_CONFIG, params);
359            Ok(TaskPushNotificationConfig {
360                task_id: req.task_id.clone(),
361                url: "http://example.com".into(),
362                id: Some(req.id.clone()),
363                token: None,
364                authentication: None,
365                tenant: None,
366            })
367        }
368
369        async fn list_push_configs(
370            &self,
371            params: &ServiceParams,
372            _req: &ListTaskPushNotificationConfigsRequest,
373        ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
374            self.record(methods::LIST_PUSH_CONFIGS, params);
375            Ok(ListTaskPushNotificationConfigsResponse {
376                configs: vec![],
377                next_page_token: None,
378            })
379        }
380
381        async fn delete_push_config(
382            &self,
383            params: &ServiceParams,
384            _req: &DeleteTaskPushNotificationConfigRequest,
385        ) -> Result<(), A2AError> {
386            self.record(methods::DELETE_PUSH_CONFIG, params);
387            Ok(())
388        }
389
390        async fn get_extended_agent_card(
391            &self,
392            params: &ServiceParams,
393            _req: &GetExtendedAgentCardRequest,
394        ) -> Result<AgentCard, A2AError> {
395            self.record(methods::GET_EXTENDED_AGENT_CARD, params);
396            Ok(AgentCard {
397                name: "Test".into(),
398                description: "Test agent".into(),
399                version: "1.0".into(),
400                supported_interfaces: vec![],
401                capabilities: AgentCapabilities::default(),
402                default_input_modes: vec!["text/plain".into()],
403                default_output_modes: vec!["text/plain".into()],
404                skills: vec![],
405                provider: None,
406                documentation_url: None,
407                icon_url: None,
408                security_schemes: None,
409                security_requirements: None,
410                signatures: None,
411            })
412        }
413
414        async fn destroy(&self) -> Result<(), A2AError> {
415            Ok(())
416        }
417    }
418
419    fn make_client() -> A2AClient<MockTransport> {
420        let (transport, _) = MockTransport::new();
421        A2AClient::new(transport)
422    }
423
424    struct RecordingInterceptor {
425        name: &'static str,
426        events: Arc<Mutex<Vec<String>>>,
427    }
428
429    #[async_trait]
430    impl CallInterceptor for RecordingInterceptor {
431        async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
432            self.events
433                .lock()
434                .unwrap()
435                .push(format!("before:{}", self.name));
436            params
437                .entry("X-Interceptor".to_string())
438                .or_default()
439                .push(self.name.to_string());
440            Ok(())
441        }
442
443        async fn after(
444            &self,
445            _method: &str,
446            result: &Result<(), A2AError>,
447        ) -> Result<(), A2AError> {
448            let status = if result.is_ok() { "ok" } else { "err" };
449            self.events
450                .lock()
451                .unwrap()
452                .push(format!("after:{}:{status}", self.name));
453            Ok(())
454        }
455    }
456
457    #[test]
458    fn test_new_sets_default_params() {
459        let client = make_client();
460        let params = client.params();
461        assert!(params.contains_key(SVC_PARAM_VERSION));
462    }
463
464    #[test]
465    fn test_with_interceptors() {
466        let client = make_client().with_interceptors(vec![]);
467        assert!(client.interceptors.is_empty());
468    }
469
470    #[tokio::test]
471    async fn test_send_message() {
472        let client = make_client();
473        let req = SendMessageRequest {
474            message: Message::new(Role::User, vec![Part::text("hi")]),
475            configuration: None,
476            metadata: None,
477            tenant: None,
478        };
479        let resp = client.send_message(&req).await.unwrap();
480        assert!(matches!(resp, SendMessageResponse::Task(_)));
481    }
482
483    #[tokio::test]
484    async fn test_send_message_applies_interceptors_and_reverses_after_order() {
485        let (transport, state) = MockTransport::new();
486        let events = Arc::new(Mutex::new(Vec::new()));
487        let client = A2AClient::new(transport).with_interceptors(vec![
488            Arc::new(RecordingInterceptor {
489                name: "first",
490                events: events.clone(),
491            }),
492            Arc::new(RecordingInterceptor {
493                name: "second",
494                events: events.clone(),
495            }),
496        ]);
497
498        let req = SendMessageRequest {
499            message: Message::new(Role::User, vec![Part::text("hi")]),
500            configuration: None,
501            metadata: None,
502            tenant: None,
503        };
504
505        client.send_message(&req).await.unwrap();
506
507        let calls = state.calls.lock().unwrap();
508        let params = &calls[0].1;
509        assert_eq!(
510            params.get("X-Interceptor").unwrap(),
511            &vec!["first".to_string(), "second".to_string()]
512        );
513
514        let events = events.lock().unwrap().clone();
515        assert_eq!(
516            events,
517            vec![
518                "before:first".to_string(),
519                "before:second".to_string(),
520                "after:second:ok".to_string(),
521                "after:first:ok".to_string(),
522            ]
523        );
524    }
525
526    #[tokio::test]
527    async fn test_send_message_preserves_transport_error_after_after_hooks() {
528        let (transport, state) = MockTransport::new();
529        *state.send_message_error.lock().unwrap() = Some(A2AError::internal("boom"));
530        let events = Arc::new(Mutex::new(Vec::new()));
531        let client =
532            A2AClient::new(transport).with_interceptors(vec![Arc::new(RecordingInterceptor {
533                name: "only",
534                events: events.clone(),
535            })]);
536
537        let req = SendMessageRequest {
538            message: Message::new(Role::User, vec![Part::text("hi")]),
539            configuration: None,
540            metadata: None,
541            tenant: None,
542        };
543
544        let err = client.send_message(&req).await.unwrap_err();
545        assert_eq!(err.message, "boom");
546
547        let events = events.lock().unwrap().clone();
548        assert_eq!(
549            events,
550            vec!["before:only".to_string(), "after:only:err".to_string(),]
551        );
552    }
553
554    #[tokio::test]
555    async fn test_send_streaming_message() {
556        use futures::StreamExt;
557        let client = make_client();
558        let req = SendMessageRequest {
559            message: Message::new(Role::User, vec![Part::text("hi")]),
560            configuration: None,
561            metadata: None,
562            tenant: None,
563        };
564        let mut stream = client.send_streaming_message(&req).await.unwrap();
565        let item = stream.next().await.unwrap().unwrap();
566        assert!(matches!(item, StreamResponse::StatusUpdate(_)));
567    }
568
569    #[tokio::test]
570    async fn test_get_task() {
571        let client = make_client();
572        let req = GetTaskRequest {
573            id: "t1".into(),
574            history_length: None,
575            tenant: None,
576        };
577        let task = client.get_task(&req).await.unwrap();
578        assert_eq!(task.id, "t1");
579    }
580
581    #[tokio::test]
582    async fn test_list_tasks() {
583        let client = make_client();
584        let req = ListTasksRequest {
585            context_id: None,
586            status: None,
587            page_size: None,
588            page_token: None,
589            history_length: None,
590            status_timestamp_after: None,
591            include_artifacts: None,
592            tenant: None,
593        };
594        let resp = client.list_tasks(&req).await.unwrap();
595        assert!(resp.tasks.is_empty());
596    }
597
598    #[tokio::test]
599    async fn test_cancel_task() {
600        let client = make_client();
601        let req = CancelTaskRequest {
602            id: "t1".into(),
603            metadata: None,
604            tenant: None,
605        };
606        let task = client.cancel_task(&req).await.unwrap();
607        assert_eq!(task.status.state, TaskState::Canceled);
608    }
609
610    #[tokio::test]
611    async fn test_subscribe_to_task() {
612        let client = make_client();
613        let req = SubscribeToTaskRequest {
614            id: "t1".into(),
615            tenant: None,
616        };
617        let _stream = client.subscribe_to_task(&req).await.unwrap();
618    }
619
620    #[tokio::test]
621    async fn test_create_push_config() {
622        let client = make_client();
623        let req = TaskPushNotificationConfig {
624            task_id: "t1".into(),
625            url: "http://example.com".into(),
626            id: None,
627            token: None,
628            authentication: None,
629            tenant: None,
630        };
631        let resp = client.create_push_config(&req).await.unwrap();
632        assert_eq!(resp.task_id, "t1");
633    }
634
635    #[tokio::test]
636    async fn test_get_push_config() {
637        let client = make_client();
638        let req = GetTaskPushNotificationConfigRequest {
639            task_id: "t1".into(),
640            id: "cfg1".into(),
641            tenant: None,
642        };
643        let resp = client.get_push_config(&req).await.unwrap();
644        assert_eq!(resp.id, Some("cfg1".into()));
645    }
646
647    #[tokio::test]
648    async fn test_list_push_configs() {
649        let client = make_client();
650        let req = ListTaskPushNotificationConfigsRequest {
651            task_id: "t1".into(),
652            page_size: None,
653            page_token: None,
654            tenant: None,
655        };
656        let resp = client.list_push_configs(&req).await.unwrap();
657        assert!(resp.configs.is_empty());
658    }
659
660    #[tokio::test]
661    async fn test_delete_push_config() {
662        let client = make_client();
663        let req = DeleteTaskPushNotificationConfigRequest {
664            task_id: "t1".into(),
665            id: "cfg1".into(),
666            tenant: None,
667        };
668        client.delete_push_config(&req).await.unwrap();
669    }
670
671    #[tokio::test]
672    async fn test_get_extended_agent_card() {
673        let client = make_client();
674        let req = GetExtendedAgentCardRequest { tenant: None };
675        let card = client.get_extended_agent_card(&req).await.unwrap();
676        assert_eq!(card.name, "Test");
677    }
678
679    #[tokio::test]
680    async fn test_destroy() {
681        let client = make_client();
682        client.destroy().await.unwrap();
683    }
684}