a2a_protocol_client/transport/
jsonrpc.rs1use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use http_body_util::{BodyExt, Full};
23use hyper::body::Bytes;
24use hyper::header;
25#[cfg(not(feature = "tls-rustls"))]
26use hyper_util::client::legacy::connect::HttpConnector;
27#[cfg(not(feature = "tls-rustls"))]
28use hyper_util::client::legacy::Client;
29#[cfg(not(feature = "tls-rustls"))]
30use hyper_util::rt::TokioExecutor;
31use tokio::sync::mpsc;
32use uuid::Uuid;
33
34use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
35
36use crate::error::{ClientError, ClientResult};
37use crate::streaming::EventStream;
38use crate::transport::Transport;
39
40#[cfg(not(feature = "tls-rustls"))]
43type HttpClient = Client<HttpConnector, Full<Bytes>>;
44
45#[cfg(feature = "tls-rustls")]
46type HttpClient = crate::tls::HttpsClient;
47
48#[derive(Clone, Debug)]
55pub struct JsonRpcTransport {
56 inner: Arc<Inner>,
57}
58
59#[derive(Debug)]
60struct Inner {
61 client: HttpClient,
62 endpoint: String,
63 request_timeout: Duration,
64 stream_connect_timeout: Duration,
65}
66
67impl JsonRpcTransport {
68 pub fn new(endpoint: impl Into<String>) -> ClientResult<Self> {
76 Self::with_timeout(endpoint, Duration::from_secs(30))
77 }
78
79 pub fn with_timeout(
85 endpoint: impl Into<String>,
86 request_timeout: Duration,
87 ) -> ClientResult<Self> {
88 Self::with_timeouts(endpoint, request_timeout, request_timeout)
89 }
90
91 pub fn with_timeouts(
97 endpoint: impl Into<String>,
98 request_timeout: Duration,
99 stream_connect_timeout: Duration,
100 ) -> ClientResult<Self> {
101 let endpoint = endpoint.into();
102 validate_url(&endpoint)?;
103
104 #[cfg(not(feature = "tls-rustls"))]
105 let client = Client::builder(TokioExecutor::new()).build_http::<Full<Bytes>>();
106
107 #[cfg(feature = "tls-rustls")]
108 let client = crate::tls::build_https_client();
109
110 Ok(Self {
111 inner: Arc::new(Inner {
112 client,
113 endpoint,
114 request_timeout,
115 stream_connect_timeout,
116 }),
117 })
118 }
119
120 #[must_use]
122 pub fn endpoint(&self) -> &str {
123 &self.inner.endpoint
124 }
125
126 fn build_request(
129 &self,
130 method: &str,
131 params: serde_json::Value,
132 extra_headers: &HashMap<String, String>,
133 accept_sse: bool,
134 ) -> ClientResult<hyper::Request<Full<Bytes>>> {
135 let id = serde_json::Value::String(Uuid::new_v4().to_string());
136 let rpc_req = JsonRpcRequest::with_params(id, method, params);
137 let body_bytes = serde_json::to_vec(&rpc_req).map_err(ClientError::Serialization)?;
138
139 let accept = if accept_sse {
140 "text/event-stream"
141 } else {
142 "application/json"
143 };
144
145 let mut builder = hyper::Request::builder()
146 .method(hyper::Method::POST)
147 .uri(&self.inner.endpoint)
148 .header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
149 .header(
150 a2a_protocol_types::A2A_VERSION_HEADER,
151 a2a_protocol_types::A2A_VERSION,
152 )
153 .header(header::ACCEPT, accept);
154
155 for (k, v) in extra_headers {
156 builder = builder.header(k.as_str(), v.as_str());
157 }
158
159 builder
160 .body(Full::new(Bytes::from(body_bytes)))
161 .map_err(|e| ClientError::Transport(e.to_string()))
162 }
163
164 async fn execute_request(
165 &self,
166 method: &str,
167 params: serde_json::Value,
168 extra_headers: &HashMap<String, String>,
169 ) -> ClientResult<serde_json::Value> {
170 trace_info!(method, endpoint = %self.inner.endpoint, "sending JSON-RPC request");
171
172 let req = self.build_request(method, params, extra_headers, false)?;
173
174 let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
175 .await
176 .map_err(|_| {
177 trace_error!(method, "request timed out");
178 ClientError::Transport("request timed out".into())
179 })?
180 .map_err(|e| {
181 trace_error!(method, error = %e, "HTTP client error");
182 ClientError::HttpClient(e.to_string())
183 })?;
184
185 let status = resp.status();
186 trace_debug!(method, %status, "received response");
187
188 let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
189
190 if !status.is_success() {
191 let body_str = String::from_utf8_lossy(&body_bytes);
192 trace_warn!(method, %status, "unexpected HTTP status");
193 return Err(ClientError::UnexpectedStatus {
194 status: status.as_u16(),
195 body: super::truncate_body(&body_str),
196 });
197 }
198
199 let envelope: JsonRpcResponse<serde_json::Value> =
200 serde_json::from_slice(&body_bytes).map_err(ClientError::Serialization)?;
201
202 match envelope {
203 JsonRpcResponse::Success(ok) => {
204 trace_info!(method, "request succeeded");
205 Ok(ok.result)
206 }
207 JsonRpcResponse::Error(err) => {
208 trace_warn!(method, code = err.error.code, "JSON-RPC error response");
209 let a2a = a2a_protocol_types::A2aError::new(
210 a2a_protocol_types::ErrorCode::try_from(err.error.code)
211 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
212 err.error.message,
213 );
214 Err(ClientError::Protocol(a2a))
215 }
216 }
217 }
218
219 async fn execute_streaming_request(
220 &self,
221 method: &str,
222 params: serde_json::Value,
223 extra_headers: &HashMap<String, String>,
224 ) -> ClientResult<EventStream> {
225 trace_info!(method, endpoint = %self.inner.endpoint, "opening SSE stream");
226
227 let req = self.build_request(method, params, extra_headers, true)?;
228
229 let resp = tokio::time::timeout(
230 self.inner.stream_connect_timeout,
231 self.inner.client.request(req),
232 )
233 .await
234 .map_err(|_| {
235 trace_error!(method, "stream connect timed out");
236 ClientError::Timeout("stream connect timed out".into())
237 })?
238 .map_err(|e| {
239 trace_error!(method, error = %e, "HTTP client error");
240 ClientError::HttpClient(e.to_string())
241 })?;
242
243 let status = resp.status();
244 if !status.is_success() {
245 let body_bytes = resp.collect().await.map_err(ClientError::Http)?.to_bytes();
246 let body_str = String::from_utf8_lossy(&body_bytes);
247 return Err(ClientError::UnexpectedStatus {
248 status: status.as_u16(),
249 body: super::truncate_body(&body_str),
250 });
251 }
252
253 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
254 let body = resp.into_body();
255
256 let task_handle = tokio::spawn(async move {
258 body_reader_task(body, tx).await;
259 });
260
261 Ok(EventStream::with_abort_handle(
262 rx,
263 task_handle.abort_handle(),
264 ))
265 }
266}
267
268impl Transport for JsonRpcTransport {
269 fn send_request<'a>(
270 &'a self,
271 method: &'a str,
272 params: serde_json::Value,
273 extra_headers: &'a HashMap<String, String>,
274 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
275 Box::pin(self.execute_request(method, params, extra_headers))
276 }
277
278 fn send_streaming_request<'a>(
279 &'a self,
280 method: &'a str,
281 params: serde_json::Value,
282 extra_headers: &'a HashMap<String, String>,
283 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
284 Box::pin(self.execute_streaming_request(method, params, extra_headers))
285 }
286}
287
288async fn body_reader_task(
295 body: hyper::body::Incoming,
296 tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
297) {
298 tokio::pin!(body);
301
302 loop {
303 let frame = std::future::poll_fn(|cx| {
305 use hyper::body::Body;
306 let pinned = unsafe { Pin::new_unchecked(&mut *body) };
310 pinned.poll_frame(cx)
311 })
312 .await;
313
314 match frame {
315 None => break, Some(Err(e)) => {
317 let _ = tx.send(Err(ClientError::Http(e))).await;
318 break;
319 }
320 Some(Ok(frame)) => {
321 if let Ok(data) = frame.into_data() {
322 if tx.send(Ok(data)).await.is_err() {
323 break;
325 }
326 }
327 }
329 }
330 }
331}
332
333fn validate_url(url: &str) -> ClientResult<()> {
336 if url.is_empty() {
337 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
338 }
339 if !url.starts_with("http://") && !url.starts_with("https://") {
340 return Err(ClientError::InvalidEndpoint(format!(
341 "URL must start with http:// or https://: {url}"
342 )));
343 }
344 Ok(())
345}
346
347#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn validate_url_rejects_empty() {
355 assert!(validate_url("").is_err());
356 }
357
358 #[test]
359 fn validate_url_rejects_non_http() {
360 assert!(validate_url("ftp://example.com").is_err());
361 }
362
363 #[test]
364 fn validate_url_accepts_http() {
365 assert!(validate_url("http://localhost:8080").is_ok());
366 }
367
368 #[test]
369 fn validate_url_accepts_https() {
370 assert!(validate_url("https://agent.example.com/a2a").is_ok());
371 }
372
373 #[test]
374 fn transport_new_rejects_bad_url() {
375 assert!(JsonRpcTransport::new("not-a-url").is_err());
376 }
377
378 #[test]
379 fn transport_new_stores_endpoint() {
380 let t = JsonRpcTransport::new("http://localhost:9090").unwrap();
381 assert_eq!(t.endpoint(), "http://localhost:9090");
382 }
383}