Skip to main content

a2a_protocol_client/transport/
jsonrpc.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! JSON-RPC 2.0 over HTTP transport implementation.
5//!
6//! [`JsonRpcTransport`] sends every A2A method call as an HTTP POST to the
7//! agent's single JSON-RPC endpoint URL. Streaming requests include
8//! `Accept: text/event-stream` and the response body is consumed as SSE.
9//!
10//! # Connection pooling
11//!
12//! The underlying [`hyper_util::client::legacy::Client`] pools connections
13//! across requests. Cloning [`JsonRpcTransport`] is cheap — it clones the
14//! inner `Arc`.
15
16use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use http_body_util::{BodyExt, Full};
23use hyper::body::Bytes;
24use hyper::header;
25#[cfg(not(feature = "tls-rustls"))]
26use hyper_util::client::legacy::connect::HttpConnector;
27#[cfg(not(feature = "tls-rustls"))]
28use hyper_util::client::legacy::Client;
29#[cfg(not(feature = "tls-rustls"))]
30use hyper_util::rt::TokioExecutor;
31use tokio::sync::mpsc;
32use uuid::Uuid;
33
34use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
35
36use crate::error::{ClientError, ClientResult};
37use crate::streaming::EventStream;
38use crate::transport::Transport;
39
40// ── Type aliases ──────────────────────────────────────────────────────────────
41
42#[cfg(not(feature = "tls-rustls"))]
43type HttpClient = Client<HttpConnector, Full<Bytes>>;
44
45#[cfg(feature = "tls-rustls")]
46type HttpClient = crate::tls::HttpsClient;
47
48// ── JsonRpcTransport ──────────────────────────────────────────────────────────
49
50/// JSON-RPC 2.0 transport: HTTP POST to a single endpoint.
51///
52/// Create via [`JsonRpcTransport::new`] or let [`crate::ClientBuilder`]
53/// construct one automatically from the agent card.
54#[derive(Clone, Debug)]
55pub struct JsonRpcTransport {
56    inner: Arc<Inner>,
57}
58
59#[derive(Debug)]
60struct Inner {
61    client: HttpClient,
62    endpoint: String,
63    request_timeout: Duration,
64    stream_connect_timeout: Duration,
65}
66
67impl JsonRpcTransport {
68    /// Creates a new transport targeting the given endpoint URL.
69    ///
70    /// The endpoint is typically the `url` field from an [`a2a_protocol_types::AgentCard`].
71    ///
72    /// # Errors
73    ///
74    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
75    pub fn new(endpoint: impl Into<String>) -> ClientResult<Self> {
76        Self::with_timeout(endpoint, Duration::from_secs(30))
77    }
78
79    /// Creates a new transport with a custom request timeout.
80    ///
81    /// # Errors
82    ///
83    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
84    pub fn with_timeout(
85        endpoint: impl Into<String>,
86        request_timeout: Duration,
87    ) -> ClientResult<Self> {
88        Self::with_timeouts(endpoint, request_timeout, request_timeout)
89    }
90
91    /// Creates a new transport with separate request and stream connect timeouts.
92    ///
93    /// # Errors
94    ///
95    /// Returns [`ClientError::InvalidEndpoint`] if the URL is malformed.
96    pub fn with_timeouts(
97        endpoint: impl Into<String>,
98        request_timeout: Duration,
99        stream_connect_timeout: Duration,
100    ) -> ClientResult<Self> {
101        let endpoint = endpoint.into();
102        validate_url(&endpoint)?;
103
104        #[cfg(not(feature = "tls-rustls"))]
105        let client = Client::builder(TokioExecutor::new()).build_http::<Full<Bytes>>();
106
107        #[cfg(feature = "tls-rustls")]
108        let client = crate::tls::build_https_client();
109
110        Ok(Self {
111            inner: Arc::new(Inner {
112                client,
113                endpoint,
114                request_timeout,
115                stream_connect_timeout,
116            }),
117        })
118    }
119
120    /// Returns the endpoint URL this transport targets.
121    #[must_use]
122    pub fn endpoint(&self) -> &str {
123        &self.inner.endpoint
124    }
125
126    // ── internals ─────────────────────────────────────────────────────────────
127
128    fn build_request(
129        &self,
130        method: &str,
131        params: serde_json::Value,
132        extra_headers: &HashMap<String, String>,
133        accept_sse: bool,
134    ) -> ClientResult<hyper::Request<Full<Bytes>>> {
135        let id = serde_json::Value::String(Uuid::new_v4().to_string());
136        let rpc_req = JsonRpcRequest::with_params(id, method, params);
137        let body_bytes = serde_json::to_vec(&rpc_req).map_err(ClientError::Serialization)?;
138
139        let accept = if accept_sse {
140            "text/event-stream"
141        } else {
142            "application/json"
143        };
144
145        let mut builder = hyper::Request::builder()
146            .method(hyper::Method::POST)
147            .uri(&self.inner.endpoint)
148            .header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
149            .header(
150                a2a_protocol_types::A2A_VERSION_HEADER,
151                a2a_protocol_types::A2A_VERSION,
152            )
153            .header(header::ACCEPT, accept);
154
155        for (k, v) in extra_headers {
156            builder = builder.header(k.as_str(), v.as_str());
157        }
158
159        builder
160            .body(Full::new(Bytes::from(body_bytes)))
161            .map_err(|e| ClientError::Transport(e.to_string()))
162    }
163
164    async fn execute_request(
165        &self,
166        method: &str,
167        params: serde_json::Value,
168        extra_headers: &HashMap<String, String>,
169    ) -> ClientResult<serde_json::Value> {
170        trace_info!(method, endpoint = %self.inner.endpoint, "sending JSON-RPC request");
171
172        let req = self.build_request(method, params, extra_headers, false)?;
173
174        let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
175            .await
176            .map_err(|_| {
177                trace_error!(method, "request timed out");
178                ClientError::Transport("request timed out".into())
179            })?
180            .map_err(|e| {
181                trace_error!(method, error = %e, "HTTP client error");
182                ClientError::HttpClient(e.to_string())
183            })?;
184
185        let status = resp.status();
186        trace_debug!(method, %status, "received response");
187
188        let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
189
190        if !status.is_success() {
191            let body_str = String::from_utf8_lossy(&body_bytes);
192            trace_warn!(method, %status, "unexpected HTTP status");
193            return Err(ClientError::UnexpectedStatus {
194                status: status.as_u16(),
195                body: super::truncate_body(&body_str),
196            });
197        }
198
199        let envelope: JsonRpcResponse<serde_json::Value> =
200            serde_json::from_slice(&body_bytes).map_err(ClientError::Serialization)?;
201
202        match envelope {
203            JsonRpcResponse::Success(ok) => {
204                trace_info!(method, "request succeeded");
205                Ok(ok.result)
206            }
207            JsonRpcResponse::Error(err) => {
208                trace_warn!(method, code = err.error.code, "JSON-RPC error response");
209                let a2a = a2a_protocol_types::A2aError::new(
210                    a2a_protocol_types::ErrorCode::try_from(err.error.code)
211                        .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
212                    err.error.message,
213                );
214                Err(ClientError::Protocol(a2a))
215            }
216        }
217    }
218
219    async fn execute_streaming_request(
220        &self,
221        method: &str,
222        params: serde_json::Value,
223        extra_headers: &HashMap<String, String>,
224    ) -> ClientResult<EventStream> {
225        trace_info!(method, endpoint = %self.inner.endpoint, "opening SSE stream");
226
227        let req = self.build_request(method, params, extra_headers, true)?;
228
229        let resp = tokio::time::timeout(
230            self.inner.stream_connect_timeout,
231            self.inner.client.request(req),
232        )
233        .await
234        .map_err(|_| {
235            trace_error!(method, "stream connect timed out");
236            ClientError::Timeout("stream connect timed out".into())
237        })?
238        .map_err(|e| {
239            trace_error!(method, error = %e, "HTTP client error");
240            ClientError::HttpClient(e.to_string())
241        })?;
242
243        let status = resp.status();
244        if !status.is_success() {
245            let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
246            let body_str = String::from_utf8_lossy(&body_bytes);
247            return Err(ClientError::UnexpectedStatus {
248                status: status.as_u16(),
249                body: super::truncate_body(&body_str),
250            });
251        }
252
253        let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
254        let body = resp.into_body();
255
256        // Spawn a background task that reads body chunks and forwards them.
257        let task_handle = tokio::spawn(async move {
258            body_reader_task(body, tx).await;
259        });
260
261        Ok(EventStream::with_abort_handle(
262            rx,
263            task_handle.abort_handle(),
264        ))
265    }
266}
267
268impl Transport for JsonRpcTransport {
269    fn send_request<'a>(
270        &'a self,
271        method: &'a str,
272        params: serde_json::Value,
273        extra_headers: &'a HashMap<String, String>,
274    ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
275        Box::pin(self.execute_request(method, params, extra_headers))
276    }
277
278    fn send_streaming_request<'a>(
279        &'a self,
280        method: &'a str,
281        params: serde_json::Value,
282        extra_headers: &'a HashMap<String, String>,
283    ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
284        Box::pin(self.execute_streaming_request(method, params, extra_headers))
285    }
286}
287
288// ── Body reader task ──────────────────────────────────────────────────────────
289
290/// Background task: reads chunks from a hyper response body and forwards them
291/// to the SSE channel.
292///
293/// Exits when the body is exhausted or the channel receiver is dropped.
294async fn body_reader_task(
295    body: hyper::body::Incoming,
296    tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
297) {
298    // We need to pin the Incoming body before polling it.
299    // Safety: we do not move `body` after this point.
300    tokio::pin!(body);
301
302    loop {
303        // Poll one frame from the body.
304        let frame = std::future::poll_fn(|cx| {
305            use hyper::body::Body;
306            // SAFETY: `body` is pinned by `tokio::pin!` above and we do not
307            // move it. `Pin::new_unchecked` is safe here because the future
308            // created by `poll_fn` ensures stable addressing.
309            let pinned = unsafe { Pin::new_unchecked(&mut *body) };
310            pinned.poll_frame(cx)
311        })
312        .await;
313
314        match frame {
315            None => break, // body exhausted
316            Some(Err(e)) => {
317                let _ = tx.send(Err(ClientError::Http(e))).await;
318                break;
319            }
320            Some(Ok(frame)) => {
321                if let Ok(data) = frame.into_data() {
322                    if tx.send(Ok(data)).await.is_err() {
323                        // Receiver dropped; stop reading.
324                        break;
325                    }
326                }
327                // Non-data frames (trailers) are skipped.
328            }
329        }
330    }
331}
332
333// ── Helpers ───────────────────────────────────────────────────────────────────
334
335fn validate_url(url: &str) -> ClientResult<()> {
336    if url.is_empty() {
337        return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
338    }
339    if !url.starts_with("http://") && !url.starts_with("https://") {
340        return Err(ClientError::InvalidEndpoint(format!(
341            "URL must start with http:// or https://: {url}"
342        )));
343    }
344    Ok(())
345}
346
347// ── Tests ─────────────────────────────────────────────────────────────────────
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn validate_url_rejects_empty() {
355        assert!(validate_url("").is_err());
356    }
357
358    #[test]
359    fn validate_url_rejects_non_http() {
360        assert!(validate_url("ftp://example.com").is_err());
361    }
362
363    #[test]
364    fn validate_url_accepts_http() {
365        assert!(validate_url("http://localhost:8080").is_ok());
366    }
367
368    #[test]
369    fn validate_url_accepts_https() {
370        assert!(validate_url("https://agent.example.com/a2a").is_ok());
371    }
372
373    #[test]
374    fn transport_new_rejects_bad_url() {
375        assert!(JsonRpcTransport::new("not-a-url").is_err());
376    }
377
378    #[test]
379    fn transport_new_stores_endpoint() {
380        let t = JsonRpcTransport::new("http://localhost:9090").unwrap();
381        assert_eq!(t.endpoint(), "http://localhost:9090");
382    }
383}