Skip to main content

a2a_client/
client.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::jsonrpc::methods;
4use a2a::*;
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use std::sync::Arc;
8
9use crate::middleware::CallInterceptor;
10use crate::transport::{ServiceParams, Transport};
11
12/// High-level A2A client wrapping a transport with middleware.
13pub struct A2AClient<T: Transport> {
14    transport: T,
15    interceptors: Vec<Arc<dyn CallInterceptor>>,
16    default_params: ServiceParams,
17}
18
19impl<T: Transport> A2AClient<T> {
20    pub fn new(transport: T) -> Self {
21        let mut default_params = ServiceParams::new();
22        default_params.insert(SVC_PARAM_VERSION.to_string(), vec![VERSION.to_string()]);
23        A2AClient {
24            transport,
25            interceptors: Vec::new(),
26            default_params,
27        }
28    }
29
30    pub fn with_interceptors(mut self, interceptors: Vec<Arc<dyn CallInterceptor>>) -> Self {
31        self.interceptors = interceptors;
32        self
33    }
34
35    fn params(&self) -> ServiceParams {
36        self.default_params.clone()
37    }
38
39    async fn apply_before(&self, method: &str) -> Result<ServiceParams, A2AError> {
40        let mut params = self.params();
41        for interceptor in &self.interceptors {
42            interceptor.before(method, &mut params).await?;
43        }
44        Ok(params)
45    }
46
47    async fn apply_after(
48        &self,
49        method: &str,
50        result: &Result<(), A2AError>,
51    ) -> Result<(), A2AError> {
52        for interceptor in self.interceptors.iter().rev() {
53            interceptor.after(method, result).await?;
54        }
55        Ok(())
56    }
57
58    async fn finish_call<R>(
59        &self,
60        method: &str,
61        result: Result<R, A2AError>,
62    ) -> Result<R, A2AError> {
63        let status = result.as_ref().map(|_| ()).map_err(Clone::clone);
64        let after_result = self.apply_after(method, &status).await;
65
66        match (result, after_result) {
67            (Ok(value), Ok(())) => Ok(value),
68            (Err(error), _) => Err(error),
69            (Ok(_), Err(error)) => Err(error),
70        }
71    }
72
73    pub async fn send_message(
74        &self,
75        req: &SendMessageRequest,
76    ) -> Result<SendMessageResponse, A2AError> {
77        let params = self.apply_before(methods::SEND_MESSAGE).await?;
78        let result = self.transport.send_message(&params, req).await;
79        self.finish_call(methods::SEND_MESSAGE, result).await
80    }
81
82    pub async fn send_streaming_message(
83        &self,
84        req: &SendMessageRequest,
85    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
86        let params = self.apply_before(methods::SEND_STREAMING_MESSAGE).await?;
87        let result = self.transport.send_streaming_message(&params, req).await;
88        self.finish_call(methods::SEND_STREAMING_MESSAGE, result)
89            .await
90    }
91
92    pub async fn get_task(&self, req: &GetTaskRequest) -> Result<Task, A2AError> {
93        let params = self.apply_before(methods::GET_TASK).await?;
94        let result = self.transport.get_task(&params, req).await;
95        self.finish_call(methods::GET_TASK, result).await
96    }
97
98    pub async fn list_tasks(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
99        let params = self.apply_before(methods::LIST_TASKS).await?;
100        let result = self.transport.list_tasks(&params, req).await;
101        self.finish_call(methods::LIST_TASKS, result).await
102    }
103
104    pub async fn cancel_task(&self, req: &CancelTaskRequest) -> Result<Task, A2AError> {
105        let params = self.apply_before(methods::CANCEL_TASK).await?;
106        let result = self.transport.cancel_task(&params, req).await;
107        self.finish_call(methods::CANCEL_TASK, result).await
108    }
109
110    pub async fn subscribe_to_task(
111        &self,
112        req: &SubscribeToTaskRequest,
113    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
114        let params = self.apply_before(methods::SUBSCRIBE_TO_TASK).await?;
115        let result = self.transport.subscribe_to_task(&params, req).await;
116        self.finish_call(methods::SUBSCRIBE_TO_TASK, result).await
117    }
118
119    pub async fn create_push_config(
120        &self,
121        req: &CreateTaskPushNotificationConfigRequest,
122    ) -> Result<TaskPushNotificationConfig, A2AError> {
123        let params = self.apply_before(methods::CREATE_PUSH_CONFIG).await?;
124        let result = self.transport.create_push_config(&params, req).await;
125        self.finish_call(methods::CREATE_PUSH_CONFIG, result).await
126    }
127
128    pub async fn get_push_config(
129        &self,
130        req: &GetTaskPushNotificationConfigRequest,
131    ) -> Result<TaskPushNotificationConfig, A2AError> {
132        let params = self.apply_before(methods::GET_PUSH_CONFIG).await?;
133        let result = self.transport.get_push_config(&params, req).await;
134        self.finish_call(methods::GET_PUSH_CONFIG, result).await
135    }
136
137    pub async fn list_push_configs(
138        &self,
139        req: &ListTaskPushNotificationConfigsRequest,
140    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
141        let params = self.apply_before(methods::LIST_PUSH_CONFIGS).await?;
142        let result = self.transport.list_push_configs(&params, req).await;
143        self.finish_call(methods::LIST_PUSH_CONFIGS, result).await
144    }
145
146    pub async fn delete_push_config(
147        &self,
148        req: &DeleteTaskPushNotificationConfigRequest,
149    ) -> Result<(), A2AError> {
150        let params = self.apply_before(methods::DELETE_PUSH_CONFIG).await?;
151        let result = self.transport.delete_push_config(&params, req).await;
152        self.finish_call(methods::DELETE_PUSH_CONFIG, result).await
153    }
154
155    pub async fn get_extended_agent_card(
156        &self,
157        req: &GetExtendedAgentCardRequest,
158    ) -> Result<AgentCard, A2AError> {
159        let params = self.apply_before(methods::GET_EXTENDED_AGENT_CARD).await?;
160        let result = self.transport.get_extended_agent_card(&params, req).await;
161        self.finish_call(methods::GET_EXTENDED_AGENT_CARD, result)
162            .await
163    }
164
165    pub async fn destroy(&self) -> Result<(), A2AError> {
166        self.transport.destroy().await
167    }
168}
169
170/// Convenience trait to extract client results.
171#[async_trait]
172pub trait SendMessageExt {
173    async fn send_text(
174        &self,
175        text: impl Into<String> + Send,
176    ) -> Result<SendMessageResponse, A2AError>;
177}
178
179#[async_trait]
180impl<T: Transport> SendMessageExt for A2AClient<T> {
181    async fn send_text(
182        &self,
183        text: impl Into<String> + Send,
184    ) -> Result<SendMessageResponse, A2AError> {
185        let msg = Message::new(Role::User, vec![Part::text(text)]);
186        let req = SendMessageRequest {
187            message: msg,
188            configuration: None,
189            metadata: None,
190            tenant: None,
191        };
192        self.send_message(&req).await
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use a2a::event::StreamResponse;
200    use futures::stream;
201    use std::sync::Mutex;
202
203    #[derive(Default)]
204    struct MockTransportState {
205        calls: Mutex<Vec<(String, ServiceParams)>>,
206        send_message_error: Mutex<Option<A2AError>>,
207    }
208
209    /// Mock transport that returns canned responses.
210    struct MockTransport {
211        state: Arc<MockTransportState>,
212    }
213
214    impl MockTransport {
215        fn new() -> (Self, Arc<MockTransportState>) {
216            let state = Arc::new(MockTransportState::default());
217            (
218                MockTransport {
219                    state: state.clone(),
220                },
221                state,
222            )
223        }
224
225        fn record(&self, method: &str, params: &ServiceParams) {
226            self.state
227                .calls
228                .lock()
229                .unwrap()
230                .push((method.to_string(), params.clone()));
231        }
232    }
233
234    #[async_trait]
235    impl Transport for MockTransport {
236        async fn send_message(
237            &self,
238            params: &ServiceParams,
239            _req: &SendMessageRequest,
240        ) -> Result<SendMessageResponse, A2AError> {
241            self.record(methods::SEND_MESSAGE, params);
242            if let Some(error) = self.state.send_message_error.lock().unwrap().clone() {
243                return Err(error);
244            }
245            Ok(SendMessageResponse::Task(Task {
246                id: "t1".into(),
247                context_id: "c1".into(),
248                status: TaskStatus {
249                    state: TaskState::Completed,
250                    message: None,
251                    timestamp: None,
252                },
253                artifacts: None,
254                history: None,
255                metadata: None,
256            }))
257        }
258
259        async fn send_streaming_message(
260            &self,
261            params: &ServiceParams,
262            _req: &SendMessageRequest,
263        ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264            self.record(methods::SEND_STREAMING_MESSAGE, params);
265            Ok(Box::pin(stream::once(async {
266                Ok(StreamResponse::StatusUpdate(
267                    a2a::event::TaskStatusUpdateEvent {
268                        task_id: "t1".into(),
269                        context_id: "c1".into(),
270                        status: TaskStatus {
271                            state: TaskState::Working,
272                            message: None,
273                            timestamp: None,
274                        },
275                        metadata: None,
276                    },
277                ))
278            })))
279        }
280
281        async fn get_task(
282            &self,
283            params: &ServiceParams,
284            req: &GetTaskRequest,
285        ) -> Result<Task, A2AError> {
286            self.record(methods::GET_TASK, params);
287            Ok(Task {
288                id: req.id.clone(),
289                context_id: "c1".into(),
290                status: TaskStatus {
291                    state: TaskState::Completed,
292                    message: None,
293                    timestamp: None,
294                },
295                artifacts: None,
296                history: None,
297                metadata: None,
298            })
299        }
300
301        async fn list_tasks(
302            &self,
303            params: &ServiceParams,
304            _req: &ListTasksRequest,
305        ) -> Result<ListTasksResponse, A2AError> {
306            self.record(methods::LIST_TASKS, params);
307            Ok(ListTasksResponse {
308                tasks: vec![],
309                next_page_token: String::new(),
310                page_size: 0,
311                total_size: 0,
312            })
313        }
314
315        async fn cancel_task(
316            &self,
317            params: &ServiceParams,
318            req: &CancelTaskRequest,
319        ) -> Result<Task, A2AError> {
320            self.record(methods::CANCEL_TASK, params);
321            Ok(Task {
322                id: req.id.clone(),
323                context_id: "c1".into(),
324                status: TaskStatus {
325                    state: TaskState::Canceled,
326                    message: None,
327                    timestamp: None,
328                },
329                artifacts: None,
330                history: None,
331                metadata: None,
332            })
333        }
334
335        async fn subscribe_to_task(
336            &self,
337            params: &ServiceParams,
338            _req: &SubscribeToTaskRequest,
339        ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
340            self.record(methods::SUBSCRIBE_TO_TASK, params);
341            Ok(Box::pin(stream::empty()))
342        }
343
344        async fn create_push_config(
345            &self,
346            params: &ServiceParams,
347            req: &CreateTaskPushNotificationConfigRequest,
348        ) -> Result<TaskPushNotificationConfig, A2AError> {
349            self.record(methods::CREATE_PUSH_CONFIG, params);
350            Ok(TaskPushNotificationConfig {
351                task_id: req.task_id.clone(),
352                config: req.config.clone(),
353                tenant: None,
354            })
355        }
356
357        async fn get_push_config(
358            &self,
359            params: &ServiceParams,
360            req: &GetTaskPushNotificationConfigRequest,
361        ) -> Result<TaskPushNotificationConfig, A2AError> {
362            self.record(methods::GET_PUSH_CONFIG, params);
363            Ok(TaskPushNotificationConfig {
364                task_id: req.task_id.clone(),
365                config: PushNotificationConfig {
366                    url: "http://example.com".into(),
367                    id: Some(req.id.clone()),
368                    token: None,
369                    authentication: None,
370                },
371                tenant: None,
372            })
373        }
374
375        async fn list_push_configs(
376            &self,
377            params: &ServiceParams,
378            _req: &ListTaskPushNotificationConfigsRequest,
379        ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
380            self.record(methods::LIST_PUSH_CONFIGS, params);
381            Ok(ListTaskPushNotificationConfigsResponse {
382                configs: vec![],
383                next_page_token: None,
384            })
385        }
386
387        async fn delete_push_config(
388            &self,
389            params: &ServiceParams,
390            _req: &DeleteTaskPushNotificationConfigRequest,
391        ) -> Result<(), A2AError> {
392            self.record(methods::DELETE_PUSH_CONFIG, params);
393            Ok(())
394        }
395
396        async fn get_extended_agent_card(
397            &self,
398            params: &ServiceParams,
399            _req: &GetExtendedAgentCardRequest,
400        ) -> Result<AgentCard, A2AError> {
401            self.record(methods::GET_EXTENDED_AGENT_CARD, params);
402            Ok(AgentCard {
403                name: "Test".into(),
404                description: "Test agent".into(),
405                version: "1.0".into(),
406                supported_interfaces: vec![],
407                capabilities: AgentCapabilities::default(),
408                default_input_modes: vec!["text/plain".into()],
409                default_output_modes: vec!["text/plain".into()],
410                skills: vec![],
411                provider: None,
412                documentation_url: None,
413                icon_url: None,
414                security_schemes: None,
415                security_requirements: None,
416                signatures: None,
417            })
418        }
419
420        async fn destroy(&self) -> Result<(), A2AError> {
421            Ok(())
422        }
423    }
424
425    fn make_client() -> A2AClient<MockTransport> {
426        let (transport, _) = MockTransport::new();
427        A2AClient::new(transport)
428    }
429
430    struct RecordingInterceptor {
431        name: &'static str,
432        events: Arc<Mutex<Vec<String>>>,
433    }
434
435    #[async_trait]
436    impl CallInterceptor for RecordingInterceptor {
437        async fn before(&self, _method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
438            self.events
439                .lock()
440                .unwrap()
441                .push(format!("before:{}", self.name));
442            params
443                .entry("X-Interceptor".to_string())
444                .or_default()
445                .push(self.name.to_string());
446            Ok(())
447        }
448
449        async fn after(
450            &self,
451            _method: &str,
452            result: &Result<(), A2AError>,
453        ) -> Result<(), A2AError> {
454            let status = if result.is_ok() { "ok" } else { "err" };
455            self.events
456                .lock()
457                .unwrap()
458                .push(format!("after:{}:{status}", self.name));
459            Ok(())
460        }
461    }
462
463    #[test]
464    fn test_new_sets_default_params() {
465        let client = make_client();
466        let params = client.params();
467        assert!(params.contains_key(SVC_PARAM_VERSION));
468    }
469
470    #[test]
471    fn test_with_interceptors() {
472        let client = make_client().with_interceptors(vec![]);
473        assert!(client.interceptors.is_empty());
474    }
475
476    #[tokio::test]
477    async fn test_send_message() {
478        let client = make_client();
479        let req = SendMessageRequest {
480            message: Message::new(Role::User, vec![Part::text("hi")]),
481            configuration: None,
482            metadata: None,
483            tenant: None,
484        };
485        let resp = client.send_message(&req).await.unwrap();
486        assert!(matches!(resp, SendMessageResponse::Task(_)));
487    }
488
489    #[tokio::test]
490    async fn test_send_message_applies_interceptors_and_reverses_after_order() {
491        let (transport, state) = MockTransport::new();
492        let events = Arc::new(Mutex::new(Vec::new()));
493        let client = A2AClient::new(transport).with_interceptors(vec![
494            Arc::new(RecordingInterceptor {
495                name: "first",
496                events: events.clone(),
497            }),
498            Arc::new(RecordingInterceptor {
499                name: "second",
500                events: events.clone(),
501            }),
502        ]);
503
504        let req = SendMessageRequest {
505            message: Message::new(Role::User, vec![Part::text("hi")]),
506            configuration: None,
507            metadata: None,
508            tenant: None,
509        };
510
511        client.send_message(&req).await.unwrap();
512
513        let calls = state.calls.lock().unwrap();
514        let params = &calls[0].1;
515        assert_eq!(
516            params.get("X-Interceptor").unwrap(),
517            &vec!["first".to_string(), "second".to_string()]
518        );
519
520        let events = events.lock().unwrap().clone();
521        assert_eq!(
522            events,
523            vec![
524                "before:first".to_string(),
525                "before:second".to_string(),
526                "after:second:ok".to_string(),
527                "after:first:ok".to_string(),
528            ]
529        );
530    }
531
532    #[tokio::test]
533    async fn test_send_message_preserves_transport_error_after_after_hooks() {
534        let (transport, state) = MockTransport::new();
535        *state.send_message_error.lock().unwrap() = Some(A2AError::internal("boom"));
536        let events = Arc::new(Mutex::new(Vec::new()));
537        let client =
538            A2AClient::new(transport).with_interceptors(vec![Arc::new(RecordingInterceptor {
539                name: "only",
540                events: events.clone(),
541            })]);
542
543        let req = SendMessageRequest {
544            message: Message::new(Role::User, vec![Part::text("hi")]),
545            configuration: None,
546            metadata: None,
547            tenant: None,
548        };
549
550        let err = client.send_message(&req).await.unwrap_err();
551        assert_eq!(err.message, "boom");
552
553        let events = events.lock().unwrap().clone();
554        assert_eq!(
555            events,
556            vec!["before:only".to_string(), "after:only:err".to_string(),]
557        );
558    }
559
560    #[tokio::test]
561    async fn test_send_streaming_message() {
562        use futures::StreamExt;
563        let client = make_client();
564        let req = SendMessageRequest {
565            message: Message::new(Role::User, vec![Part::text("hi")]),
566            configuration: None,
567            metadata: None,
568            tenant: None,
569        };
570        let mut stream = client.send_streaming_message(&req).await.unwrap();
571        let item = stream.next().await.unwrap().unwrap();
572        assert!(matches!(item, StreamResponse::StatusUpdate(_)));
573    }
574
575    #[tokio::test]
576    async fn test_get_task() {
577        let client = make_client();
578        let req = GetTaskRequest {
579            id: "t1".into(),
580            history_length: None,
581            tenant: None,
582        };
583        let task = client.get_task(&req).await.unwrap();
584        assert_eq!(task.id, "t1");
585    }
586
587    #[tokio::test]
588    async fn test_list_tasks() {
589        let client = make_client();
590        let req = ListTasksRequest {
591            context_id: None,
592            status: None,
593            page_size: None,
594            page_token: None,
595            history_length: None,
596            status_timestamp_after: None,
597            include_artifacts: None,
598            tenant: None,
599        };
600        let resp = client.list_tasks(&req).await.unwrap();
601        assert!(resp.tasks.is_empty());
602    }
603
604    #[tokio::test]
605    async fn test_cancel_task() {
606        let client = make_client();
607        let req = CancelTaskRequest {
608            id: "t1".into(),
609            metadata: None,
610            tenant: None,
611        };
612        let task = client.cancel_task(&req).await.unwrap();
613        assert_eq!(task.status.state, TaskState::Canceled);
614    }
615
616    #[tokio::test]
617    async fn test_subscribe_to_task() {
618        let client = make_client();
619        let req = SubscribeToTaskRequest {
620            id: "t1".into(),
621            tenant: None,
622        };
623        let _stream = client.subscribe_to_task(&req).await.unwrap();
624    }
625
626    #[tokio::test]
627    async fn test_create_push_config() {
628        let client = make_client();
629        let req = CreateTaskPushNotificationConfigRequest {
630            task_id: "t1".into(),
631            config: PushNotificationConfig {
632                url: "http://example.com".into(),
633                id: None,
634                token: None,
635                authentication: None,
636            },
637            tenant: None,
638        };
639        let resp = client.create_push_config(&req).await.unwrap();
640        assert_eq!(resp.task_id, "t1");
641    }
642
643    #[tokio::test]
644    async fn test_get_push_config() {
645        let client = make_client();
646        let req = GetTaskPushNotificationConfigRequest {
647            task_id: "t1".into(),
648            id: "cfg1".into(),
649            tenant: None,
650        };
651        let resp = client.get_push_config(&req).await.unwrap();
652        assert_eq!(resp.config.id, Some("cfg1".into()));
653    }
654
655    #[tokio::test]
656    async fn test_list_push_configs() {
657        let client = make_client();
658        let req = ListTaskPushNotificationConfigsRequest {
659            task_id: "t1".into(),
660            page_size: None,
661            page_token: None,
662            tenant: None,
663        };
664        let resp = client.list_push_configs(&req).await.unwrap();
665        assert!(resp.configs.is_empty());
666    }
667
668    #[tokio::test]
669    async fn test_delete_push_config() {
670        let client = make_client();
671        let req = DeleteTaskPushNotificationConfigRequest {
672            task_id: "t1".into(),
673            id: "cfg1".into(),
674            tenant: None,
675        };
676        client.delete_push_config(&req).await.unwrap();
677    }
678
679    #[tokio::test]
680    async fn test_get_extended_agent_card() {
681        let client = make_client();
682        let req = GetExtendedAgentCardRequest { tenant: None };
683        let card = client.get_extended_agent_card(&req).await.unwrap();
684        assert_eq!(card.name, "Test");
685    }
686
687    #[tokio::test]
688    async fn test_destroy() {
689        let client = make_client();
690        client.destroy().await.unwrap();
691    }
692}