use super::*;
#[tokio::test]
async fn rest_rejects_wrong_content_type_on_post() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let body = serde_json::to_vec(&make_send_params()).unwrap();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/message:send"))
.header("content-type", "text/xml")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 415, "wrong content type should return 415");
}
#[tokio::test]
async fn rest_health_endpoint_returns_ok() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let req = hyper::Request::builder()
.method("GET")
.uri(format!("http://{addr}/health"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let value: serde_json::Value = serde_json::from_slice(&body).expect("parse health");
assert_eq!(value["status"], "ok");
}
#[tokio::test]
async fn rest_ready_endpoint_returns_ok() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let req = hyper::Request::builder()
.method("GET")
.uri(format!("http://{addr}/ready"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn rest_rejects_path_traversal() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let req = hyper::Request::builder()
.method("GET")
.uri(format!("http://{addr}/tasks/../../../etc/passwd"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 400, "path traversal should be rejected");
}
async fn start_jsonrpc_server_small_body() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
use a2a_protocol_server::dispatch::DispatchConfig;
let handler = Arc::new(
a2a_protocol_server::builder::RequestHandlerBuilder::new(SimpleExecutor)
.with_push_sender(MockPushSender)
.build()
.expect("build handler"),
);
let config = DispatchConfig::default().with_max_request_body_size(64);
let dispatcher = Arc::new(JsonRpcDispatcher::with_config(handler, config));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
let handle = tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
let io = hyper_util::rt::TokioIo::new(stream);
let dispatcher = Arc::clone(&dispatcher);
tokio::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&dispatcher);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
(addr, handle)
}
async fn start_rest_server_small_body() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
use a2a_protocol_server::dispatch::DispatchConfig;
let handler = Arc::new(
a2a_protocol_server::builder::RequestHandlerBuilder::new(SimpleExecutor)
.with_push_sender(MockPushSender)
.build()
.expect("build handler"),
);
let config = DispatchConfig::default().with_max_request_body_size(64);
let dispatcher = Arc::new(RestDispatcher::with_config(handler, config));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
let handle = tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
let io = hyper_util::rt::TokioIo::new(stream);
let dispatcher = Arc::clone(&dispatcher);
tokio::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&dispatcher);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
(addr, handle)
}
#[tokio::test]
async fn jsonrpc_rejects_oversized_body() {
let (addr, _handle) = start_jsonrpc_server_small_body().await;
let client = http_client();
let oversized = "x".repeat(200);
let rpc = a2a_protocol_types::jsonrpc::JsonRpcRequest::with_params(
serde_json::json!(1),
"SendMessage",
serde_json::json!({ "data": oversized }),
);
let body = serde_json::to_vec(&rpc).unwrap();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
val["error"]["message"]
.as_str()
.unwrap_or("")
.contains("too large"),
"expected 'too large' in error message, got: {val}"
);
}
#[tokio::test]
async fn rest_rejects_oversized_body() {
let (addr, _handle) = start_rest_server_small_body().await;
let client = http_client();
let oversized = "x".repeat(200);
let body_json = serde_json::json!({ "data": oversized });
let body = serde_json::to_vec(&body_json).unwrap();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/message:send"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 413, "oversized body should return 413");
let body = resp.into_body().collect().await.unwrap().to_bytes();
let val: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
val["error"]["message"]
.as_str()
.unwrap_or("")
.contains("too large"),
"expected 'too large' in error message, got: {val}"
);
}