Skip to main content

a2a_protocol_client/transport/
rest.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! HTTP REST transport implementation.
5//!
6//! [`RestTransport`] maps A2A method names to REST HTTP verb + path
7//! combinations, extracts path parameters from the JSON params, and sends
8//! standard JSON bodies.
9//!
10//! # Method → REST mapping
11//!
12//! | A2A method | HTTP verb | Path |
13//! |---|---|---|
14//! | `SendMessage` | POST | `/message:send` |
15//! | `SendStreamingMessage` | POST | `/message:stream` |
16//! | `GetTask` | GET | `/tasks/{id}` |
17//! | `CancelTask` | POST | `/tasks/{id}:cancel` |
18//! | `ListTasks` | GET | `/tasks` |
19//! | `SubscribeToTask` | POST | `/tasks/{id}:subscribe` |
20//! | `CreateTaskPushNotificationConfig` | POST | `/tasks/{id}/pushNotificationConfigs` |
21//! | `GetTaskPushNotificationConfig` | GET | `/tasks/{id}/pushNotificationConfigs/{configId}` |
22//! | `ListTaskPushNotificationConfigs` | GET | `/tasks/{id}/pushNotificationConfigs` |
23//! | `DeleteTaskPushNotificationConfig` | DELETE | `/tasks/{id}/pushNotificationConfigs/{configId}` |
24//! | `GetExtendedAgentCard` | GET | `/extendedAgentCard` |
25
26use std::collections::HashMap;
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::time::Duration;
31
32use http_body_util::{BodyExt, Full};
33use hyper::body::Bytes;
34use hyper::header;
35#[cfg(not(feature = "tls-rustls"))]
36use hyper_util::client::legacy::connect::HttpConnector;
37#[cfg(not(feature = "tls-rustls"))]
38use hyper_util::client::legacy::Client;
39#[cfg(not(feature = "tls-rustls"))]
40use hyper_util::rt::TokioExecutor;
41use tokio::sync::mpsc;
42
43use a2a_protocol_types::JsonRpcResponse;
44
45use crate::error::{ClientError, ClientResult};
46use crate::streaming::EventStream;
47use crate::transport::Transport;
48
49// ── Type aliases ──────────────────────────────────────────────────────────────
50
51#[cfg(not(feature = "tls-rustls"))]
52type HttpClient = Client<HttpConnector, Full<Bytes>>;
53
54#[cfg(feature = "tls-rustls")]
55type HttpClient = crate::tls::HttpsClient;
56
57// ── Route ─────────────────────────────────────────────────────────────────────
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60enum HttpMethod {
61    Get,
62    Post,
63    Delete,
64}
65
66#[derive(Debug)]
67struct Route {
68    http_method: HttpMethod,
69    path_template: &'static str,
70    /// Names of params that are path parameters (extracted from JSON params).
71    path_params: &'static [&'static str],
72    /// Whether the response is SSE (used in tests).
73    #[allow(dead_code)]
74    streaming: bool,
75}
76
77// ── Method routing ────────────────────────────────────────────────────────────
78
79#[allow(clippy::too_many_lines)]
80fn route_for(method: &str) -> Option<Route> {
81    match method {
82        "SendMessage" => Some(Route {
83            http_method: HttpMethod::Post,
84            path_template: "/message:send",
85            path_params: &[],
86            streaming: false,
87        }),
88        "SendStreamingMessage" => Some(Route {
89            http_method: HttpMethod::Post,
90            path_template: "/message:stream",
91            path_params: &[],
92            streaming: true,
93        }),
94        "GetTask" => Some(Route {
95            http_method: HttpMethod::Get,
96            path_template: "/tasks/{id}",
97            path_params: &["id"],
98            streaming: false,
99        }),
100        "CancelTask" => Some(Route {
101            http_method: HttpMethod::Post,
102            path_template: "/tasks/{id}:cancel",
103            path_params: &["id"],
104            streaming: false,
105        }),
106        "ListTasks" => Some(Route {
107            http_method: HttpMethod::Get,
108            path_template: "/tasks",
109            path_params: &[],
110            streaming: false,
111        }),
112        "SubscribeToTask" => Some(Route {
113            http_method: HttpMethod::Post,
114            path_template: "/tasks/{id}:subscribe",
115            path_params: &["id"],
116            streaming: true,
117        }),
118        "CreateTaskPushNotificationConfig" => Some(Route {
119            http_method: HttpMethod::Post,
120            path_template: "/tasks/{taskId}/pushNotificationConfigs",
121            path_params: &["taskId"],
122            streaming: false,
123        }),
124        "GetTaskPushNotificationConfig" => Some(Route {
125            http_method: HttpMethod::Get,
126            path_template: "/tasks/{taskId}/pushNotificationConfigs/{id}",
127            path_params: &["taskId", "id"],
128            streaming: false,
129        }),
130        "ListTaskPushNotificationConfigs" => Some(Route {
131            http_method: HttpMethod::Get,
132            path_template: "/tasks/{taskId}/pushNotificationConfigs",
133            path_params: &["taskId"],
134            streaming: false,
135        }),
136        "DeleteTaskPushNotificationConfig" => Some(Route {
137            http_method: HttpMethod::Delete,
138            path_template: "/tasks/{taskId}/pushNotificationConfigs/{id}",
139            path_params: &["taskId", "id"],
140            streaming: false,
141        }),
142        "GetExtendedAgentCard" => Some(Route {
143            http_method: HttpMethod::Get,
144            path_template: "/extendedAgentCard",
145            path_params: &[],
146            streaming: false,
147        }),
148        _ => None,
149    }
150}
151
152// ── RestTransport ─────────────────────────────────────────────────────────────
153
154/// REST transport: HTTP verbs mapped to REST paths.
155///
156/// Create via [`RestTransport::new`] or let [`crate::ClientBuilder`] construct
157/// one from the agent card.
158#[derive(Clone, Debug)]
159pub struct RestTransport {
160    inner: Arc<Inner>,
161}
162
163#[derive(Debug)]
164struct Inner {
165    client: HttpClient,
166    base_url: String,
167    request_timeout: Duration,
168    stream_connect_timeout: Duration,
169}
170
171impl RestTransport {
172    /// Creates a new transport using `base_url` as the root URL.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
177    pub fn new(base_url: impl Into<String>) -> ClientResult<Self> {
178        Self::with_timeout(base_url, Duration::from_secs(30))
179    }
180
181    /// Creates a new transport with a custom request timeout.
182    ///
183    /// # Errors
184    ///
185    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
186    pub fn with_timeout(
187        base_url: impl Into<String>,
188        request_timeout: Duration,
189    ) -> ClientResult<Self> {
190        Self::with_timeouts(base_url, request_timeout, request_timeout)
191    }
192
193    /// Creates a new transport with separate request and stream connect timeouts.
194    ///
195    /// # Errors
196    ///
197    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
198    pub fn with_timeouts(
199        base_url: impl Into<String>,
200        request_timeout: Duration,
201        stream_connect_timeout: Duration,
202    ) -> ClientResult<Self> {
203        let base_url = base_url.into();
204        if base_url.is_empty()
205            || (!base_url.starts_with("http://") && !base_url.starts_with("https://"))
206        {
207            return Err(ClientError::InvalidEndpoint(format!(
208                "invalid base URL: {base_url}"
209            )));
210        }
211
212        #[cfg(not(feature = "tls-rustls"))]
213        let client = Client::builder(TokioExecutor::new()).build_http::<Full<Bytes>>();
214
215        #[cfg(feature = "tls-rustls")]
216        let client = crate::tls::build_https_client();
217
218        Ok(Self {
219            inner: Arc::new(Inner {
220                client,
221                base_url: base_url.trim_end_matches('/').to_owned(),
222                request_timeout,
223                stream_connect_timeout,
224            }),
225        })
226    }
227
228    /// Returns the base URL this transport targets.
229    #[must_use]
230    pub fn base_url(&self) -> &str {
231        &self.inner.base_url
232    }
233
234    // ── internals ─────────────────────────────────────────────────────────────
235
236    fn build_uri(
237        &self,
238        route: &Route,
239        params: &serde_json::Value,
240    ) -> ClientResult<(String, serde_json::Value)> {
241        let mut path = route.path_template.to_owned();
242        let mut remaining = params.clone();
243
244        for &param in route.path_params {
245            let value = remaining
246                .get(param)
247                .and_then(serde_json::Value::as_str)
248                .ok_or_else(|| ClientError::Transport(format!("missing path parameter: {param}")))?
249                .to_owned();
250
251            path = path.replace(&format!("{{{param}}}"), &value);
252
253            if let Some(obj) = remaining.as_object_mut() {
254                obj.remove(param);
255            }
256        }
257
258        let mut uri = format!("{}{path}", self.inner.base_url);
259
260        // For GET/DELETE, append remaining params as query string.
261        if route.http_method == HttpMethod::Get || route.http_method == HttpMethod::Delete {
262            let query = build_query_string(&remaining);
263            if !query.is_empty() {
264                uri.push('?');
265                uri.push_str(&query);
266            }
267        }
268
269        Ok((uri, remaining))
270    }
271
272    fn build_request(
273        &self,
274        method: &str,
275        params: &serde_json::Value,
276        extra_headers: &HashMap<String, String>,
277        streaming: bool,
278    ) -> ClientResult<hyper::Request<Full<Bytes>>> {
279        let route = route_for(method)
280            .ok_or_else(|| ClientError::Transport(format!("no REST route for method: {method}")))?;
281
282        let (uri, body_params) = self.build_uri(&route, params)?;
283        let accept = if streaming {
284            "text/event-stream"
285        } else {
286            "application/json"
287        };
288
289        let hyper_method = match route.http_method {
290            HttpMethod::Get => hyper::Method::GET,
291            HttpMethod::Post => hyper::Method::POST,
292            HttpMethod::Delete => hyper::Method::DELETE,
293        };
294
295        let body =
296            if route.http_method == HttpMethod::Get || route.http_method == HttpMethod::Delete {
297                // For GET/DELETE, body is empty; params were in the path.
298                Full::new(Bytes::new())
299            } else {
300                let bytes = serde_json::to_vec(&body_params).map_err(ClientError::Serialization)?;
301                Full::new(Bytes::from(bytes))
302            };
303
304        let mut builder = hyper::Request::builder()
305            .method(hyper_method)
306            .uri(uri)
307            .header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
308            .header(
309                a2a_protocol_types::A2A_VERSION_HEADER,
310                a2a_protocol_types::A2A_VERSION,
311            )
312            .header(header::ACCEPT, accept);
313
314        for (k, v) in extra_headers {
315            builder = builder.header(k.as_str(), v.as_str());
316        }
317
318        builder
319            .body(body)
320            .map_err(|e| ClientError::Transport(e.to_string()))
321    }
322
323    async fn execute_request(
324        &self,
325        method: &str,
326        params: serde_json::Value,
327        extra_headers: &HashMap<String, String>,
328    ) -> ClientResult<serde_json::Value> {
329        trace_info!(method, base_url = %self.inner.base_url, "sending REST request");
330
331        let req = self.build_request(method, &params, extra_headers, false)?;
332
333        let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
334            .await
335            .map_err(|_| {
336                trace_error!(method, "request timed out");
337                ClientError::Transport("request timed out".into())
338            })?
339            .map_err(|e| {
340                trace_error!(method, error = %e, "HTTP client error");
341                ClientError::HttpClient(e.to_string())
342            })?;
343
344        let status = resp.status();
345        trace_debug!(method, %status, "received response");
346        let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
347
348        if !status.is_success() {
349            let body_str = String::from_utf8_lossy(&body_bytes);
350            return Err(ClientError::UnexpectedStatus {
351                status: status.as_u16(),
352                body: super::truncate_body(&body_str),
353            });
354        }
355
356        // REST responses may or may not wrap in JSON-RPC; try JSON-RPC first.
357        if let Ok(envelope) =
358            serde_json::from_slice::<JsonRpcResponse<serde_json::Value>>(&body_bytes)
359        {
360            return match envelope {
361                JsonRpcResponse::Success(ok) => Ok(ok.result),
362                JsonRpcResponse::Error(err) => {
363                    let a2a = a2a_protocol_types::A2aError::new(
364                        a2a_protocol_types::ErrorCode::try_from(err.error.code)
365                            .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
366                        err.error.message,
367                    );
368                    Err(ClientError::Protocol(a2a))
369                }
370            };
371        }
372
373        // Fall back to raw JSON value.
374        serde_json::from_slice(&body_bytes).map_err(ClientError::Serialization)
375    }
376
377    async fn execute_streaming_request(
378        &self,
379        method: &str,
380        params: serde_json::Value,
381        extra_headers: &HashMap<String, String>,
382    ) -> ClientResult<EventStream> {
383        trace_info!(method, base_url = %self.inner.base_url, "opening REST SSE stream");
384
385        let req = self.build_request(method, &params, extra_headers, true)?;
386
387        let resp = tokio::time::timeout(
388            self.inner.stream_connect_timeout,
389            self.inner.client.request(req),
390        )
391        .await
392        .map_err(|_| {
393            trace_error!(method, "stream connect timed out");
394            ClientError::Timeout("stream connect timed out".into())
395        })?
396        .map_err(|e| {
397            trace_error!(method, error = %e, "HTTP client error");
398            ClientError::HttpClient(e.to_string())
399        })?;
400
401        let status = resp.status();
402        if !status.is_success() {
403            let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
404            let body_str = String::from_utf8_lossy(&body_bytes);
405            return Err(ClientError::UnexpectedStatus {
406                status: status.as_u16(),
407                body: super::truncate_body(&body_str),
408            });
409        }
410
411        let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
412        let body = resp.into_body();
413
414        let task_handle = tokio::spawn(async move {
415            body_reader_task(body, tx).await;
416        });
417
418        Ok(EventStream::with_abort_handle(
419            rx,
420            task_handle.abort_handle(),
421        ))
422    }
423}
424
425impl Transport for RestTransport {
426    fn send_request<'a>(
427        &'a self,
428        method: &'a str,
429        params: serde_json::Value,
430        extra_headers: &'a HashMap<String, String>,
431    ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
432        Box::pin(self.execute_request(method, params, extra_headers))
433    }
434
435    fn send_streaming_request<'a>(
436        &'a self,
437        method: &'a str,
438        params: serde_json::Value,
439        extra_headers: &'a HashMap<String, String>,
440    ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
441        Box::pin(self.execute_streaming_request(method, params, extra_headers))
442    }
443}
444
445// ── Body reader task (shared with jsonrpc.rs pattern) ─────────────────────────
446
447async fn body_reader_task(
448    body: hyper::body::Incoming,
449    tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
450) {
451    tokio::pin!(body);
452    loop {
453        let frame = std::future::poll_fn(|cx| {
454            use hyper::body::Body;
455            // SAFETY: `body` is pinned by `tokio::pin!` and not moved.
456            let pinned = unsafe { Pin::new_unchecked(&mut *body) };
457            pinned.poll_frame(cx)
458        })
459        .await;
460
461        match frame {
462            None => break,
463            Some(Err(e)) => {
464                let _ = tx.send(Err(ClientError::Http(e))).await;
465                break;
466            }
467            Some(Ok(f)) => {
468                if let Ok(data) = f.into_data() {
469                    if tx.send(Ok(data)).await.is_err() {
470                        break;
471                    }
472                }
473            }
474        }
475    }
476}
477
478// ── Query string builder ─────────────────────────────────────────────────────
479
480/// Builds a URL query string from a JSON object's non-null fields.
481fn build_query_string(params: &serde_json::Value) -> String {
482    let Some(obj) = params.as_object() else {
483        return String::new();
484    };
485    let mut parts = Vec::new();
486    for (k, v) in obj {
487        match v {
488            serde_json::Value::Null => {}
489            serde_json::Value::String(s) => parts.push(format!("{k}={s}")),
490            serde_json::Value::Number(n) => parts.push(format!("{k}={n}")),
491            serde_json::Value::Bool(b) => parts.push(format!("{k}={b}")),
492            _ => {
493                // Complex types: serialize as JSON string.
494                if let Ok(s) = serde_json::to_string(v) {
495                    parts.push(format!("{k}={s}"));
496                }
497            }
498        }
499    }
500    parts.join("&")
501}
502
503// ── Tests ─────────────────────────────────────────────────────────────────────
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn route_for_known_methods() {
511        assert!(route_for("SendMessage").is_some());
512        assert!(route_for("GetTask").is_some());
513        assert!(route_for("ListTasks").is_some());
514        assert!(route_for("SendStreamingMessage").is_some_and(|r| r.streaming));
515    }
516
517    #[test]
518    fn route_for_unknown_method_returns_none() {
519        assert!(route_for("unknown/method").is_none());
520    }
521
522    #[test]
523    fn build_uri_extracts_path_param_and_appends_query() {
524        let transport = RestTransport::new("http://localhost:8080").unwrap();
525        let route = route_for("GetTask").unwrap();
526        let params = serde_json::json!({"id": "task-123", "historyLength": 5});
527        let (uri, _remaining) = transport.build_uri(&route, &params).unwrap();
528        assert!(
529            uri.starts_with("http://localhost:8080/tasks/task-123"),
530            "should have task ID in path"
531        );
532        assert!(
533            uri.contains("historyLength=5"),
534            "should have historyLength in query"
535        );
536    }
537
538    #[test]
539    fn build_uri_appends_query_for_get() {
540        let transport = RestTransport::new("http://localhost:8080").unwrap();
541        let route = route_for("ListTasks").unwrap();
542        let params = serde_json::json!({"pageSize": 10});
543        let (uri, _remaining) = transport.build_uri(&route, &params).unwrap();
544        assert!(uri.contains("pageSize=10"), "should have pageSize in query");
545    }
546
547    #[test]
548    fn rest_transport_rejects_invalid_url() {
549        assert!(RestTransport::new("not-a-url").is_err());
550    }
551
552    #[test]
553    fn rest_transport_stores_base_url() {
554        let t = RestTransport::new("http://localhost:9090").unwrap();
555        assert_eq!(t.base_url(), "http://localhost:9090");
556    }
557}