use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use a2a_protocol_types::agent_card::{AgentCapabilities, AgentCard, AgentInterface, AgentSkill};
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
use a2a_protocol_types::params::MessageSendParams;
use a2a_protocol_types::push::TaskPushNotificationConfig;
use a2a_protocol_types::responses::SendMessageResponse;
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
use a2a_protocol_server::builder::RequestHandlerBuilder;
use a2a_protocol_server::dispatch::cors::CorsConfig;
use a2a_protocol_server::dispatch::RestDispatcher;
use a2a_protocol_server::executor::AgentExecutor;
use a2a_protocol_server::push::PushSender;
use a2a_protocol_server::request_context::RequestContext;
use a2a_protocol_server::serve::Dispatcher;
use a2a_protocol_server::streaming::EventQueueWriter;
struct SimpleExecutor;
impl AgentExecutor for SimpleExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct MockPushSender;
impl PushSender for MockPushSender {
fn send<'a>(
&'a self,
_url: &'a str,
_event: &'a StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move { Ok(()) })
}
}
fn minimal_agent_card() -> AgentCard {
AgentCard {
url: None,
name: "Test Agent".into(),
description: "A test agent".into(),
version: "1.0.0".into(),
supported_interfaces: vec![AgentInterface {
url: "https://agent.example.com/rpc".into(),
protocol_binding: "JSONRPC".into(),
protocol_version: "1.0.0".into(),
tenant: None,
}],
default_input_modes: vec!["text/plain".into()],
default_output_modes: vec!["text/plain".into()],
skills: vec![AgentSkill {
id: "echo".into(),
name: "Echo".into(),
description: "Echoes input".into(),
tags: vec!["echo".into()],
examples: None,
input_modes: None,
output_modes: None,
security_requirements: None,
}],
capabilities: AgentCapabilities::none().with_extended_agent_card(true),
provider: None,
icon_url: None,
documentation_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
}
}
fn make_handler() -> Arc<a2a_protocol_server::RequestHandler> {
Arc::new(
RequestHandlerBuilder::new(SimpleExecutor)
.with_agent_card(minimal_agent_card())
.with_push_sender(MockPushSender)
.build()
.expect("build handler"),
)
}
fn make_handler_no_push() -> Arc<a2a_protocol_server::RequestHandler> {
Arc::new(
RequestHandlerBuilder::new(SimpleExecutor)
.with_agent_card(minimal_agent_card())
.build()
.expect("build handler"),
)
}
fn make_send_params() -> MessageSendParams {
MessageSendParams {
tenant: None,
message: Message {
id: MessageId::new("msg-1"),
role: MessageRole::User,
parts: vec![Part::text("hello")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
}
}
async fn start_rest_server_with_cors(
handler: Arc<a2a_protocol_server::RequestHandler>,
cors: Option<CorsConfig>,
) -> (SocketAddr, tokio::task::JoinHandle<()>) {
let mut dispatcher = RestDispatcher::new(handler);
if let Some(c) = cors {
dispatcher = dispatcher.with_cors(c);
}
let dispatcher = Arc::new(dispatcher);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
let handle = tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
let io = hyper_util::rt::TokioIo::new(stream);
let d = Arc::clone(&dispatcher);
tokio::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&d);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
(addr, handle)
}
async fn start_rest_server() -> (SocketAddr, tokio::task::JoinHandle<()>) {
start_rest_server_with_cors(make_handler(), None).await
}
fn http_client() -> hyper_util::client::legacy::Client<
hyper_util::client::legacy::connect::HttpConnector,
Full<Bytes>,
> {
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()).build_http()
}
async fn http_request_full(
addr: SocketAddr,
method: &str,
path: &str,
body: Option<&str>,
content_type: Option<&str>,
) -> (u16, hyper::HeaderMap, String) {
let client = http_client();
let mut builder = hyper::Request::builder()
.method(method)
.uri(format!("http://{addr}{path}"));
if let Some(ct) = content_type {
builder = builder.header("content-type", ct);
}
let body_bytes = body.unwrap_or("").as_bytes().to_vec();
let req = builder.body(Full::new(Bytes::from(body_bytes))).unwrap();
let resp = client.request(req).await.unwrap();
let status = resp.status().as_u16();
let headers = resp.headers().clone();
let body = resp.collect().await.unwrap().to_bytes();
(status, headers, String::from_utf8_lossy(&body).into_owned())
}
async fn http_request(
addr: SocketAddr,
method: &str,
path: &str,
body: Option<&str>,
content_type: Option<&str>,
) -> (u16, String) {
let (status, _headers, body) = http_request_full(addr, method, path, body, content_type).await;
(status, body)
}
#[tokio::test]
async fn cors_preflight_returns_204_with_cors_headers() {
let cors = CorsConfig::new("https://example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let (status, headers, _body) =
http_request_full(addr, "OPTIONS", "/message:send", None, None).await;
assert_eq!(status, 204, "CORS preflight should return 204");
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://example.com"
);
let methods = headers
.get("access-control-allow-methods")
.expect("should have allow-methods");
assert!(
methods.to_str().unwrap().contains("POST"),
"allow-methods should include POST"
);
let allow_headers = headers
.get("access-control-allow-headers")
.expect("should have allow-headers");
assert!(
!allow_headers.is_empty(),
"allow-headers should be non-empty"
);
let max_age = headers
.get("access-control-max-age")
.expect("should have max-age");
assert!(!max_age.is_empty(), "max-age should be non-empty");
}
#[tokio::test]
async fn options_without_cors_returns_health() {
let (addr, _handle) = start_rest_server().await;
let (status, body) = http_request(addr, "OPTIONS", "/anything", None, None).await;
assert_eq!(status, 200);
assert!(
body.contains("ok"),
"OPTIONS without CORS should return health response"
);
}
#[tokio::test]
async fn cors_headers_on_oversized_query_string() {
let cors = CorsConfig::new("https://cors-test.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let long_query = format!("/tasks?q={}", "a".repeat(5000));
let (status, headers, _body) = http_request_full(addr, "GET", &long_query, None, None).await;
assert_eq!(status, 414);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://cors-test.example.com",
"CORS headers should be present on error responses"
);
}
#[tokio::test]
async fn cors_headers_on_health_check() {
let cors = CorsConfig::new("https://health-cors.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let (status, headers, _body) = http_request_full(addr, "GET", "/health", None, None).await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://health-cors.example.com",
"CORS headers should be on health response"
);
}
#[tokio::test]
async fn cors_headers_on_normal_response() {
let cors = CorsConfig::new("https://normal-cors.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let (status, headers, _body) = http_request_full(addr, "GET", "/tasks", None, None).await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://normal-cors.example.com",
"CORS headers should be on normal dispatch responses"
);
}
#[tokio::test]
async fn unknown_action_on_task_returns_404() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) =
http_request(addr, "POST", "/tasks/some-task:unknownAction", None, None).await;
assert_eq!(status, 404, "unknown colon-action should return 404");
}
#[tokio::test]
async fn get_on_cancel_action_returns_404() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(addr, "GET", "/tasks/some-task:cancel", None, None).await;
assert_eq!(status, 404, "GET on :cancel should return 404");
}
#[tokio::test]
async fn empty_task_id_with_action_returns_404() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(addr, "POST", "/tasks/:cancel", None, None).await;
assert_eq!(status, 404, "empty task id with action should return 404");
}
#[tokio::test]
async fn send_message_streaming_returns_sse() {
let (addr, _handle) = start_rest_server().await;
let body = serde_json::to_vec(&make_send_params()).unwrap();
let client = http_client();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/message:stream"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 200);
assert_eq!(
resp.headers()
.get("content-type")
.and_then(|v| v.to_str().ok()),
Some("text/event-stream")
);
}
#[tokio::test]
async fn list_tasks_returns_200_with_empty_list() {
let (addr, _handle) = start_rest_server().await;
let (status, body) = http_request(addr, "GET", "/tasks", None, None).await;
assert_eq!(status, 200);
assert!(body.contains("tasks"));
}
#[tokio::test]
async fn cancel_task_nonexistent_returns_error() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) =
http_request(addr, "POST", "/tasks/nonexistent-task:cancel", None, None).await;
assert_eq!(status, 404);
}
#[tokio::test]
async fn subscribe_nonexistent_task_returns_error() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(
addr,
"POST",
"/tasks/nonexistent-task:subscribe",
None,
None,
)
.await;
assert_eq!(status, 404);
}
#[tokio::test]
async fn subscribe_via_get_nonexistent_task() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) =
http_request(addr, "GET", "/tasks/nonexistent-task:subscribe", None, None).await;
assert_eq!(status, 404);
}
#[tokio::test]
async fn set_push_config_invalid_json_returns_400() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(
addr,
"POST",
"/tasks/task-1/pushNotificationConfigs",
Some("not valid json"),
Some("application/json"),
)
.await;
assert_eq!(status, 400, "invalid JSON body should return 400");
}
#[tokio::test]
async fn set_push_config_missing_fields_returns_400() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(
addr,
"POST",
"/tasks/task-1/pushNotificationConfigs",
Some(r#"{"someField": "value"}"#),
Some("application/json"),
)
.await;
assert_eq!(
status, 400,
"JSON missing required push config fields should return 400"
);
}
#[tokio::test]
async fn set_push_config_not_supported_returns_error() {
let handler = make_handler_no_push();
let (addr, _handle) = start_rest_server_with_cors(handler, None).await;
let config = TaskPushNotificationConfig::new("task-1", "https://example.com/hook");
let body = serde_json::to_vec(&config).unwrap();
let (status, _body) = http_request(
addr,
"POST",
"/tasks/task-1/pushNotificationConfigs",
Some(&String::from_utf8(body).unwrap()),
Some("application/json"),
)
.await;
assert_eq!(status, 400, "push not supported should return 400");
}
#[tokio::test]
async fn get_push_config_nonexistent_returns_error() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(
addr,
"GET",
"/tasks/task-1/pushNotificationConfigs/nonexistent-id",
None,
None,
)
.await;
assert_eq!(status, 400);
}
#[tokio::test]
async fn list_push_configs_empty_returns_200() {
let (addr, _handle) = start_rest_server().await;
let (status, body) = http_request(
addr,
"GET",
"/tasks/task-1/pushNotificationConfigs",
None,
None,
)
.await;
assert_eq!(status, 200);
assert!(
body.contains("\"configs\""),
"response should contain configs field, got: {body}"
);
}
#[tokio::test]
async fn delete_push_config_nonexistent_returns_200() {
let (addr, _handle) = start_rest_server().await;
let (status, _body) = http_request(
addr,
"DELETE",
"/tasks/task-1/pushNotificationConfigs/nonexistent-id",
None,
None,
)
.await;
assert_eq!(status, 200);
}
#[tokio::test]
async fn list_push_configs_no_push_sender_still_works() {
let handler = make_handler_no_push();
let (addr, _handle) = start_rest_server_with_cors(handler, None).await;
let (status, body) = http_request(
addr,
"GET",
"/tasks/task-1/pushNotificationConfigs",
None,
None,
)
.await;
assert_eq!(
status, 200,
"list push configs should succeed even without push sender"
);
assert!(
body.contains("\"configs\""),
"response should contain configs field, got: {body}"
);
}
#[tokio::test]
async fn delete_push_config_no_push_sender_still_works() {
let handler = make_handler_no_push();
let (addr, _handle) = start_rest_server_with_cors(handler, None).await;
let (status, _body) = http_request(
addr,
"DELETE",
"/tasks/task-1/pushNotificationConfigs/some-id",
None,
None,
)
.await;
assert_eq!(status, 200);
}
#[tokio::test]
async fn get_push_config_no_push_sender_returns_400() {
let handler = make_handler_no_push();
let (addr, _handle) = start_rest_server_with_cors(handler, None).await;
let (status, _body) = http_request(
addr,
"GET",
"/tasks/task-1/pushNotificationConfigs/some-id",
None,
None,
)
.await;
assert_eq!(status, 400);
}
#[tokio::test]
async fn extended_card_returns_200() {
let (addr, _handle) = start_rest_server().await;
let (status, body) = http_request(addr, "GET", "/extendedAgentCard", None, None).await;
assert_eq!(status, 200);
assert!(body.contains("Test Agent"));
}
#[test]
fn rest_dispatcher_debug_impl() {
let handler = RequestHandlerBuilder::new(SimpleExecutor)
.build()
.expect("build handler");
let dispatcher = RestDispatcher::new(Arc::new(handler));
let debug_str = format!("{:?}", dispatcher);
assert!(
debug_str.contains("RestDispatcher"),
"Debug impl should contain 'RestDispatcher'"
);
}
#[tokio::test]
async fn dispatcher_trait_dispatch_via_real_server() {
let handler = make_handler();
let dispatcher: Arc<dyn Dispatcher> = Arc::new(RestDispatcher::new(handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
let d = Arc::clone(&dispatcher);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let io = hyper_util::rt::TokioIo::new(stream);
let d = Arc::clone(&d);
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&d);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await;
});
let client = http_client();
let req = hyper::Request::builder()
.method("GET")
.uri(format!("http://{addr}/health"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("request");
assert_eq!(resp.status(), 200);
}
#[test]
fn with_cors_returns_self() {
let handler = RequestHandlerBuilder::new(SimpleExecutor)
.build()
.expect("build handler");
let dispatcher = RestDispatcher::new(Arc::new(handler));
let _dispatcher = dispatcher.with_cors(CorsConfig::permissive());
}
#[tokio::test]
async fn push_config_crud_with_cors_headers() {
let cors = CorsConfig::new("https://crud-cors.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let config = TaskPushNotificationConfig::new("task-1", "https://example.com/hook");
let body = serde_json::to_vec(&config).unwrap();
let (status, headers, resp_body) = http_request_full(
addr,
"POST",
"/tasks/task-1/pushNotificationConfigs",
Some(&String::from_utf8(body).unwrap()),
Some("application/json"),
)
.await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://crud-cors.example.com",
"CORS headers should be present on push config create response"
);
let created: TaskPushNotificationConfig =
serde_json::from_str(&resp_body).expect("parse config");
let config_id = created.id.unwrap();
let (status, headers, _body) = http_request_full(
addr,
"GET",
&format!("/tasks/task-1/pushNotificationConfigs/{config_id}"),
None,
None,
)
.await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://crud-cors.example.com"
);
let (status, headers, _body) = http_request_full(
addr,
"GET",
"/tasks/task-1/pushNotificationConfigs",
None,
None,
)
.await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://crud-cors.example.com"
);
let (status, headers, _body) = http_request_full(
addr,
"DELETE",
&format!("/tasks/task-1/pushNotificationConfigs/{config_id}"),
None,
None,
)
.await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://crud-cors.example.com"
);
}
#[tokio::test]
async fn subscribe_existing_task_returns_sse() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let body = serde_json::to_vec(&make_send_params()).unwrap();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/message:send"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("send");
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let result: SendMessageResponse = serde_json::from_slice(&body).expect("parse");
let task_id = match result {
SendMessageResponse::Task(t) => t.id.0,
_ => panic!("expected Task variant"),
};
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/tasks/{task_id}:subscribe"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("subscribe");
assert_ne!(
resp.status(),
404,
"subscribe to existing task should be routed"
);
}
#[tokio::test]
async fn cancel_existing_task() {
let (addr, _handle) = start_rest_server().await;
let client = http_client();
let body = serde_json::to_vec(&make_send_params()).unwrap();
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/message:send"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap();
let resp = client.request(req).await.expect("send");
assert_eq!(resp.status(), 200);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let result: SendMessageResponse = serde_json::from_slice(&body).expect("parse");
let task_id = match result {
SendMessageResponse::Task(t) => t.id.0,
_ => panic!("expected Task variant"),
};
let req = hyper::Request::builder()
.method("POST")
.uri(format!("http://{addr}/tasks/{task_id}:cancel"))
.body(Full::new(Bytes::new()))
.unwrap();
let resp = client.request(req).await.expect("cancel");
let status = resp.status().as_u16();
assert!(
status == 200 || status == 400 || status == 409,
"cancel should return 200, 400, or 409, got {status}"
);
}
#[tokio::test]
async fn cors_headers_on_send_message_response() {
let cors = CorsConfig::new("https://send-cors.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let body = serde_json::to_vec(&make_send_params()).unwrap();
let (status, headers, _body) = http_request_full(
addr,
"POST",
"/message:send",
Some(&String::from_utf8(body).unwrap()),
Some("application/json"),
)
.await;
assert_eq!(status, 200);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://send-cors.example.com",
"CORS headers should be on send_message response"
);
}
#[tokio::test]
async fn cors_headers_on_not_found_response() {
let cors = CorsConfig::new("https://notfound-cors.example.com");
let (addr, _handle) = start_rest_server_with_cors(make_handler(), Some(cors)).await;
let (status, headers, _body) =
http_request_full(addr, "GET", "/nonexistent/path", None, None).await;
assert_eq!(status, 404);
assert_eq!(
headers.get("access-control-allow-origin").unwrap(),
"https://notfound-cors.example.com",
"CORS headers should be on 404 response"
);
}