use std::collections::HashMap;
use std::convert::Infallible;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use a2a_protocol_types::jsonrpc::{
JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
JsonRpcVersion,
};
use crate::error::ServerError;
pub(super) fn extract_headers(headers: &hyper::HeaderMap) -> HashMap<String, String> {
let mut map = HashMap::with_capacity(headers.len());
for (key, value) in headers {
if let Ok(v) = value.to_str() {
map.insert(key.as_str().to_owned(), v.to_owned());
}
}
map
}
pub(super) fn success_response_bytes<T: serde::Serialize>(id: JsonRpcId, result: &T) -> Vec<u8> {
match serde_json::to_value(result) {
Ok(value) => {
let resp = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id,
result: value,
};
serde_json::to_vec(&resp).unwrap_or_else(|_| {
br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"internal serialization error"}}"#.to_vec()
})
}
Err(_) => error_response_bytes(
id,
&ServerError::Internal("result serialization failed".into()),
),
}
}
pub(super) fn error_response_bytes(id: JsonRpcId, err: &ServerError) -> Vec<u8> {
let a2a_err = err.to_a2a_error();
let resp = JsonRpcErrorResponse::new(
id,
JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
);
serde_json::to_vec(&resp).unwrap_or_default()
}
pub(super) fn parse_params<T: serde::de::DeserializeOwned>(
rpc_req: &JsonRpcRequest,
) -> Result<T, ServerError> {
let params = rpc_req
.params
.as_ref()
.ok_or_else(|| ServerError::InvalidParams("missing params".into()))?;
serde_json::from_value(params.clone())
.map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
}
pub(super) fn success_response<T: serde::Serialize>(
id: JsonRpcId,
result: &T,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let value = match serde_json::to_value(result) {
Ok(v) => v,
Err(e) => return internal_serialization_error(id, &e),
};
let resp = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id: id.clone(),
result: value,
};
match serde_json::to_vec(&resp) {
Ok(body) => json_response(200, body),
Err(e) => internal_serialization_error(id, &e),
}
}
pub(super) fn error_response(
id: JsonRpcId,
err: &ServerError,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let a2a_err = err.to_a2a_error();
let resp = JsonRpcErrorResponse::new(
id.clone(),
JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
);
match serde_json::to_vec(&resp) {
Ok(body) => json_response(200, body),
Err(e) => internal_serialization_error(id, &e),
}
}
pub(super) fn parse_error_response(
id: JsonRpcId,
message: &str,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let resp = JsonRpcErrorResponse::new(
id.clone(),
JsonRpcError::new(
a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
format!("Parse error: {message}"),
),
);
match serde_json::to_vec(&resp) {
Ok(body) => json_response(200, body),
Err(e) => internal_serialization_error(id, &e),
}
}
pub(super) fn internal_serialization_error(
_id: JsonRpcId,
_err: &serde_json::Error,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
trace_error!(error = %_err, "JSON-RPC response serialization failed");
let body = br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"internal serialization error"}}"#;
json_response(200, body.to_vec())
}
pub(super) async fn read_body_limited(
body: Incoming,
max_size: usize,
read_timeout: std::time::Duration,
) -> Result<Bytes, String> {
let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
if let Some(upper) = size_hint.upper() {
if upper > max_size as u64 {
return Err(format!(
"request body too large: {upper} bytes exceeds {max_size} byte limit"
));
}
}
let collected = tokio::time::timeout(read_timeout, body.collect())
.await
.map_err(|_| "request body read timed out".to_owned())?
.map_err(|e| e.to_string())?;
let bytes = collected.to_bytes();
if bytes.len() > max_size {
return Err(format!(
"request body too large: {} bytes exceeds {max_size} byte limit",
bytes.len()
));
}
Ok(bytes)
}
pub(super) fn json_response(
status: u16,
body: Vec<u8>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
hyper::Response::builder()
.status(status)
.header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
.header(a2a_protocol_types::A2A_VERSION_HEADER, a2a_protocol_types::A2A_VERSION)
.body(Full::new(Bytes::from(body)).boxed())
.unwrap_or_else(|_| {
hyper::Response::new(
Full::new(Bytes::from_static(
br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"response build error"}}"#,
))
.boxed(),
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::BodyExt;
use hyper::header::HeaderValue;
#[test]
fn extract_headers_lowercases_keys() {
let mut headers = hyper::HeaderMap::new();
headers.insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_static("Bearer token"),
);
let map = extract_headers(&headers);
assert_eq!(
map.get("authorization").map(String::as_str),
Some("Bearer token")
);
}
#[test]
fn extract_headers_skips_non_ascii_values() {
let mut headers = hyper::HeaderMap::new();
let bad_value = HeaderValue::from_bytes(b"caf\xe9").unwrap();
headers.insert(hyper::header::X_CONTENT_TYPE_OPTIONS, bad_value);
let map = extract_headers(&headers);
assert!(!map.contains_key("x-content-type-options"));
}
#[test]
fn extract_headers_empty() {
let headers = hyper::HeaderMap::new();
let map = extract_headers(&headers);
assert!(map.is_empty());
}
#[test]
fn parse_params_missing_returns_invalid_params() {
use a2a_protocol_types::params::TaskQueryParams;
let req = JsonRpcRequest {
jsonrpc: JsonRpcVersion,
id: None,
method: "GetTask".to_owned(),
params: None,
};
let result: Result<TaskQueryParams, _> = parse_params(&req);
assert!(result.is_err(), "expected error when params are missing");
let err = result.unwrap_err();
assert!(
matches!(err, ServerError::InvalidParams(_)),
"expected InvalidParams, got {err:?}"
);
}
#[test]
fn parse_params_invalid_type_returns_error() {
use a2a_protocol_types::params::TaskQueryParams;
let req = JsonRpcRequest {
jsonrpc: JsonRpcVersion,
id: Some(serde_json::json!(1)),
method: "GetTask".to_owned(),
params: Some(serde_json::json!(42)),
};
let result: Result<TaskQueryParams, _> = parse_params(&req);
assert!(result.is_err(), "expected error for wrong params type");
}
#[test]
fn json_response_200_status() {
let resp = json_response(200, b"{}".to_vec());
assert_eq!(resp.status().as_u16(), 200);
}
#[test]
fn json_response_404_status() {
let resp = json_response(404, b"not found".to_vec());
assert_eq!(resp.status().as_u16(), 404);
}
#[test]
fn parse_error_response_returns_200_with_error_body() {
let resp = parse_error_response(None, "bad json");
assert_eq!(resp.status().as_u16(), 200);
}
#[tokio::test]
async fn parse_error_response_has_error_code() {
let resp = parse_error_response(None, "something went wrong");
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(val["error"]["code"], -32700);
assert!(val["error"]["message"].is_string());
}
#[test]
fn success_response_bytes_structure() {
let id: JsonRpcId = Some(serde_json::json!(1));
let bytes = success_response_bytes(id, &serde_json::json!({"key": "val"}));
let val: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(val["result"]["key"], "val");
assert_eq!(val["id"], 1);
}
#[test]
fn success_response_includes_jsonrpc_version() {
let id: JsonRpcId = Some(serde_json::json!(42));
let bytes = success_response_bytes(id, &serde_json::json!(null));
let val: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(val["jsonrpc"], "2.0");
}
#[test]
fn error_response_bytes_contains_error_object() {
let id: JsonRpcId = Some(serde_json::json!(1));
let err = ServerError::MethodNotFound("Foo".into());
let bytes = error_response_bytes(id, &err);
let val: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert!(
val["error"].is_object(),
"expected 'error' key to be an object"
);
assert!(val["error"]["code"].is_number());
assert!(val["error"]["message"].is_string());
}
#[test]
fn error_response_has_jsonrpc_version() {
let id: JsonRpcId = Some(serde_json::json!(7));
let err = ServerError::Internal("oops".into());
let bytes = error_response_bytes(id, &err);
let val: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(val["jsonrpc"], "2.0");
}
#[tokio::test]
async fn success_response_http_200_with_result() {
let id: JsonRpcId = Some(serde_json::json!(1));
let resp = success_response(id, &serde_json::json!({"status": "ok"}));
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(val["result"]["status"], "ok");
assert_eq!(val["jsonrpc"], "2.0");
}
#[tokio::test]
async fn error_response_http_200_with_error() {
let id: JsonRpcId = Some(serde_json::json!(2));
let err = ServerError::TaskNotFound("t-123".into());
let resp = error_response(id, &err);
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(val["error"]["code"].is_number());
assert!(val["error"]["message"].as_str().unwrap().contains("t-123"));
}
#[tokio::test]
async fn internal_serialization_error_returns_200() {
let err = serde_json::from_str::<String>("bad").unwrap_err();
let resp = internal_serialization_error(None, &err);
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(val["error"]["code"], -32603);
}
#[test]
fn json_response_has_content_type_and_version_header() {
let resp = json_response(200, b"{}".to_vec());
assert!(resp.headers().get("content-type").is_some());
assert!(resp
.headers()
.get(a2a_protocol_types::A2A_VERSION_HEADER)
.is_some());
}
}