use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::header;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::client::legacy::Client;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::rt::TokioExecutor;
use tokio::sync::mpsc;
use uuid::Uuid;
use a2a_protocol_types::{JsonRpcRequest, JsonRpcResponse};
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use crate::transport::Transport;
#[cfg(not(feature = "tls-rustls"))]
type HttpClient = Client<HttpConnector, Full<Bytes>>;
#[cfg(feature = "tls-rustls")]
type HttpClient = crate::tls::HttpsClient;
#[derive(Clone, Debug)]
pub struct JsonRpcTransport {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
client: HttpClient,
endpoint: String,
request_timeout: Duration,
stream_connect_timeout: Duration,
}
impl JsonRpcTransport {
pub fn new(endpoint: impl Into<String>) -> ClientResult<Self> {
Self::with_timeout(endpoint, Duration::from_secs(30))
}
pub fn with_timeout(
endpoint: impl Into<String>,
request_timeout: Duration,
) -> ClientResult<Self> {
Self::with_timeouts(endpoint, request_timeout, request_timeout)
}
pub fn with_timeouts(
endpoint: impl Into<String>,
request_timeout: Duration,
stream_connect_timeout: Duration,
) -> ClientResult<Self> {
Self::with_all_timeouts(
endpoint,
request_timeout,
stream_connect_timeout,
Duration::from_secs(10),
)
}
pub fn with_all_timeouts(
endpoint: impl Into<String>,
request_timeout: Duration,
stream_connect_timeout: Duration,
connection_timeout: Duration,
) -> ClientResult<Self> {
let endpoint = endpoint.into();
validate_url(&endpoint)?;
#[cfg(not(feature = "tls-rustls"))]
let client = {
let mut connector = HttpConnector::new();
connector.set_connect_timeout(Some(connection_timeout));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(connector)
};
#[cfg(feature = "tls-rustls")]
let client = crate::tls::build_https_client_with_connect_timeout(
crate::tls::default_tls_config(),
connection_timeout,
);
Ok(Self {
inner: Arc::new(Inner {
client,
endpoint,
request_timeout,
stream_connect_timeout,
}),
})
}
#[must_use]
pub fn endpoint(&self) -> &str {
&self.inner.endpoint
}
fn build_request(
&self,
method: &str,
params: serde_json::Value,
extra_headers: &HashMap<String, String>,
accept_sse: bool,
) -> ClientResult<(serde_json::Value, hyper::Request<Full<Bytes>>)> {
let id = serde_json::Value::String(Uuid::new_v4().to_string());
let rpc_req = JsonRpcRequest::with_params(id.clone(), method, params);
let body_bytes = serde_json::to_vec(&rpc_req).map_err(ClientError::Serialization)?;
let accept = if accept_sse {
"text/event-stream"
} else {
"application/json"
};
let mut builder = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(&self.inner.endpoint)
.header(header::CONTENT_TYPE, a2a_protocol_types::A2A_CONTENT_TYPE)
.header(
a2a_protocol_types::A2A_VERSION_HEADER,
a2a_protocol_types::A2A_VERSION,
)
.header(header::ACCEPT, accept);
for (k, v) in extra_headers {
builder = builder.header(k.as_str(), v.as_str());
}
let req = builder
.body(Full::new(Bytes::from(body_bytes)))
.map_err(|e| ClientError::Transport(e.to_string()))?;
Ok((id, req))
}
#[allow(clippy::too_many_lines)]
async fn execute_request(
&self,
method: &str,
params: serde_json::Value,
extra_headers: &HashMap<String, String>,
) -> ClientResult<serde_json::Value> {
trace_info!(method, endpoint = %self.inner.endpoint, "sending JSON-RPC request");
let (request_id, req) = self.build_request(method, params, extra_headers, false)?;
let resp = tokio::time::timeout(self.inner.request_timeout, self.inner.client.request(req))
.await
.map_err(|_| {
trace_error!(method, "request timed out");
ClientError::Timeout("request timed out".into())
})?
.map_err(|e| {
trace_error!(method, error = %e, "HTTP client error");
ClientError::HttpClient(e.to_string())
})?;
let status = resp.status();
trace_debug!(method, %status, "received response");
let body_bytes = tokio::time::timeout(self.inner.request_timeout, resp.collect())
.await
.map_err(|_| {
trace_error!(method, "response body read timed out");
ClientError::Timeout("response body read timed out".into())
})?
.map_err(ClientError::Http)?
.to_bytes();
if !status.is_success() {
let body_str = String::from_utf8_lossy(&body_bytes);
trace_warn!(method, %status, "unexpected HTTP status");
return Err(ClientError::UnexpectedStatus {
status: status.as_u16(),
body: super::truncate_body(&body_str),
});
}
let envelope: JsonRpcResponse<serde_json::Value> = serde_json::from_slice(&body_bytes)
.map_err(|e| {
let preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(200)]);
if preview.contains("jsonrpc") {
ClientError::Serialization(e)
} else {
ClientError::ProtocolBindingMismatch(format!(
"response is not JSON-RPC ({e}); the server may use REST transport",
))
}
})?;
match envelope {
JsonRpcResponse::Success(ok) => {
if ok.id != Some(request_id.clone()) {
trace_warn!(
method,
"JSON-RPC response id mismatch (expected {request_id}, got {:?})",
ok.id
);
return Err(ClientError::Transport(
"JSON-RPC response id does not match request id".into(),
));
}
trace_info!(method, "request succeeded");
Ok(ok.result)
}
JsonRpcResponse::Error(err) => {
if err.id != Some(request_id.clone()) {
trace_warn!(
method,
"JSON-RPC error response id mismatch (expected {request_id}, got {:?})",
err.id
);
}
trace_warn!(method, code = err.error.code, "JSON-RPC error response");
let a2a = a2a_protocol_types::A2aError::new(
a2a_protocol_types::ErrorCode::try_from(err.error.code)
.unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
err.error.message,
);
Err(ClientError::Protocol(a2a))
}
}
}
async fn execute_streaming_request(
&self,
method: &str,
params: serde_json::Value,
extra_headers: &HashMap<String, String>,
) -> ClientResult<EventStream> {
trace_info!(method, endpoint = %self.inner.endpoint, "opening SSE stream");
let (_request_id, req) = self.build_request(method, params, extra_headers, true)?;
let resp = tokio::time::timeout(
self.inner.stream_connect_timeout,
self.inner.client.request(req),
)
.await
.map_err(|_| {
trace_error!(method, "stream connect timed out");
ClientError::Timeout("stream connect timed out".into())
})?
.map_err(|e| {
trace_error!(method, error = %e, "HTTP client error");
ClientError::HttpClient(e.to_string())
})?;
let status = resp.status();
if !status.is_success() {
let body_bytes =
tokio::time::timeout(self.inner.stream_connect_timeout, resp.collect())
.await
.map_err(|_| ClientError::Timeout("error body read timed out".into()))?
.map_err(ClientError::Http)?
.to_bytes();
let body_str = String::from_utf8_lossy(&body_bytes);
return Err(ClientError::UnexpectedStatus {
status: status.as_u16(),
body: super::truncate_body(&body_str),
});
}
let actual_status = status.as_u16();
let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
let body = resp.into_body();
let task_handle = tokio::spawn(async move {
body_reader_task(body, tx).await;
});
Ok(EventStream::with_status(
rx,
task_handle.abort_handle(),
actual_status,
))
}
}
impl Transport for JsonRpcTransport {
fn send_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
Box::pin(self.execute_request(method, params, extra_headers))
}
fn send_streaming_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(self.execute_streaming_request(method, params, extra_headers))
}
}
async fn body_reader_task(
mut body: hyper::body::Incoming,
tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
) {
use http_body_util::BodyExt;
loop {
match body.frame().await {
None => break, Some(Err(e)) => {
let _ = tx.send(Err(ClientError::Http(e))).await;
break;
}
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if tx.send(Ok(data)).await.is_err() {
break;
}
}
}
}
}
}
fn validate_url(url: &str) -> ClientResult<()> {
if url.is_empty() {
return Err(ClientError::InvalidEndpoint("URL must not be empty".into()));
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(ClientError::InvalidEndpoint(format!(
"URL must start with http:// or https://: {url}"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_url_rejects_empty() {
assert!(validate_url("").is_err());
}
#[test]
fn validate_url_rejects_non_http() {
assert!(validate_url("ftp://example.com").is_err());
}
#[test]
fn validate_url_accepts_http() {
assert!(validate_url("http://localhost:8080").is_ok());
}
#[test]
fn validate_url_accepts_https() {
assert!(validate_url("https://agent.example.com/a2a").is_ok());
}
#[test]
fn transport_new_rejects_bad_url() {
assert!(JsonRpcTransport::new("not-a-url").is_err());
}
#[test]
fn transport_new_stores_endpoint() {
let t = JsonRpcTransport::new("http://localhost:9090").unwrap();
assert_eq!(t.endpoint(), "http://localhost:9090");
}
async fn start_server(status: u16, body: impl Into<String>) -> std::net::SocketAddr {
let body: String = body.into();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
let body = body.clone();
tokio::spawn(async move {
let service = hyper::service::service_fn(
move |req: hyper::Request<hyper::body::Incoming>| {
let body_template = body.clone();
async move {
let req_bytes = req
.into_body()
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.unwrap_or_default();
let response_body = if let Ok(req_json) =
serde_json::from_slice::<serde_json::Value>(&req_bytes)
{
if let Some(id) = req_json.get("id") {
body_template.replace("\"__ID__\"", &id.to_string())
} else {
body_template
}
} else {
body_template
};
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(response_body)))
.unwrap(),
)
}
},
);
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
addr
}
#[tokio::test]
async fn execute_request_non_success_status_returns_error() {
let addr = start_server(404, "Not Found").await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let result = transport
.execute_request("GetTask", serde_json::json!({}), &HashMap::new())
.await;
match result {
Err(ClientError::UnexpectedStatus { status, .. }) => {
assert_eq!(status, 404);
}
other => panic!("expected UnexpectedStatus, got {other:?}"),
}
}
#[tokio::test]
async fn execute_request_success_parses_jsonrpc() {
let response_body = r#"{"jsonrpc":"2.0","id":"__ID__","result":{"hello":"world"}}"#;
let addr = start_server(200, response_body).await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let result = transport
.execute_request("GetTask", serde_json::json!({}), &HashMap::new())
.await;
let value = result.unwrap();
assert_eq!(value["hello"], "world");
}
#[tokio::test]
async fn execute_streaming_request_non_success_returns_error() {
let addr = start_server(500, "Internal Server Error").await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let result = transport
.execute_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await;
match result {
Err(ClientError::UnexpectedStatus { status, .. }) => {
assert_eq!(status, 500);
}
other => panic!("expected UnexpectedStatus, got {other:?}"),
}
}
#[tokio::test]
async fn execute_request_jsonrpc_error_returns_protocol_error() {
let response_body = r#"{"jsonrpc":"2.0","id":"__ID__","error":{"code":-32603,"message":"internal failure"}}"#;
let addr = start_server(200, response_body).await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let result = transport
.execute_request("GetTask", serde_json::json!({}), &HashMap::new())
.await;
match result {
Err(ClientError::Protocol(a2a_err)) => {
assert!(
a2a_err.message.contains("internal failure"),
"got: {}",
a2a_err.message
);
}
other => panic!("expected Protocol error, got {other:?}"),
}
}
#[tokio::test]
async fn execute_request_non_jsonrpc_returns_binding_mismatch() {
let response_body = r#"{"status":"ok","data":42}"#;
let addr = start_server(200, response_body).await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let result = transport
.execute_request("GetTask", serde_json::json!({}), &HashMap::new())
.await;
match result {
Err(ClientError::ProtocolBindingMismatch(msg)) => {
assert!(msg.contains("REST"), "should mention REST transport: {msg}");
}
other => panic!("expected ProtocolBindingMismatch, got {other:?}"),
}
}
#[tokio::test]
async fn send_streaming_request_via_trait_delegation() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let service = hyper::service::service_fn(|_req| async {
let sse_body = "data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"status\":\"ok\"}}\n\n";
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Full::new(Bytes::from(sse_body)))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let dyn_transport: &dyn Transport = &transport;
let result = dyn_transport
.send_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await;
assert!(result.is_ok(), "streaming via trait delegation should work");
}
#[tokio::test]
async fn send_request_via_trait_delegation() {
let response_body = r#"{"jsonrpc":"2.0","id":"__ID__","result":{"hello":"world"}}"#;
let addr = start_server(200, response_body).await;
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let dyn_transport: &dyn Transport = &transport;
let result = dyn_transport
.send_request("GetTask", serde_json::json!({}), &HashMap::new())
.await;
let value = result.unwrap();
assert_eq!(value["hello"], "world");
}
#[tokio::test]
async fn execute_streaming_request_success_returns_event_stream() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let service = hyper::service::service_fn(|_req| async {
let sse_body = "data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"status\":\"ok\"}}\n\n";
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Full::new(Bytes::from(sse_body)))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = JsonRpcTransport::new(&url).unwrap();
let mut stream = transport
.execute_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await
.unwrap();
let event = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
.await
.expect("timed out waiting for event");
assert!(
event.is_some(),
"expected at least one event from the stream"
);
}
}