Skip to main content

a2a_client/
rest.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::push_config_compat::{
12    deserialize_list_task_push_notification_configs_response,
13    deserialize_task_push_notification_config,
14};
15use crate::transport::{ServiceParams, Transport, TransportFactory};
16
17const REST_SEND_MESSAGE_PATH: &str = "/message:send";
18const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
19const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
20
21#[derive(Debug, Deserialize)]
22struct RestErrorEnvelope {
23    error: RestErrorStatus,
24}
25
26#[derive(Debug, Deserialize)]
27struct RestErrorStatus {
28    message: String,
29
30    #[serde(default)]
31    details: Vec<TypedDetail>,
32}
33
34/// REST (HTTP+JSON) transport implementation.
35///
36/// Maps A2A operations to RESTful HTTP endpoints.
37pub struct RestTransport {
38    client: Client,
39    base_url: String,
40}
41
42impl RestTransport {
43    /// Build a `RestTransport` from a pre-constructed `reqwest::Client`.
44    pub fn new(client: Client, base_url: String) -> Self {
45        let base_url = base_url.trim_end_matches('/').to_string();
46        RestTransport { client, base_url }
47    }
48
49    fn build_request(
50        &self,
51        method: reqwest::Method,
52        path: &str,
53        params: &ServiceParams,
54    ) -> reqwest::RequestBuilder {
55        let url = format!("{}{}", self.base_url, path);
56        let mut builder = self.client.request(method, &url);
57        for (key, values) in params {
58            for v in values {
59                builder = builder.header(key, v);
60            }
61        }
62        builder
63    }
64
65    fn build_request_with_query(
66        &self,
67        method: reqwest::Method,
68        path: &str,
69        params: &ServiceParams,
70        query: &[(String, String)],
71    ) -> reqwest::RequestBuilder {
72        let builder = self.build_request(method, path, params);
73        if query.is_empty() {
74            builder
75        } else {
76            builder.query(query)
77        }
78    }
79
80    async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
81        builder
82            .send()
83            .await
84            .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
85    }
86
87    async fn into_rest_error(resp: reqwest::Response) -> A2AError {
88        let status = resp.status();
89        let body = resp.text().await.unwrap_or_default();
90        parse_rest_error(status, &body)
91    }
92
93    async fn post_value<Req>(
94        &self,
95        path: &str,
96        params: &ServiceParams,
97        body: &Req,
98    ) -> Result<Value, A2AError>
99    where
100        Req: ProtoJsonPayload,
101    {
102        let payload = protojson_conv::to_value(body).map_err(|e| {
103            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
104        })?;
105        let resp = self
106            .send(
107                self.build_request(reqwest::Method::POST, path, params)
108                    .json(&payload),
109            )
110            .await?;
111
112        if !resp.status().is_success() {
113            return Err(Self::into_rest_error(resp).await);
114        }
115        let payload = resp
116            .json::<Value>()
117            .await
118            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
119
120        Ok(payload)
121    }
122
123    async fn post_json<Req, Resp>(
124        &self,
125        path: &str,
126        params: &ServiceParams,
127        body: &Req,
128    ) -> Result<Resp, A2AError>
129    where
130        Req: ProtoJsonPayload,
131        Resp: ProtoJsonPayload,
132    {
133        let payload = self.post_value(path, params, body).await?;
134
135        protojson_conv::from_value(payload).map_err(|e| {
136            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
137        })
138    }
139
140    async fn get_value(
141        &self,
142        path: &str,
143        params: &ServiceParams,
144        query: &[(String, String)],
145    ) -> Result<Value, A2AError> {
146        let resp = self
147            .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
148            .await?;
149
150        if !resp.status().is_success() {
151            return Err(Self::into_rest_error(resp).await);
152        }
153        let payload = resp
154            .json::<Value>()
155            .await
156            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
157
158        Ok(payload)
159    }
160
161    async fn get_json<Resp>(
162        &self,
163        path: &str,
164        params: &ServiceParams,
165        query: &[(String, String)],
166    ) -> Result<Resp, A2AError>
167    where
168        Resp: ProtoJsonPayload,
169    {
170        let payload = self.get_value(path, params, query).await?;
171
172        protojson_conv::from_value(payload).map_err(|e| {
173            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
174        })
175    }
176
177    async fn post_streaming<Req>(
178        &self,
179        path: &str,
180        params: &ServiceParams,
181        body: &Req,
182    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
183    where
184        Req: ProtoJsonPayload,
185    {
186        let payload = protojson_conv::to_value(body).map_err(|e| {
187            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
188        })?;
189        let resp = self
190            .send(
191                self.build_request(reqwest::Method::POST, path, params)
192                    .header("Accept", "text/event-stream")
193                    .json(&payload),
194            )
195            .await?;
196
197        if !resp.status().is_success() {
198            return Err(Self::into_rest_error(resp).await);
199        }
200
201        let stream = resp.bytes_stream();
202        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
203    }
204
205    async fn get_streaming(
206        &self,
207        path: &str,
208        params: &ServiceParams,
209    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
210        let resp = self
211            .send(
212                self.build_request(reqwest::Method::GET, path, params)
213                    .header("Accept", "text/event-stream"),
214            )
215            .await?;
216
217        if !resp.status().is_success() {
218            return Err(Self::into_rest_error(resp).await);
219        }
220
221        let stream = resp.bytes_stream();
222        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
223    }
224
225    async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
226        let resp = self
227            .send(self.build_request(reqwest::Method::DELETE, path, params))
228            .await?;
229
230        if !resp.status().is_success() {
231            return Err(Self::into_rest_error(resp).await);
232        }
233        Ok(())
234    }
235}
236
237fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
238    let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
239        return A2AError::internal(format!("HTTP {status}: {body}"));
240    };
241
242    crate::a2a_error_from_details(
243        error_code::INTERNAL_ERROR,
244        envelope.error.message,
245        envelope.error.details,
246    )
247}
248
249#[async_trait]
250impl Transport for RestTransport {
251    async fn send_message(
252        &self,
253        params: &ServiceParams,
254        req: &SendMessageRequest,
255    ) -> Result<SendMessageResponse, A2AError> {
256        self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
257    }
258
259    async fn send_streaming_message(
260        &self,
261        params: &ServiceParams,
262        req: &SendMessageRequest,
263    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
264        self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
265            .await
266    }
267
268    async fn get_task(
269        &self,
270        params: &ServiceParams,
271        req: &GetTaskRequest,
272    ) -> Result<Task, A2AError> {
273        let path = format!("/tasks/{}", req.id);
274        let mut query_parts = Vec::new();
275        if let Some(hl) = req.history_length {
276            query_parts.push(("historyLength".to_string(), hl.to_string()));
277        }
278        self.get_json(&path, params, &query_parts).await
279    }
280
281    async fn list_tasks(
282        &self,
283        params: &ServiceParams,
284        req: &ListTasksRequest,
285    ) -> Result<ListTasksResponse, A2AError> {
286        let mut query_parts = Vec::new();
287        if let Some(ref cid) = req.context_id {
288            query_parts.push(("contextId".to_string(), cid.clone()));
289        }
290        if let Some(ref status) = req.status {
291            let s = serde_json::to_value(status)
292                .ok()
293                .and_then(|v| v.as_str().map(String::from))
294                .unwrap_or_default();
295            query_parts.push(("status".to_string(), s));
296        }
297        if let Some(ps) = req.page_size {
298            query_parts.push(("pageSize".to_string(), ps.to_string()));
299        }
300        if let Some(ref pt) = req.page_token {
301            query_parts.push(("pageToken".to_string(), pt.clone()));
302        }
303        if let Some(hl) = req.history_length {
304            query_parts.push(("historyLength".to_string(), hl.to_string()));
305        }
306        if let Some(ref ts) = req.status_timestamp_after {
307            query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
308        }
309        if let Some(ia) = req.include_artifacts {
310            query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
311        }
312        self.get_json("/tasks", params, &query_parts).await
313    }
314
315    async fn cancel_task(
316        &self,
317        params: &ServiceParams,
318        req: &CancelTaskRequest,
319    ) -> Result<Task, A2AError> {
320        self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
321            .await
322    }
323
324    async fn subscribe_to_task(
325        &self,
326        params: &ServiceParams,
327        req: &SubscribeToTaskRequest,
328    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
329        self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
330            .await
331    }
332
333    async fn create_push_config(
334        &self,
335        params: &ServiceParams,
336        req: &TaskPushNotificationConfig,
337    ) -> Result<TaskPushNotificationConfig, A2AError> {
338        let payload = self
339            .post_value(
340                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
341                params,
342                req,
343            )
344            .await?;
345        deserialize_task_push_notification_config(payload)
346    }
347
348    async fn get_push_config(
349        &self,
350        params: &ServiceParams,
351        req: &GetTaskPushNotificationConfigRequest,
352    ) -> Result<TaskPushNotificationConfig, A2AError> {
353        let payload = self
354            .get_value(
355                &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
356                params,
357                &[],
358            )
359            .await?;
360        deserialize_task_push_notification_config(payload)
361    }
362
363    async fn list_push_configs(
364        &self,
365        params: &ServiceParams,
366        req: &ListTaskPushNotificationConfigsRequest,
367    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
368        let mut query_parts = Vec::new();
369        if let Some(page_size) = req.page_size {
370            query_parts.push(("pageSize".to_string(), page_size.to_string()));
371        }
372        if let Some(ref page_token) = req.page_token {
373            query_parts.push(("pageToken".to_string(), page_token.clone()));
374        }
375
376        let payload = self
377            .get_value(
378                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
379                params,
380                &query_parts,
381            )
382            .await?;
383        deserialize_list_task_push_notification_configs_response(payload)
384    }
385
386    async fn delete_push_config(
387        &self,
388        params: &ServiceParams,
389        req: &DeleteTaskPushNotificationConfigRequest,
390    ) -> Result<(), A2AError> {
391        self.delete(
392            &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
393            params,
394        )
395        .await
396    }
397
398    async fn get_extended_agent_card(
399        &self,
400        params: &ServiceParams,
401        _req: &GetExtendedAgentCardRequest,
402    ) -> Result<AgentCard, A2AError> {
403        self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
404            .await
405    }
406
407    async fn destroy(&self) -> Result<(), A2AError> {
408        Ok(())
409    }
410}
411
412/// Factory for creating [`RestTransport`] instances.
413pub struct RestTransportFactory {
414    client: Client,
415}
416
417impl RestTransportFactory {
418    pub fn new(client: Option<Client>) -> Self {
419        RestTransportFactory {
420            client: client
421                .unwrap_or_else(|| crate::default_reqwest_client(None).expect("default client")),
422        }
423    }
424
425    #[cfg(any(
426        feature = "rustls-tls",
427        feature = "rustls-no-provider",
428        feature = "native-tls"
429    ))]
430    pub fn with_root_certificates_pem(pem: &[u8]) -> Result<Self, A2AError> {
431        Ok(Self {
432            client: crate::default_reqwest_client(Some(pem))?,
433        })
434    }
435}
436
437#[async_trait]
438impl TransportFactory for RestTransportFactory {
439    fn protocol(&self) -> &str {
440        TRANSPORT_PROTOCOL_HTTP_JSON
441    }
442
443    async fn create(
444        &self,
445        _card: &AgentCard,
446        iface: &AgentInterface,
447    ) -> Result<Box<dyn Transport>, A2AError> {
448        Ok(Box::new(RestTransport::new(
449            self.client.clone(),
450            iface.url.clone(),
451        )))
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use serde_json::json;
459
460    #[test]
461    fn test_rest_transport_new_strips_trailing_slash() {
462        let t = RestTransport::new(
463            crate::default_reqwest_client(None).unwrap(),
464            "http://localhost:8080/".into(),
465        );
466        assert_eq!(t.base_url, "http://localhost:8080");
467    }
468
469    #[test]
470    fn test_rest_transport_new_no_trailing_slash() {
471        let t = RestTransport::new(
472            crate::default_reqwest_client(None).unwrap(),
473            "http://localhost:8080".into(),
474        );
475        assert_eq!(t.base_url, "http://localhost:8080");
476    }
477
478    #[test]
479    fn test_rest_transport_factory_protocol() {
480        let f = RestTransportFactory::new(None);
481        assert_eq!(f.protocol(), "HTTP+JSON");
482    }
483
484    #[tokio::test]
485    async fn test_rest_transport_factory_create() {
486        let f = RestTransportFactory::new(None);
487        let card = AgentCard {
488            name: "Test".into(),
489            description: "Test".into(),
490            version: "1.0".into(),
491            supported_interfaces: vec![],
492            capabilities: AgentCapabilities::default(),
493            default_input_modes: vec!["text/plain".into()],
494            default_output_modes: vec!["text/plain".into()],
495            skills: vec![],
496            provider: None,
497            documentation_url: None,
498            icon_url: None,
499            security_schemes: None,
500            security_requirements: None,
501            signatures: None,
502        };
503        let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
504        let transport = f.create(&card, &iface).await.unwrap();
505        transport.destroy().await.unwrap();
506    }
507
508    #[test]
509    fn test_build_request_adds_params() {
510        let t = RestTransport::new(
511            crate::default_reqwest_client(None).unwrap(),
512            "http://localhost:8080".into(),
513        );
514        let mut params = ServiceParams::new();
515        params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
516        let builder = t.build_request(reqwest::Method::GET, "/test", &params);
517        let req = builder.build().unwrap();
518        let vals: Vec<_> = req
519            .headers()
520            .get_all("X-Custom")
521            .iter()
522            .map(|v| v.to_str().unwrap().to_string())
523            .collect();
524        assert_eq!(vals, vec!["val1", "val2"]);
525    }
526
527    #[test]
528    fn test_parse_rest_error_preserves_a2a_error_code() {
529        let body = json!({
530            "error": {
531                "code": 404,
532                "status": "NOT_FOUND",
533                "message": "task not found: t1",
534                "details": [
535                    {
536                        "@type": errordetails::ERROR_INFO_TYPE,
537                        "reason": "TASK_NOT_FOUND",
538                        "domain": errordetails::PROTOCOL_DOMAIN,
539                        "metadata": {
540                            "taskId": "t1"
541                        }
542                    },
543                    {
544                        "resource": "task"
545                    }
546                ]
547            }
548        })
549        .to_string();
550
551        let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
552
553        assert_eq!(err.code, error_code::TASK_NOT_FOUND);
554        assert_eq!(err.message, "task not found: t1");
555        let details = err.details.expect("expected structured details");
556        assert_eq!(details.len(), 2);
557        assert_eq!(details[0].type_url, errordetails::ERROR_INFO_TYPE);
558        assert_eq!(
559            details[1].value.get("resource"),
560            Some(&Value::String("task".into()))
561        );
562    }
563
564    #[test]
565    fn test_parse_rest_error_accepts_go_reason_aliases() {
566        let body = json!({
567            "error": {
568                "code": 400,
569                "status": "INVALID_ARGUMENT",
570                "message": "incompatible content types",
571                "details": [
572                    {
573                        "@type": errordetails::ERROR_INFO_TYPE,
574                        "reason": "UNSUPPORTED_CONTENT_TYPE",
575                        "domain": errordetails::PROTOCOL_DOMAIN,
576                        "metadata": {}
577                    }
578                ]
579            }
580        })
581        .to_string();
582
583        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
584        assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
585    }
586
587    #[test]
588    fn test_parse_rest_error_bad_request_fallback() {
589        let body = json!({
590            "error": {
591                "code": 400,
592                "status": "INVALID_ARGUMENT",
593                "message": "invalid request parameters",
594                "details": [
595                    {
596                        "@type": errordetails::BAD_REQUEST_TYPE,
597                        "fieldViolations": [
598                            {
599                                "field": "message.parts",
600                                "description": "At least one part is required"
601                            }
602                        ]
603                    }
604                ]
605            }
606        })
607        .to_string();
608
609        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
610        assert_eq!(err.code, error_code::INVALID_PARAMS);
611        assert!(
612            err.message
613                .contains("message.parts: At least one part is required")
614        );
615        let details = err.details.expect("expected details");
616        assert_eq!(details.len(), 1);
617        assert_eq!(details[0].type_url, errordetails::BAD_REQUEST_TYPE);
618        let violations = details[0].value.get("fieldViolations").unwrap();
619        assert_eq!(violations[0]["field"], "message.parts");
620    }
621
622    #[test]
623    fn test_parse_rest_error_bad_request_with_error_info_uses_reason() {
624        let body = json!({
625            "error": {
626                "code": 400,
627                "status": "INVALID_ARGUMENT",
628                "message": "bad params",
629                "details": [
630                    {
631                        "@type": errordetails::BAD_REQUEST_TYPE,
632                        "fieldViolations": [
633                            {"field": "task.id", "description": "required"}
634                        ]
635                    },
636                    {
637                        "@type": errordetails::ERROR_INFO_TYPE,
638                        "reason": "INVALID_PARAMS",
639                        "domain": errordetails::PROTOCOL_DOMAIN,
640                        "metadata": {}
641                    }
642                ]
643            }
644        })
645        .to_string();
646
647        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
648        assert_eq!(err.code, error_code::INVALID_PARAMS);
649    }
650
651    #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
652    #[test]
653    fn test_with_root_certificates_pem_valid() {
654        let pem = crate::test_utils::rcgen_self_signed_ca_pem();
655        let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
656        assert_eq!(f.protocol(), TRANSPORT_PROTOCOL_HTTP_JSON);
657    }
658
659    #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
660    #[tokio::test]
661    async fn test_with_root_certificates_pem_factory_create() {
662        let pem = crate::test_utils::rcgen_self_signed_ca_pem();
663        let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
664        let card = AgentCard {
665            name: "Test".into(),
666            description: "Test".into(),
667            version: "1.0".into(),
668            supported_interfaces: vec![],
669            capabilities: AgentCapabilities::default(),
670            default_input_modes: vec!["text/plain".into()],
671            default_output_modes: vec!["text/plain".into()],
672            skills: vec![],
673            provider: None,
674            documentation_url: None,
675            icon_url: None,
676            security_schemes: None,
677            security_requirements: None,
678            signatures: None,
679        };
680        let iface = AgentInterface::new("https://localhost:3443/rest", "HTTP+JSON");
681        let transport = f.create(&card, &iface).await.unwrap();
682        transport.destroy().await.unwrap();
683    }
684}