Skip to main content

a2a_client/
rest.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::*;
4use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
5use async_trait::async_trait;
6use futures::stream::BoxStream;
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::Value;
10use std::collections::HashMap;
11
12use crate::push_config_compat::{
13    deserialize_list_task_push_notification_configs_response,
14    deserialize_task_push_notification_config,
15};
16use crate::transport::{ServiceParams, Transport, TransportFactory};
17
18const REST_SEND_MESSAGE_PATH: &str = "/message:send";
19const REST_STREAM_MESSAGE_PATH: &str = "/message:stream";
20const REST_EXTENDED_AGENT_CARD_PATH: &str = "/extendedAgentCard";
21const REST_ERROR_INFO_TYPE_URL: &str = "type.googleapis.com/google.rpc.ErrorInfo";
22const REST_ERROR_DOMAIN: &str = "a2a-protocol.org";
23
24#[derive(Debug, Deserialize)]
25struct RestErrorEnvelope {
26    error: RestErrorStatus,
27}
28
29#[derive(Debug, Deserialize)]
30struct RestErrorStatus {
31    message: String,
32
33    #[serde(default)]
34    details: Vec<Value>,
35}
36
37#[derive(Debug, Deserialize)]
38struct RestErrorInfo {
39    #[serde(rename = "@type")]
40    type_url: String,
41    reason: String,
42    domain: String,
43
44    #[serde(default)]
45    metadata: HashMap<String, String>,
46}
47
48/// REST (HTTP+JSON) transport implementation.
49///
50/// Maps A2A operations to RESTful HTTP endpoints.
51pub struct RestTransport {
52    client: Client,
53    base_url: String,
54}
55
56impl RestTransport {
57    pub fn new(client: Client, base_url: String) -> Self {
58        let base_url = base_url.trim_end_matches('/').to_string();
59        RestTransport { client, base_url }
60    }
61
62    fn build_request(
63        &self,
64        method: reqwest::Method,
65        path: &str,
66        params: &ServiceParams,
67    ) -> reqwest::RequestBuilder {
68        let url = format!("{}{}", self.base_url, path);
69        let mut builder = self.client.request(method, &url);
70        for (key, values) in params {
71            for v in values {
72                builder = builder.header(key, v);
73            }
74        }
75        builder
76    }
77
78    fn build_request_with_query(
79        &self,
80        method: reqwest::Method,
81        path: &str,
82        params: &ServiceParams,
83        query: &[(String, String)],
84    ) -> reqwest::RequestBuilder {
85        let builder = self.build_request(method, path, params);
86        if query.is_empty() {
87            builder
88        } else {
89            builder.query(query)
90        }
91    }
92
93    async fn send(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::Response, A2AError> {
94        builder
95            .send()
96            .await
97            .map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))
98    }
99
100    async fn into_rest_error(resp: reqwest::Response) -> A2AError {
101        let status = resp.status();
102        let body = resp.text().await.unwrap_or_default();
103        parse_rest_error(status, &body)
104    }
105
106    async fn post_value<Req>(
107        &self,
108        path: &str,
109        params: &ServiceParams,
110        body: &Req,
111    ) -> Result<Value, A2AError>
112    where
113        Req: ProtoJsonPayload,
114    {
115        let payload = protojson_conv::to_value(body).map_err(|e| {
116            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
117        })?;
118        let resp = self
119            .send(
120                self.build_request(reqwest::Method::POST, path, params)
121                    .json(&payload),
122            )
123            .await?;
124
125        if !resp.status().is_success() {
126            return Err(Self::into_rest_error(resp).await);
127        }
128        let payload = resp
129            .json::<Value>()
130            .await
131            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
132
133        Ok(payload)
134    }
135
136    async fn post_json<Req, Resp>(
137        &self,
138        path: &str,
139        params: &ServiceParams,
140        body: &Req,
141    ) -> Result<Resp, A2AError>
142    where
143        Req: ProtoJsonPayload,
144        Resp: ProtoJsonPayload,
145    {
146        let payload = self.post_value(path, params, body).await?;
147
148        protojson_conv::from_value(payload).map_err(|e| {
149            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
150        })
151    }
152
153    async fn get_value(
154        &self,
155        path: &str,
156        params: &ServiceParams,
157        query: &[(String, String)],
158    ) -> Result<Value, A2AError> {
159        let resp = self
160            .send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
161            .await?;
162
163        if !resp.status().is_success() {
164            return Err(Self::into_rest_error(resp).await);
165        }
166        let payload = resp
167            .json::<Value>()
168            .await
169            .map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
170
171        Ok(payload)
172    }
173
174    async fn get_json<Resp>(
175        &self,
176        path: &str,
177        params: &ServiceParams,
178        query: &[(String, String)],
179    ) -> Result<Resp, A2AError>
180    where
181        Resp: ProtoJsonPayload,
182    {
183        let payload = self.get_value(path, params, query).await?;
184
185        protojson_conv::from_value(payload).map_err(|e| {
186            A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
187        })
188    }
189
190    async fn post_streaming<Req>(
191        &self,
192        path: &str,
193        params: &ServiceParams,
194        body: &Req,
195    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
196    where
197        Req: ProtoJsonPayload,
198    {
199        let payload = protojson_conv::to_value(body).map_err(|e| {
200            A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
201        })?;
202        let resp = self
203            .send(
204                self.build_request(reqwest::Method::POST, path, params)
205                    .header("Accept", "text/event-stream")
206                    .json(&payload),
207            )
208            .await?;
209
210        if !resp.status().is_success() {
211            return Err(Self::into_rest_error(resp).await);
212        }
213
214        let stream = resp.bytes_stream();
215        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
216    }
217
218    async fn get_streaming(
219        &self,
220        path: &str,
221        params: &ServiceParams,
222    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
223        let resp = self
224            .send(
225                self.build_request(reqwest::Method::GET, path, params)
226                    .header("Accept", "text/event-stream"),
227            )
228            .await?;
229
230        if !resp.status().is_success() {
231            return Err(Self::into_rest_error(resp).await);
232        }
233
234        let stream = resp.bytes_stream();
235        Ok(crate::jsonrpc::parse_sse_stream_rest(stream))
236    }
237
238    async fn delete(&self, path: &str, params: &ServiceParams) -> Result<(), A2AError> {
239        let resp = self
240            .send(self.build_request(reqwest::Method::DELETE, path, params))
241            .await?;
242
243        if !resp.status().is_success() {
244            return Err(Self::into_rest_error(resp).await);
245        }
246        Ok(())
247    }
248}
249
250fn parse_rest_error(status: reqwest::StatusCode, body: &str) -> A2AError {
251    let Ok(envelope) = serde_json::from_str::<RestErrorEnvelope>(body) else {
252        return A2AError::internal(format!("HTTP {status}: {body}"));
253    };
254
255    let mut details = HashMap::new();
256    let mut code = None;
257
258    for raw_detail in envelope.error.details {
259        if let Ok(info) = serde_json::from_value::<RestErrorInfo>(raw_detail.clone()) {
260            if info.type_url == REST_ERROR_INFO_TYPE_URL && info.domain == REST_ERROR_DOMAIN {
261                code = reason_to_error_code(&info.reason).or(code);
262                for (key, value) in info.metadata {
263                    details.insert(key, Value::String(value));
264                }
265                continue;
266            }
267        }
268
269        if let Value::Object(values) = raw_detail {
270            details.extend(values.into_iter());
271        }
272    }
273
274    A2AError {
275        code: code.unwrap_or(error_code::INTERNAL_ERROR),
276        message: envelope.error.message,
277        details: (!details.is_empty()).then_some(details),
278    }
279}
280
281fn reason_to_error_code(reason: &str) -> Option<i32> {
282    match reason {
283        "TASK_NOT_FOUND" => Some(error_code::TASK_NOT_FOUND),
284        "TASK_NOT_CANCELABLE" => Some(error_code::TASK_NOT_CANCELABLE),
285        "PUSH_NOTIFICATION_NOT_SUPPORTED" => Some(error_code::PUSH_NOTIFICATION_NOT_SUPPORTED),
286        "UNSUPPORTED_OPERATION" => Some(error_code::UNSUPPORTED_OPERATION),
287        "UNSUPPORTED_CONTENT_TYPE" | "CONTENT_TYPE_NOT_SUPPORTED" => {
288            Some(error_code::CONTENT_TYPE_NOT_SUPPORTED)
289        }
290        "INVALID_AGENT_RESPONSE" => Some(error_code::INVALID_AGENT_RESPONSE),
291        "EXTENDED_AGENT_CARD_NOT_CONFIGURED" | "EXTENDED_CARD_NOT_CONFIGURED" => {
292            Some(error_code::EXTENDED_CARD_NOT_CONFIGURED)
293        }
294        "EXTENSION_SUPPORT_REQUIRED" => Some(error_code::EXTENSION_SUPPORT_REQUIRED),
295        "VERSION_NOT_SUPPORTED" => Some(error_code::VERSION_NOT_SUPPORTED),
296        "PARSE_ERROR" => Some(error_code::PARSE_ERROR),
297        "INVALID_REQUEST" => Some(error_code::INVALID_REQUEST),
298        "METHOD_NOT_FOUND" => Some(error_code::METHOD_NOT_FOUND),
299        "INVALID_PARAMS" => Some(error_code::INVALID_PARAMS),
300        "INTERNAL_ERROR" => Some(error_code::INTERNAL_ERROR),
301        _ => None,
302    }
303}
304
305#[async_trait]
306impl Transport for RestTransport {
307    async fn send_message(
308        &self,
309        params: &ServiceParams,
310        req: &SendMessageRequest,
311    ) -> Result<SendMessageResponse, A2AError> {
312        self.post_json(REST_SEND_MESSAGE_PATH, params, req).await
313    }
314
315    async fn send_streaming_message(
316        &self,
317        params: &ServiceParams,
318        req: &SendMessageRequest,
319    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
320        self.post_streaming(REST_STREAM_MESSAGE_PATH, params, req)
321            .await
322    }
323
324    async fn get_task(
325        &self,
326        params: &ServiceParams,
327        req: &GetTaskRequest,
328    ) -> Result<Task, A2AError> {
329        let path = format!("/tasks/{}", req.id);
330        let mut query_parts = Vec::new();
331        if let Some(hl) = req.history_length {
332            query_parts.push(("historyLength".to_string(), hl.to_string()));
333        }
334        self.get_json(&path, params, &query_parts).await
335    }
336
337    async fn list_tasks(
338        &self,
339        params: &ServiceParams,
340        req: &ListTasksRequest,
341    ) -> Result<ListTasksResponse, A2AError> {
342        let mut query_parts = Vec::new();
343        if let Some(ref cid) = req.context_id {
344            query_parts.push(("contextId".to_string(), cid.clone()));
345        }
346        if let Some(ref status) = req.status {
347            let s = serde_json::to_value(status)
348                .ok()
349                .and_then(|v| v.as_str().map(String::from))
350                .unwrap_or_default();
351            query_parts.push(("status".to_string(), s));
352        }
353        if let Some(ps) = req.page_size {
354            query_parts.push(("pageSize".to_string(), ps.to_string()));
355        }
356        if let Some(ref pt) = req.page_token {
357            query_parts.push(("pageToken".to_string(), pt.clone()));
358        }
359        if let Some(hl) = req.history_length {
360            query_parts.push(("historyLength".to_string(), hl.to_string()));
361        }
362        if let Some(ref ts) = req.status_timestamp_after {
363            query_parts.push(("statusTimestampAfter".to_string(), ts.to_rfc3339()));
364        }
365        if let Some(ia) = req.include_artifacts {
366            query_parts.push(("includeArtifacts".to_string(), ia.to_string()));
367        }
368        self.get_json("/tasks", params, &query_parts).await
369    }
370
371    async fn cancel_task(
372        &self,
373        params: &ServiceParams,
374        req: &CancelTaskRequest,
375    ) -> Result<Task, A2AError> {
376        self.post_json(&format!("/tasks/{}:cancel", req.id), params, req)
377            .await
378    }
379
380    async fn subscribe_to_task(
381        &self,
382        params: &ServiceParams,
383        req: &SubscribeToTaskRequest,
384    ) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
385        self.get_streaming(&format!("/tasks/{}:subscribe", req.id), params)
386            .await
387    }
388
389    async fn create_push_config(
390        &self,
391        params: &ServiceParams,
392        req: &CreateTaskPushNotificationConfigRequest,
393    ) -> Result<TaskPushNotificationConfig, A2AError> {
394        let payload = self
395            .post_value(
396                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
397                params,
398                &req.config,
399            )
400            .await?;
401        deserialize_task_push_notification_config(payload)
402    }
403
404    async fn get_push_config(
405        &self,
406        params: &ServiceParams,
407        req: &GetTaskPushNotificationConfigRequest,
408    ) -> Result<TaskPushNotificationConfig, A2AError> {
409        let payload = self
410            .get_value(
411                &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
412                params,
413                &[],
414            )
415            .await?;
416        deserialize_task_push_notification_config(payload)
417    }
418
419    async fn list_push_configs(
420        &self,
421        params: &ServiceParams,
422        req: &ListTaskPushNotificationConfigsRequest,
423    ) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
424        let mut query_parts = Vec::new();
425        if let Some(page_size) = req.page_size {
426            query_parts.push(("pageSize".to_string(), page_size.to_string()));
427        }
428        if let Some(ref page_token) = req.page_token {
429            query_parts.push(("pageToken".to_string(), page_token.clone()));
430        }
431
432        let payload = self
433            .get_value(
434                &format!("/tasks/{}/pushNotificationConfigs", req.task_id),
435                params,
436                &query_parts,
437            )
438            .await?;
439        deserialize_list_task_push_notification_configs_response(payload)
440    }
441
442    async fn delete_push_config(
443        &self,
444        params: &ServiceParams,
445        req: &DeleteTaskPushNotificationConfigRequest,
446    ) -> Result<(), A2AError> {
447        self.delete(
448            &format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
449            params,
450        )
451        .await
452    }
453
454    async fn get_extended_agent_card(
455        &self,
456        params: &ServiceParams,
457        _req: &GetExtendedAgentCardRequest,
458    ) -> Result<AgentCard, A2AError> {
459        self.get_json(REST_EXTENDED_AGENT_CARD_PATH, params, &[])
460            .await
461    }
462
463    async fn destroy(&self) -> Result<(), A2AError> {
464        Ok(())
465    }
466}
467
468/// Factory for creating [`RestTransport`] instances.
469pub struct RestTransportFactory {
470    client: Client,
471}
472
473impl RestTransportFactory {
474    pub fn new(client: Option<Client>) -> Self {
475        RestTransportFactory {
476            client: client.unwrap_or_default(),
477        }
478    }
479}
480
481#[async_trait]
482impl TransportFactory for RestTransportFactory {
483    fn protocol(&self) -> &str {
484        TRANSPORT_PROTOCOL_HTTP_JSON
485    }
486
487    async fn create(
488        &self,
489        _card: &AgentCard,
490        iface: &AgentInterface,
491    ) -> Result<Box<dyn Transport>, A2AError> {
492        Ok(Box::new(RestTransport::new(
493            self.client.clone(),
494            iface.url.clone(),
495        )))
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use serde_json::json;
503
504    #[test]
505    fn test_rest_transport_new_strips_trailing_slash() {
506        let t = RestTransport::new(Client::new(), "http://localhost:8080/".into());
507        assert_eq!(t.base_url, "http://localhost:8080");
508    }
509
510    #[test]
511    fn test_rest_transport_new_no_trailing_slash() {
512        let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
513        assert_eq!(t.base_url, "http://localhost:8080");
514    }
515
516    #[test]
517    fn test_rest_transport_factory_protocol() {
518        let f = RestTransportFactory::new(None);
519        assert_eq!(f.protocol(), "HTTP+JSON");
520    }
521
522    #[tokio::test]
523    async fn test_rest_transport_factory_create() {
524        let f = RestTransportFactory::new(None);
525        let card = AgentCard {
526            name: "Test".into(),
527            description: "Test".into(),
528            version: "1.0".into(),
529            supported_interfaces: vec![],
530            capabilities: AgentCapabilities::default(),
531            default_input_modes: vec!["text/plain".into()],
532            default_output_modes: vec!["text/plain".into()],
533            skills: vec![],
534            provider: None,
535            documentation_url: None,
536            icon_url: None,
537            security_schemes: None,
538            security_requirements: None,
539            signatures: None,
540        };
541        let iface = AgentInterface::new("http://localhost:8080/", "HTTP+JSON");
542        let transport = f.create(&card, &iface).await.unwrap();
543        transport.destroy().await.unwrap();
544    }
545
546    #[test]
547    fn test_build_request_adds_params() {
548        let t = RestTransport::new(Client::new(), "http://localhost:8080".into());
549        let mut params = ServiceParams::new();
550        params.insert("X-Custom".into(), vec!["val1".into(), "val2".into()]);
551        let builder = t.build_request(reqwest::Method::GET, "/test", &params);
552        let req = builder.build().unwrap();
553        let vals: Vec<_> = req
554            .headers()
555            .get_all("X-Custom")
556            .iter()
557            .map(|v| v.to_str().unwrap().to_string())
558            .collect();
559        assert_eq!(vals, vec!["val1", "val2"]);
560    }
561
562    #[test]
563    fn test_parse_rest_error_preserves_a2a_error_code() {
564        let body = json!({
565            "error": {
566                "code": 404,
567                "status": "NOT_FOUND",
568                "message": "task not found: t1",
569                "details": [
570                    {
571                        "@type": REST_ERROR_INFO_TYPE_URL,
572                        "reason": "TASK_NOT_FOUND",
573                        "domain": REST_ERROR_DOMAIN,
574                        "metadata": {
575                            "taskId": "t1"
576                        }
577                    },
578                    {
579                        "resource": "task"
580                    }
581                ]
582            }
583        })
584        .to_string();
585
586        let err = parse_rest_error(reqwest::StatusCode::NOT_FOUND, &body);
587
588        assert_eq!(err.code, error_code::TASK_NOT_FOUND);
589        assert_eq!(err.message, "task not found: t1");
590        let details = err.details.expect("expected structured details");
591        assert_eq!(details.get("taskId"), Some(&Value::String("t1".into())));
592        assert_eq!(details.get("resource"), Some(&Value::String("task".into())));
593    }
594
595    #[test]
596    fn test_parse_rest_error_accepts_go_reason_aliases() {
597        let body = json!({
598            "error": {
599                "code": 400,
600                "status": "INVALID_ARGUMENT",
601                "message": "incompatible content types",
602                "details": [
603                    {
604                        "@type": REST_ERROR_INFO_TYPE_URL,
605                        "reason": "UNSUPPORTED_CONTENT_TYPE",
606                        "domain": REST_ERROR_DOMAIN,
607                        "metadata": {}
608                    }
609                ]
610            }
611        })
612        .to_string();
613
614        let err = parse_rest_error(reqwest::StatusCode::BAD_REQUEST, &body);
615        assert_eq!(err.code, error_code::CONTENT_TYPE_NOT_SUPPORTED);
616    }
617}