use std::collections::HashMap;
use std::convert::Infallible;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use crate::error::ServerError;
pub(super) fn extract_headers(headers: &hyper::HeaderMap) -> HashMap<String, String> {
let mut map = HashMap::with_capacity(headers.len());
for (key, value) in headers {
if let Ok(v) = value.to_str() {
map.insert(key.as_str().to_owned(), v.to_owned());
}
}
map
}
pub(super) fn json_ok_response<T: serde::Serialize>(
value: &T,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
match serde_json::to_vec(value) {
Ok(body) => build_json_response(200, body),
Err(_err) => {
trace_error!(error = %_err, "REST response serialization failed");
internal_error_response()
}
}
}
pub(super) fn error_json_response(
status: u16,
message: &str,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let body = serde_json::json!({
"error": {
"code": status,
"message": message
}
});
serde_json::to_vec(&body).map_or_else(
|_| internal_error_response(),
|bytes| build_json_response(status, bytes),
)
}
pub(super) fn internal_error_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
let body = br#"{"error":{"code":500,"message":"internal serialization error"}}"#;
build_json_response(500, body.to_vec())
}
pub(super) fn not_found_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
error_json_response(404, "not found")
}
pub(super) fn server_error_to_response(
err: &ServerError,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let a2a_err = err.to_a2a_error();
let status = a2a_err.code.http_status();
let grpc_status = a2a_err.code.grpc_status();
let details = a2a_err.error_info_data(None);
let mut error_obj = serde_json::json!({
"error": {
"code": status,
"status": grpc_status,
"message": a2a_err.message
}
});
if !details.is_null() {
error_obj["error"]["details"] = details;
}
serde_json::to_vec(&error_obj).map_or_else(
|_| internal_error_response(),
|body| build_json_response(status, body),
)
}
pub(super) fn health_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
let body = br#"{"status":"ok"}"#;
build_json_response(200, body.to_vec())
}
pub(super) fn build_json_response(
status: u16,
body: Vec<u8>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
hyper::Response::builder()
.status(status)
.header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
.header(
a2a_protocol_types::A2A_VERSION_HEADER,
a2a_protocol_types::A2A_VERSION,
)
.body(Full::new(Bytes::from(body)).boxed())
.unwrap_or_else(|_| {
hyper::Response::new(
Full::new(Bytes::from_static(br#"{"error":"response build error"}"#)).boxed(),
)
})
}
pub(super) async fn read_body_limited(
body: Incoming,
max_size: usize,
read_timeout: std::time::Duration,
) -> Result<Bytes, String> {
use http_body_util::BodyExt;
let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
if let Some(upper) = size_hint.upper() {
if upper > max_size as u64 {
return Err(format!(
"request body too large: {upper} bytes exceeds {max_size} byte limit"
));
}
}
let collected = tokio::time::timeout(read_timeout, body.collect())
.await
.map_err(|_| "request body read timed out".to_owned())?
.map_err(|e| e.to_string())?;
let bytes = collected.to_bytes();
if bytes.len() > max_size {
return Err(format!(
"request body too large: {} bytes exceeds {max_size} byte limit",
bytes.len()
));
}
Ok(bytes)
}
pub(super) fn inject_field_if_missing(
mut value: serde_json::Value,
field: &str,
path_value: &str,
) -> serde_json::Value {
if let Some(obj) = value.as_object_mut() {
obj.entry(field.to_owned())
.or_insert_with(|| serde_json::Value::String(path_value.to_owned()));
}
value
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::BodyExt;
#[test]
fn extract_headers_lowercased_keys() {
let mut hm = hyper::HeaderMap::new();
hm.insert("Content-Type", "application/json".parse().unwrap());
hm.insert("Authorization", "Bearer tok".parse().unwrap());
let map = extract_headers(&hm);
assert_eq!(map.get("content-type").unwrap(), "application/json");
assert_eq!(map.get("authorization").unwrap(), "Bearer tok");
}
#[test]
fn extract_headers_empty() {
let hm = hyper::HeaderMap::new();
let map = extract_headers(&hm);
assert!(map.is_empty());
}
#[tokio::test]
async fn health_response_status_and_body() {
let resp = health_response();
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(val["status"], "ok");
}
#[tokio::test]
async fn error_json_response_status_and_body() {
let resp = error_json_response(400, "bad request");
assert_eq!(resp.status().as_u16(), 400);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(val["error"]["message"], "bad request");
assert_eq!(val["error"]["code"], 400);
}
#[tokio::test]
async fn error_json_response_has_a2a_content_type() {
let resp = error_json_response(404, "not found");
assert_eq!(
resp.headers()
.get("content-type")
.and_then(|v| v.to_str().ok()),
Some(a2a_protocol_types::A2A_CONTENT_TYPE),
);
}
#[tokio::test]
async fn not_found_response_is_404() {
let resp = not_found_response();
assert_eq!(resp.status().as_u16(), 404);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(val["error"]["message"], "not found");
}
#[tokio::test]
async fn internal_error_response_is_500() {
let resp = internal_error_response();
assert_eq!(resp.status().as_u16(), 500);
}
#[tokio::test]
async fn build_json_response_includes_version_header() {
let resp = build_json_response(200, b"{}".to_vec());
assert_eq!(
resp.headers()
.get(a2a_protocol_types::A2A_VERSION_HEADER)
.and_then(|v| v.to_str().ok()),
Some(a2a_protocol_types::A2A_VERSION),
);
}
#[tokio::test]
async fn json_ok_response_serializes_value() {
let val = serde_json::json!({"key": "value"});
let resp = json_ok_response(&val);
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(parsed["key"], "value");
}
#[tokio::test]
async fn server_error_task_not_found_maps_to_404() {
let err = ServerError::TaskNotFound("t1".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 404);
}
#[tokio::test]
async fn server_error_method_not_found_maps_to_404() {
let err = ServerError::MethodNotFound("foo".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 404);
}
#[tokio::test]
async fn server_error_task_not_cancelable_maps_to_409() {
let err = ServerError::TaskNotCancelable("t1".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 409);
}
#[tokio::test]
async fn server_error_invalid_params_maps_to_400() {
let err = ServerError::InvalidParams("bad".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 400);
}
#[tokio::test]
async fn server_error_push_not_supported_maps_to_400() {
let err = ServerError::PushNotSupported;
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 400);
}
#[tokio::test]
async fn server_error_internal_maps_to_500() {
let err = ServerError::Internal("oops".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 500);
}
#[test]
fn inject_field_when_missing() {
let val = serde_json::json!({"url": "https://example.com"});
let result = inject_field_if_missing(val, "taskId", "task-1");
assert_eq!(result["taskId"], "task-1");
assert_eq!(result["url"], "https://example.com");
}
#[test]
fn inject_field_preserves_existing() {
let val = serde_json::json!({"taskId": "existing", "url": "https://example.com"});
let result = inject_field_if_missing(val, "taskId", "task-1");
assert_eq!(
result["taskId"], "existing",
"should not overwrite existing field"
);
}
#[tokio::test]
async fn internal_error_response_has_json_body() {
let resp = internal_error_response();
assert_eq!(resp.status().as_u16(), 500);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let text = String::from_utf8_lossy(&body);
assert!(
text.contains("internal serialization error"),
"internal error response should contain error message: {text}"
);
}
#[tokio::test]
async fn error_json_response_various_statuses() {
for status in [400, 403, 404, 422, 500, 503] {
let resp = error_json_response(status, &format!("error {status}"));
assert_eq!(resp.status().as_u16(), status);
}
}
#[tokio::test]
async fn server_error_serialization_maps_to_400() {
let err = ServerError::Serialization(serde_json::from_str::<()>("bad").unwrap_err());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 400);
}
#[tokio::test]
async fn build_json_response_various_statuses() {
for status in [200, 201, 400, 404, 500] {
let resp = build_json_response(status, b"{}".to_vec());
assert_eq!(resp.status().as_u16(), status);
assert_eq!(
resp.headers()
.get("content-type")
.and_then(|v| v.to_str().ok()),
Some(a2a_protocol_types::A2A_CONTENT_TYPE),
);
}
}
#[test]
fn inject_field_on_non_object_is_noop() {
let val = serde_json::json!("string value");
let result = inject_field_if_missing(val.clone(), "taskId", "task-1");
assert_eq!(result, val);
}
#[tokio::test]
async fn server_error_transport_maps_to_500() {
let err = ServerError::Transport("transport broke".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 500);
}
#[tokio::test]
async fn server_error_payload_too_large_maps_to_400() {
let err = ServerError::PayloadTooLarge("too big".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 400);
}
#[tokio::test]
async fn server_error_http_client_maps_to_500() {
let err = ServerError::HttpClient("connection refused".into());
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 500);
}
#[tokio::test]
async fn server_error_invalid_state_transition_maps_to_400() {
use a2a_protocol_types::task::TaskState;
let err = ServerError::InvalidStateTransition {
task_id: "t1".into(),
from: TaskState::Completed,
to: TaskState::Working,
};
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 400);
}
#[tokio::test]
async fn server_error_protocol_maps_to_500() {
let err = ServerError::Protocol(a2a_protocol_types::error::A2aError::internal("proto err"));
let resp = server_error_to_response(&err);
assert_eq!(resp.status().as_u16(), 500);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(val["error"]["message"]
.as_str()
.unwrap_or("")
.contains("proto err"));
}
#[tokio::test]
async fn build_json_response_with_empty_body() {
let resp = build_json_response(200, vec![]);
assert_eq!(resp.status().as_u16(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
assert!(body.is_empty());
}
}