Skip to main content

a2a_client/
rest.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::push_config_compat::{
12    deserialize_list_task_push_notification_configs_response,
13    deserialize_task_push_notification_config,
14};
15use crate::transport::{ServiceParams, Transport, TransportFactory};
16
17const REST_SEND_MESSAGE_PATH: &str = "/message:send";
18const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
19const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
20
21#[derive(Debug, Deserialize)]
22struct RestErrorEnvelope {
23    error: RestErrorStatus,
24}
25
26#[derive(Debug, Deserialize)]
27struct RestErrorStatus {
28    message: String,
29
30    #[serde(default)]
31    details: Vec<TypedDetail>,
32}
33
34/// REST (HTTP+JSON) transport implementation.
35///
36/// Maps A2A operations to RESTful HTTP endpoints.
37pub struct RestTransport {
38    client: Client,
39    base_url: String,
40}
41
42impl RestTransport {
43    pub fn new(client: Client, base_url: String) -> Self {
44        let base_url = base_url.trim_end_matches('/').to_string();
45        RestTransport { client, base_url }
46    }
47
48    fn build_request(
49        &self,
50        method: reqwest::Method,
51        path: &str,
52        params: &ServiceParams,
53    ) -> reqwest::RequestBuilder {
54        let url = format!("{}{}", self.base_url, path);
55        let mut builder = self.client.request(method, &url);
56        for (key, values) in params {
57            for v in values {
58                builder = builder.header(key, v);
59            }
60        }
61        builder
62    }
63
64    fn build_request_with_query(
65        &self,
66        method: reqwest::Method,
67        path: &str,
68        params: &ServiceParams,
69        query: &[(String, String)],
70    ) -> reqwest::RequestBuilder {
71        let builder = self.build_request(method, path, params);
72        if query.is_empty() {
73            builder
74        } else {
75            builder.query(query)
76        }
77    }
78
79    async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
80        builder
81            .send()
82            .await
83            .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
84    }
85
86    async fn into_rest_error(resp: reqwest::Response) -> A2AError {
87        let status = resp.status();
88        let body = resp.text().await.unwrap_or_default();
89        parse_rest_error(status, &body)
90    }
91
92    async fn post_value<Req>(
93        &self,
94        path: &str,
95        params: &ServiceParams,
96        body: &Req,
97    ) -> Result<Value, A2AError>
98    where
99        Req: ProtoJsonPayload,
100    {
101        let payload = protojson_conv::to_value(body).map_err(|e| {
102            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
103        })?;
104        let resp = self
105            .send(
106                self.build_request(reqwest::Method::POST, path, params)
107                    .json(&payload),
108            )
109            .await?;
110
111        if !resp.status().is_success() {
112            return Err(Self::into_rest_error(resp).await);
113        }
114        let payload = resp
115            .json::<Value>()
116            .await
117            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
118
119        Ok(payload)
120    }
121
122    async fn post_json<Req, Resp>(
123        &self,
124        path: &str,
125        params: &ServiceParams,
126        body: &Req,
127    ) -> Result<Resp, A2AError>
128    where
129        Req: ProtoJsonPayload,
130        Resp: ProtoJsonPayload,
131    {
132        let payload = self.post_value(path, params, body).await?;
133
134        protojson_conv::from_value(payload).map_err(|e| {
135            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
136        })
137    }
138
139    async fn get_value(
140        &self,
141        path: &str,
142        params: &ServiceParams,
143        query: &[(String, String)],
144    ) -> Result<Value, A2AError> {
145        let resp = self
146            .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
147            .await?;
148
149        if !resp.status().is_success() {
150            return Err(Self::into_rest_error(resp).await);
151        }
152        let payload = resp
153            .json::<Value>()
154            .await
155            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
156
157        Ok(payload)
158    }
159
160    async fn get_json<Resp>(
161        &self,
162        path: &str,
163        params: &ServiceParams,
164        query: &[(String, String)],
165    ) -> Result<Resp, A2AError>
166    where
167        Resp: ProtoJsonPayload,
168    {
169        let payload = self.get_value(path, params, query).await?;
170
171        protojson_conv::from_value(payload).map_err(|e| {
172            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
173        })
174    }
175
176    async fn post_streaming<Req>(
177        &self,
178        path: &str,
179        params: &ServiceParams,
180        body: &Req,
181    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
182    where
183        Req: ProtoJsonPayload,
184    {
185        let payload = protojson_conv::to_value(body).map_err(|e| {
186            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
187        })?;
188        let resp = self
189            .send(
190                self.build_request(reqwest::Method::POST, path, params)
191                    .header("Accept", "text/event-stream")
192                    .json(&payload),
193            )
194            .await?;
195
196        if !resp.status().is_success() {
197            return Err(Self::into_rest_error(resp).await);
198        }
199
200        let stream = resp.bytes_stream();
201        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
202    }
203
204    async fn get_streaming(
205        &self,
206        path: &str,
207        params: &ServiceParams,
208    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
209        let resp = self
210            .send(
211                self.build_request(reqwest::Method::GET, path, params)
212                    .header("Accept", "text/event-stream"),
213            )
214            .await?;
215
216        if !resp.status().is_success() {
217            return Err(Self::into_rest_error(resp).await);
218        }
219
220        let stream = resp.bytes_stream();
221        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
222    }
223
224    async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
225        let resp = self
226            .send(self.build_request(reqwest::Method::DELETE, path, params))
227            .await?;
228
229        if !resp.status().is_success() {
230            return Err(Self::into_rest_error(resp).await);
231        }
232        Ok(())
233    }
234}
235
236fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
237    let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
238        return A2AError::internal(format!("HTTP {status}: {body}"));
239    };
240
241    crate::a2a_error_from_details(
242        error_code::INTERNAL_ERROR,
243        envelope.error.message,
244        envelope.error.details,
245    )
246}
247
248#[async_trait]
249impl Transport for RestTransport {
250    async fn send_message(
251        &self,
252        params: &ServiceParams,
253        req: &SendMessageRequest,
254    ) -> Result<SendMessageResponse, A2AError> {
255        self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
256    }
257
258    async fn send_streaming_message(
259        &self,
260        params: &ServiceParams,
261        req: &SendMessageRequest,
262    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
263        self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
264            .await
265    }
266
267    async fn get_task(
268        &self,
269        params: &ServiceParams,
270        req: &GetTaskRequest,
271    ) -> Result<Task, A2AError> {
272        let path = format!("/tasks/{}", req.id);
273        let mut query_parts = Vec::new();
274        if let Some(hl) = req.history_length {
275            query_parts.push(("historyLength".to_string(), hl.to_string()));
276        }
277        self.get_json(&path, params, &query_parts).await
278    }
279
280    async fn list_tasks(
281        &self,
282        params: &ServiceParams,
283        req: &ListTasksRequest,
284    ) -> Result<ListTasksResponse, A2AError> {
285        let mut query_parts = Vec::new();
286        if let Some(ref cid) = req.context_id {
287            query_parts.push(("contextId".to_string(), cid.clone()));
288        }
289        if let Some(ref status) = req.status {
290            let s = serde_json::to_value(status)
291                .ok()
292                .and_then(|v| v.as_str().map(String::from))
293                .unwrap_or_default();
294            query_parts.push(("status".to_string(), s));
295        }
296        if let Some(ps) = req.page_size {
297            query_parts.push(("pageSize".to_string(), ps.to_string()));
298        }
299        if let Some(ref pt) = req.page_token {
300            query_parts.push(("pageToken".to_string(), pt.clone()));
301        }
302        if let Some(hl) = req.history_length {
303            query_parts.push(("historyLength".to_string(), hl.to_string()));
304        }
305        if let Some(ref ts) = req.status_timestamp_after {
306            query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
307        }
308        if let Some(ia) = req.include_artifacts {
309            query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
310        }
311        self.get_json("/tasks", params, &query_parts).await
312    }
313
314    async fn cancel_task(
315        &self,
316        params: &ServiceParams,
317        req: &CancelTaskRequest,
318    ) -> Result<Task, A2AError> {
319        self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
320            .await
321    }
322
323    async fn subscribe_to_task(
324        &self,
325        params: &ServiceParams,
326        req: &SubscribeToTaskRequest,
327    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
328        self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
329            .await
330    }
331
332    async fn create_push_config(
333        &self,
334        params: &ServiceParams,
335        req: &TaskPushNotificationConfig,
336    ) -> Result<TaskPushNotificationConfig, A2AError> {
337        let payload = self
338            .post_value(
339                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
340                params,
341                req,
342            )
343            .await?;
344        deserialize_task_push_notification_config(payload)
345    }
346
347    async fn get_push_config(
348        &self,
349        params: &ServiceParams,
350        req: &GetTaskPushNotificationConfigRequest,
351    ) -> Result<TaskPushNotificationConfig, A2AError> {
352        let payload = self
353            .get_value(
354                &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
355                params,
356                &[],
357            )
358            .await?;
359        deserialize_task_push_notification_config(payload)
360    }
361
362    async fn list_push_configs(
363        &self,
364        params: &ServiceParams,
365        req: &ListTaskPushNotificationConfigsRequest,
366    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
367        let mut query_parts = Vec::new();
368        if let Some(page_size) = req.page_size {
369            query_parts.push(("pageSize".to_string(), page_size.to_string()));
370        }
371        if let Some(ref page_token) = req.page_token {
372            query_parts.push(("pageToken".to_string(), page_token.clone()));
373        }
374
375        let payload = self
376            .get_value(
377                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
378                params,
379                &query_parts,
380            )
381            .await?;
382        deserialize_list_task_push_notification_configs_response(payload)
383    }
384
385    async fn delete_push_config(
386        &self,
387        params: &ServiceParams,
388        req: &DeleteTaskPushNotificationConfigRequest,
389    ) -> Result<(), A2AError> {
390        self.delete(
391            &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
392            params,
393        )
394        .await
395    }
396
397    async fn get_extended_agent_card(
398        &self,
399        params: &ServiceParams,
400        _req: &GetExtendedAgentCardRequest,
401    ) -> Result<AgentCard, A2AError> {
402        self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
403            .await
404    }
405
406    async fn destroy(&self) -> Result<(), A2AError> {
407        Ok(())
408    }
409}
410
411/// Factory for creating [`RestTransport`] instances.
412pub struct RestTransportFactory {
413    client: Client,
414}
415
416impl RestTransportFactory {
417    pub fn new(client: Option<Client>) -> Self {
418        RestTransportFactory {
419            client: client.unwrap_or_default(),
420        }
421    }
422
423    #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
424    pub fn with_root_certificates_pem(pem: &[u8]) -> Result<Self, A2AError> {
425        Ok(Self {
426            client: crate::build_reqwest_client_with_root_pem(pem)?,
427        })
428    }
429}
430
431#[async_trait]
432impl TransportFactory for RestTransportFactory {
433    fn protocol(&self) -> &str {
434        TRANSPORT_PROTOCOL_HTTP_JSON
435    }
436
437    async fn create(
438        &self,
439        _card: &AgentCard,
440        iface: &AgentInterface,
441    ) -> Result<Box<dyn Transport>, A2AError> {
442        Ok(Box::new(RestTransport::new(
443            self.client.clone(),
444            iface.url.clone(),
445        )))
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use crate::test_utils::rcgen_self_signed_ca_pem;
453    use serde_json::json;
454
455    #[test]
456    fn test_rest_transport_new_strips_trailing_slash() {
457        let t = RestTransport::new(Client::new(), "http://localhost:8080/".into());
458        assert_eq!(t.base_url, "http://localhost:8080");
459    }
460
461    #[test]
462    fn test_rest_transport_new_no_trailing_slash() {
463        let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
464        assert_eq!(t.base_url, "http://localhost:8080");
465    }
466
467    #[test]
468    fn test_rest_transport_factory_protocol() {
469        let f = RestTransportFactory::new(None);
470        assert_eq!(f.protocol(), "HTTP+JSON");
471    }
472
473    #[tokio::test]
474    async fn test_rest_transport_factory_create() {
475        let f = RestTransportFactory::new(None);
476        let card = AgentCard {
477            name: "Test".into(),
478            description: "Test".into(),
479            version: "1.0".into(),
480            supported_interfaces: vec![],
481            capabilities: AgentCapabilities::default(),
482            default_input_modes: vec!["text/plain".into()],
483            default_output_modes: vec!["text/plain".into()],
484            skills: vec![],
485            provider: None,
486            documentation_url: None,
487            icon_url: None,
488            security_schemes: None,
489            security_requirements: None,
490            signatures: None,
491        };
492        let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
493        let transport = f.create(&card, &iface).await.unwrap();
494        transport.destroy().await.unwrap();
495    }
496
497    #[test]
498    fn test_build_request_adds_params() {
499        let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
500        let mut params = ServiceParams::new();
501        params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
502        let builder = t.build_request(reqwest::Method::GET, "/test", &params);
503        let req = builder.build().unwrap();
504        let vals: Vec<_> = req
505            .headers()
506            .get_all("X-Custom")
507            .iter()
508            .map(|v| v.to_str().unwrap().to_string())
509            .collect();
510        assert_eq!(vals, vec!["val1", "val2"]);
511    }
512
513    #[test]
514    fn test_parse_rest_error_preserves_a2a_error_code() {
515        let body = json!({
516            "error": {
517                "code": 404,
518                "status": "NOT_FOUND",
519                "message": "task not found: t1",
520                "details": [
521                    {
522                        "@type": errordetails::ERROR_INFO_TYPE,
523                        "reason": "TASK_NOT_FOUND",
524                        "domain": errordetails::PROTOCOL_DOMAIN,
525                        "metadata": {
526                            "taskId": "t1"
527                        }
528                    },
529                    {
530                        "resource": "task"
531                    }
532                ]
533            }
534        })
535        .to_string();
536
537        let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
538
539        assert_eq!(err.code, error_code::TASK_NOT_FOUND);
540        assert_eq!(err.message, "task not found: t1");
541        let details = err.details.expect("expected structured details");
542        assert_eq!(details.len(), 2);
543        assert_eq!(details[0].type_url, errordetails::ERROR_INFO_TYPE);
544        assert_eq!(
545            details[1].value.get("resource"),
546            Some(&Value::String("task".into()))
547        );
548    }
549
550    #[test]
551    fn test_parse_rest_error_accepts_go_reason_aliases() {
552        let body = json!({
553            "error": {
554                "code": 400,
555                "status": "INVALID_ARGUMENT",
556                "message": "incompatible content types",
557                "details": [
558                    {
559                        "@type": errordetails::ERROR_INFO_TYPE,
560                        "reason": "UNSUPPORTED_CONTENT_TYPE",
561                        "domain": errordetails::PROTOCOL_DOMAIN,
562                        "metadata": {}
563                    }
564                ]
565            }
566        })
567        .to_string();
568
569        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
570        assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
571    }
572
573    #[test]
574    fn test_parse_rest_error_bad_request_fallback() {
575        let body = json!({
576            "error": {
577                "code": 400,
578                "status": "INVALID_ARGUMENT",
579                "message": "invalid request parameters",
580                "details": [
581                    {
582                        "@type": errordetails::BAD_REQUEST_TYPE,
583                        "fieldViolations": [
584                            {
585                                "field": "message.parts",
586                                "description": "At least one part is required"
587                            }
588                        ]
589                    }
590                ]
591            }
592        })
593        .to_string();
594
595        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
596        assert_eq!(err.code, error_code::INVALID_PARAMS);
597        assert!(
598            err.message
599                .contains("message.parts: At least one part is required")
600        );
601        let details = err.details.expect("expected details");
602        assert_eq!(details.len(), 1);
603        assert_eq!(details[0].type_url, errordetails::BAD_REQUEST_TYPE);
604        let violations = details[0].value.get("fieldViolations").unwrap();
605        assert_eq!(violations[0]["field"], "message.parts");
606    }
607
608    #[test]
609    fn test_parse_rest_error_bad_request_with_error_info_uses_reason() {
610        let body = json!({
611            "error": {
612                "code": 400,
613                "status": "INVALID_ARGUMENT",
614                "message": "bad params",
615                "details": [
616                    {
617                        "@type": errordetails::BAD_REQUEST_TYPE,
618                        "fieldViolations": [
619                            {"field": "task.id", "description": "required"}
620                        ]
621                    },
622                    {
623                        "@type": errordetails::ERROR_INFO_TYPE,
624                        "reason": "INVALID_PARAMS",
625                        "domain": errordetails::PROTOCOL_DOMAIN,
626                        "metadata": {}
627                    }
628                ]
629            }
630        })
631        .to_string();
632
633        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
634        assert_eq!(err.code, error_code::INVALID_PARAMS);
635    }
636
637    #[test]
638    fn test_with_root_certificates_pem_valid() {
639        let pem = rcgen_self_signed_ca_pem();
640        let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
641        assert_eq!(f.protocol(), TRANSPORT_PROTOCOL_HTTP_JSON);
642    }
643
644    #[tokio::test]
645    async fn test_with_root_certificates_pem_factory_create() {
646        let pem = rcgen_self_signed_ca_pem();
647        let f = RestTransportFactory::with_root_certificates_pem(&pem).unwrap();
648        let card = AgentCard {
649            name: "Test".into(),
650            description: "Test".into(),
651            version: "1.0".into(),
652            supported_interfaces: vec![],
653            capabilities: AgentCapabilities::default(),
654            default_input_modes: vec!["text/plain".into()],
655            default_output_modes: vec!["text/plain".into()],
656            skills: vec![],
657            provider: None,
658            documentation_url: None,
659            icon_url: None,
660            security_schemes: None,
661            security_requirements: None,
662            signatures: None,
663        };
664        let iface = AgentInterface::new("https://localhost:3443/rest", "HTTP+JSON");
665        let transport = f.create(&card, &iface).await.unwrap();
666        transport.destroy().await.unwrap();
667    }
668}