1use std::collections::HashMap;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::time::Duration;
23
24use http_body_util::{BodyExt, Full};
25use hyper::body::Bytes;
26use hyper::header;
27#[cfg(not(feature = "tls-rustls"))]
28use hyper_util::client::legacy::connect::HttpConnector;
29#[cfg(not(feature = "tls-rustls"))]
30use hyper_util::client::legacy::Client;
31#[cfg(not(feature = "tls-rustls"))]
32use hyper_util::rt::TokioExecutor;
33use tokio::sync::mpsc;
34use uuid::Uuid;
35
36use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
37
38use crate::error::{ClientError, ClientResult};
39use crate::streaming::EventStream;
40use crate::transport::Transport;
41
42#[cfg(not(feature = "tls-rustls"))]
45type HttpClient = Client<HttpConnector, Full<Bytes>>;
46
47#[cfg(feature = "tls-rustls")]
48type HttpClient = crate::tls::HttpsClient;
49
50#[derive(Clone, Debug)]
57pub struct JsonRpcTransport {
58 inner: Arc<Inner>,
59}
60
61#[derive(Debug)]
62struct Inner {
63 client: HttpClient,
64 endpoint: String,
65 request_timeout: Duration,
66 stream_connect_timeout: Duration,
67}
68
69impl JsonRpcTransport {
70 pub fn new(endpoint: impl Into<String>) -> ClientResult<Self> {
78 Self::with_timeout(endpoint, Duration::from_secs(30))
79 }
80
81 pub fn with_timeout(
87 endpoint: impl Into<String>,
88 request_timeout: Duration,
89 ) -> ClientResult<Self> {
90 Self::with_timeouts(endpoint, request_timeout, request_timeout)
91 }
92
93 pub fn with_timeouts(
101 endpoint: impl Into<String>,
102 request_timeout: Duration,
103 stream_connect_timeout: Duration,
104 ) -> ClientResult<Self> {
105 Self::with_all_timeouts(
106 endpoint,
107 request_timeout,
108 stream_connect_timeout,
109 Duration::from_secs(10),
110 )
111 }
112
113 pub fn with_all_timeouts(
122 endpoint: impl Into<String>,
123 request_timeout: Duration,
124 stream_connect_timeout: Duration,
125 connection_timeout: Duration,
126 ) -> ClientResult<Self> {
127 let endpoint = endpoint.into();
128 validate_url(&endpoint)?;
129
130 #[cfg(not(feature = "tls-rustls"))]
131 let client = {
132 let mut connector = HttpConnector::new();
133 connector.set_connect_timeout(Some(connection_timeout));
134 connector.set_nodelay(true);
135 Client::builder(TokioExecutor::new())
136 .pool_idle_timeout(Duration::from_secs(90))
137 .build(connector)
138 };
139
140 #[cfg(feature = "tls-rustls")]
141 let client = crate::tls::build_https_client_with_connect_timeout(
142 crate::tls::default_tls_config(),
143 connection_timeout,
144 );
145
146 Ok(Self {
147 inner: Arc::new(Inner {
148 client,
149 endpoint,
150 request_timeout,
151 stream_connect_timeout,
152 }),
153 })
154 }
155
156 #[must_use]
158 pub fn endpoint(&self) -> &str {
159 &self.inner.endpoint
160 }
161
162 fn build_request(
165 &self,
166 method: &str,
167 params: serde_json::Value,
168 extra_headers: &HashMap<String, String>,
169 accept_sse: bool,
170 ) -> ClientResult<hyper::Request<Full<Bytes>>> {
171 let id = serde_json::Value::String(Uuid::new_v4().to_string());
172 let rpc_req = JsonRpcRequest::with_params(id, method, params);
173 let body_bytes = serde_json::to_vec(&rpc_req).map_err(ClientError::Serialization)?;
174
175 let accept = if accept_sse {
176 "text/event-stream"
177 } else {
178 "application/json"
179 };
180
181 let mut builder = hyper::Request::builder()
182 .method(hyper::Method::POST)
183 .uri(&self.inner.endpoint)
184 .header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
185 .header(
186 a2a_protocol_types::A2A_VERSION_HEADER,
187 a2a_protocol_types::A2A_VERSION,
188 )
189 .header(header::ACCEPT, accept);
190
191 for (k, v) in extra_headers {
192 builder = builder.header(k.as_str(), v.as_str());
193 }
194
195 builder
196 .body(Full::new(Bytes::from(body_bytes)))
197 .map_err(|e| ClientError::Transport(e.to_string()))
198 }
199
200 async fn execute_request(
201 &self,
202 method: &str,
203 params: serde_json::Value,
204 extra_headers: &HashMap<String, String>,
205 ) -> ClientResult<serde_json::Value> {
206 trace_info!(method, endpoint = %self.inner.endpoint, "sending JSON-RPC request");
207
208 let req = self.build_request(method, params, extra_headers, false)?;
209
210 let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
211 .await
212 .map_err(|_| {
213 trace_error!(method, "request timed out");
214 ClientError::Timeout("request timed out".into())
215 })?
216 .map_err(|e| {
217 trace_error!(method, error = %e, "HTTP client error");
218 ClientError::HttpClient(e.to_string())
219 })?;
220
221 let status = resp.status();
222 trace_debug!(method, %status, "received response");
223
224 let body_bytes = tokio::time::timeout(self.inner.request_timeout, resp.collect())
225 .await
226 .map_err(|_| {
227 trace_error!(method, "response body read timed out");
228 ClientError::Timeout("response body read timed out".into())
229 })?
230 .map_err(ClientError::Http)?
231 .to_bytes();
232
233 if !status.is_success() {
234 let body_str = String::from_utf8_lossy(&body_bytes);
235 trace_warn!(method, %status, "unexpected HTTP status");
236 return Err(ClientError::UnexpectedStatus {
237 status: status.as_u16(),
238 body: super::truncate_body(&body_str),
239 });
240 }
241
242 let envelope: JsonRpcResponse<serde_json::Value> = serde_json::from_slice(&body_bytes)
243 .map_err(|e| {
244 let preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(200)]);
247 if preview.contains("jsonrpc") {
248 ClientError::Serialization(e)
249 } else {
250 ClientError::ProtocolBindingMismatch(format!(
251 "response is not JSON-RPC ({e}); the server may use REST transport",
252 ))
253 }
254 })?;
255
256 match envelope {
257 JsonRpcResponse::Success(ok) => {
258 trace_info!(method, "request succeeded");
259 Ok(ok.result)
260 }
261 JsonRpcResponse::Error(err) => {
262 trace_warn!(method, code = err.error.code, "JSON-RPC error response");
263 let a2a = a2a_protocol_types::A2aError::new(
264 a2a_protocol_types::ErrorCode::try_from(err.error.code)
265 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
266 err.error.message,
267 );
268 Err(ClientError::Protocol(a2a))
269 }
270 }
271 }
272
273 async fn execute_streaming_request(
274 &self,
275 method: &str,
276 params: serde_json::Value,
277 extra_headers: &HashMap<String, String>,
278 ) -> ClientResult<EventStream> {
279 trace_info!(method, endpoint = %self.inner.endpoint, "opening SSE stream");
280
281 let req = self.build_request(method, params, extra_headers, true)?;
282
283 let resp = tokio::time::timeout(
284 self.inner.stream_connect_timeout,
285 self.inner.client.request(req),
286 )
287 .await
288 .map_err(|_| {
289 trace_error!(method, "stream connect timed out");
290 ClientError::Timeout("stream connect timed out".into())
291 })?
292 .map_err(|e| {
293 trace_error!(method, error = %e, "HTTP client error");
294 ClientError::HttpClient(e.to_string())
295 })?;
296
297 let status = resp.status();
298 if !status.is_success() {
299 let body_bytes =
300 tokio::time::timeout(self.inner.stream_connect_timeout, resp.collect())
301 .await
302 .map_err(|_| ClientError::Timeout("error body read timed out".into()))?
303 .map_err(ClientError::Http)?
304 .to_bytes();
305 let body_str = String::from_utf8_lossy(&body_bytes);
306 return Err(ClientError::UnexpectedStatus {
307 status: status.as_u16(),
308 body: super::truncate_body(&body_str),
309 });
310 }
311
312 let actual_status = status.as_u16();
313 let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
314 let body = resp.into_body();
315
316 let task_handle = tokio::spawn(async move {
318 body_reader_task(body, tx).await;
319 });
320
321 Ok(EventStream::with_status(
322 rx,
323 task_handle.abort_handle(),
324 actual_status,
325 ))
326 }
327}
328
329impl Transport for JsonRpcTransport {
330 fn send_request<'a>(
331 &'a self,
332 method: &'a str,
333 params: serde_json::Value,
334 extra_headers: &'a HashMap<String, String>,
335 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
336 Box::pin(self.execute_request(method, params, extra_headers))
337 }
338
339 fn send_streaming_request<'a>(
340 &'a self,
341 method: &'a str,
342 params: serde_json::Value,
343 extra_headers: &'a HashMap<String, String>,
344 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
345 Box::pin(self.execute_streaming_request(method, params, extra_headers))
346 }
347}
348
349async fn body_reader_task(
356 mut body: hyper::body::Incoming,
357 tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
358) {
359 use http_body_util::BodyExt;
360
361 loop {
362 match body.frame().await {
363 None => break, Some(Err(e)) => {
365 let _ = tx.send(Err(ClientError::Http(e))).await;
366 break;
367 }
368 Some(Ok(frame)) => {
369 if let Ok(data) = frame.into_data() {
370 if tx.send(Ok(data)).await.is_err() {
371 break;
373 }
374 }
375 }
377 }
378 }
379}
380
381fn validate_url(url: &str) -> ClientResult<()> {
384 if url.is_empty() {
385 return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
386 }
387 if !url.starts_with("http://") && !url.starts_with("https://") {
388 return Err(ClientError::InvalidEndpoint(format!(
389 "URL must start with http:// or https://: {url}"
390 )));
391 }
392 Ok(())
393}
394
395#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn validate_url_rejects_empty() {
403 assert!(validate_url("").is_err());
404 }
405
406 #[test]
407 fn validate_url_rejects_non_http() {
408 assert!(validate_url("ftp://example.com").is_err());
409 }
410
411 #[test]
412 fn validate_url_accepts_http() {
413 assert!(validate_url("http://localhost:8080").is_ok());
414 }
415
416 #[test]
417 fn validate_url_accepts_https() {
418 assert!(validate_url("https://agent.example.com/a2a").is_ok());
419 }
420
421 #[test]
422 fn transport_new_rejects_bad_url() {
423 assert!(JsonRpcTransport::new("not-a-url").is_err());
424 }
425
426 #[test]
427 fn transport_new_stores_endpoint() {
428 let t = JsonRpcTransport::new("http://localhost:9090").unwrap();
429 assert_eq!(t.endpoint(), "http://localhost:9090");
430 }
431
432 async fn start_server(status: u16, body: impl Into<String>) -> std::net::SocketAddr {
434 let body: String = body.into();
435 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
436 let addr = listener.local_addr().unwrap();
437
438 tokio::spawn(async move {
439 loop {
440 let (stream, _) = listener.accept().await.unwrap();
441 let io = hyper_util::rt::TokioIo::new(stream);
442 let body = body.clone();
443 tokio::spawn(async move {
444 let service = hyper::service::service_fn(move |_req| {
445 let body = body.clone();
446 async move {
447 Ok::<_, hyper::Error>(
448 hyper::Response::builder()
449 .status(status)
450 .header("content-type", "application/json")
451 .body(Full::new(Bytes::from(body)))
452 .unwrap(),
453 )
454 }
455 });
456 let _ = hyper_util::server::conn::auto::Builder::new(
457 hyper_util::rt::TokioExecutor::new(),
458 )
459 .serve_connection(io, service)
460 .await;
461 });
462 }
463 });
464
465 addr
466 }
467
468 #[tokio::test]
469 async fn execute_request_non_success_status_returns_error() {
470 let addr = start_server(404, "Not Found").await;
471 let url = format!("http://127.0.0.1:{}", addr.port());
472 let transport = JsonRpcTransport::new(&url).unwrap();
473 let result = transport
474 .execute_request("GetTask", serde_json::json!({}), &HashMap::new())
475 .await;
476 match result {
477 Err(ClientError::UnexpectedStatus { status, .. }) => {
478 assert_eq!(status, 404);
479 }
480 other => panic!("expected UnexpectedStatus, got {other:?}"),
481 }
482 }
483
484 #[tokio::test]
485 async fn execute_request_success_parses_jsonrpc() {
486 let response_body = r#"{"jsonrpc":"2.0","id":"1","result":{"hello":"world"}}"#;
487 let addr = start_server(200, response_body).await;
488 let url = format!("http://127.0.0.1:{}", addr.port());
489 let transport = JsonRpcTransport::new(&url).unwrap();
490 let result = transport
491 .execute_request("GetTask", serde_json::json!({}), &HashMap::new())
492 .await;
493 let value = result.unwrap();
494 assert_eq!(value["hello"], "world");
495 }
496
497 #[tokio::test]
498 async fn execute_streaming_request_non_success_returns_error() {
499 let addr = start_server(500, "Internal Server Error").await;
500 let url = format!("http://127.0.0.1:{}", addr.port());
501 let transport = JsonRpcTransport::new(&url).unwrap();
502 let result = transport
503 .execute_streaming_request(
504 "SendStreamingMessage",
505 serde_json::json!({}),
506 &HashMap::new(),
507 )
508 .await;
509 match result {
510 Err(ClientError::UnexpectedStatus { status, .. }) => {
511 assert_eq!(status, 500);
512 }
513 other => panic!("expected UnexpectedStatus, got {other:?}"),
514 }
515 }
516
517 #[tokio::test]
519 async fn execute_request_jsonrpc_error_returns_protocol_error() {
520 let response_body =
521 r#"{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"internal failure"}}"#;
522 let addr = start_server(200, response_body).await;
523 let url = format!("http://127.0.0.1:{}", addr.port());
524 let transport = JsonRpcTransport::new(&url).unwrap();
525 let result = transport
526 .execute_request("GetTask", serde_json::json!({}), &HashMap::new())
527 .await;
528 match result {
529 Err(ClientError::Protocol(a2a_err)) => {
530 assert!(
531 a2a_err.message.contains("internal failure"),
532 "got: {}",
533 a2a_err.message
534 );
535 }
536 other => panic!("expected Protocol error, got {other:?}"),
537 }
538 }
539
540 #[tokio::test]
542 async fn execute_request_non_jsonrpc_returns_binding_mismatch() {
543 let response_body = r#"{"status":"ok","data":42}"#;
545 let addr = start_server(200, response_body).await;
546 let url = format!("http://127.0.0.1:{}", addr.port());
547 let transport = JsonRpcTransport::new(&url).unwrap();
548 let result = transport
549 .execute_request("GetTask", serde_json::json!({}), &HashMap::new())
550 .await;
551 match result {
552 Err(ClientError::ProtocolBindingMismatch(msg)) => {
553 assert!(msg.contains("REST"), "should mention REST transport: {msg}");
554 }
555 other => panic!("expected ProtocolBindingMismatch, got {other:?}"),
556 }
557 }
558
559 #[tokio::test]
561 async fn send_streaming_request_via_trait_delegation() {
562 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
564 let addr = listener.local_addr().unwrap();
565
566 tokio::spawn(async move {
567 loop {
568 let (stream, _) = listener.accept().await.unwrap();
569 let io = hyper_util::rt::TokioIo::new(stream);
570 tokio::spawn(async move {
571 let service = hyper::service::service_fn(|_req| async {
572 let sse_body = "data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"status\":\"ok\"}}\n\n";
573 Ok::<_, hyper::Error>(
574 hyper::Response::builder()
575 .status(200)
576 .header("content-type", "text/event-stream")
577 .body(Full::new(Bytes::from(sse_body)))
578 .unwrap(),
579 )
580 });
581 let _ = hyper_util::server::conn::auto::Builder::new(
582 hyper_util::rt::TokioExecutor::new(),
583 )
584 .serve_connection(io, service)
585 .await;
586 });
587 }
588 });
589
590 let url = format!("http://127.0.0.1:{}", addr.port());
591 let transport = JsonRpcTransport::new(&url).unwrap();
592 let dyn_transport: &dyn Transport = &transport;
594 let result = dyn_transport
595 .send_streaming_request(
596 "SendStreamingMessage",
597 serde_json::json!({}),
598 &HashMap::new(),
599 )
600 .await;
601 assert!(result.is_ok(), "streaming via trait delegation should work");
602 }
603
604 #[tokio::test]
606 async fn send_request_via_trait_delegation() {
607 let response_body = r#"{"jsonrpc":"2.0","id":"1","result":{"hello":"world"}}"#;
608 let addr = start_server(200, response_body).await;
609 let url = format!("http://127.0.0.1:{}", addr.port());
610 let transport = JsonRpcTransport::new(&url).unwrap();
611 let dyn_transport: &dyn Transport = &transport;
613 let result = dyn_transport
614 .send_request("GetTask", serde_json::json!({}), &HashMap::new())
615 .await;
616 let value = result.unwrap();
617 assert_eq!(value["hello"], "world");
618 }
619
620 #[tokio::test]
621 async fn execute_streaming_request_success_returns_event_stream() {
622 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
624 let addr = listener.local_addr().unwrap();
625
626 tokio::spawn(async move {
627 loop {
628 let (stream, _) = listener.accept().await.unwrap();
629 let io = hyper_util::rt::TokioIo::new(stream);
630 tokio::spawn(async move {
631 let service = hyper::service::service_fn(|_req| async {
632 let sse_body = "data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"status\":\"ok\"}}\n\n";
633 Ok::<_, hyper::Error>(
634 hyper::Response::builder()
635 .status(200)
636 .header("content-type", "text/event-stream")
637 .body(Full::new(Bytes::from(sse_body)))
638 .unwrap(),
639 )
640 });
641 let _ = hyper_util::server::conn::auto::Builder::new(
642 hyper_util::rt::TokioExecutor::new(),
643 )
644 .serve_connection(io, service)
645 .await;
646 });
647 }
648 });
649
650 let url = format!("http://127.0.0.1:{}", addr.port());
651 let transport = JsonRpcTransport::new(&url).unwrap();
652 let mut stream = transport
653 .execute_streaming_request(
654 "SendStreamingMessage",
655 serde_json::json!({}),
656 &HashMap::new(),
657 )
658 .await
659 .unwrap();
660 let event = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
662 .await
663 .expect("timed out waiting for event");
664 assert!(
665 event.is_some(),
666 "expected at least one event from the stream"
667 );
668 }
669}