mod response;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::body::Incoming;
use a2a_protocol_types::jsonrpc::{
JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
JsonRpcVersion,
};
use crate::agent_card::StaticAgentCardHandler;
use crate::dispatch::cors::CorsConfig;
use crate::error::ServerError;
use crate::handler::{RequestHandler, SendMessageResult};
use crate::serve::Dispatcher;
use crate::streaming::build_sse_response;
use response::{
error_response, error_response_bytes, extract_headers, json_response, parse_error_response,
parse_params, read_body_limited, success_response, success_response_bytes,
};
pub struct JsonRpcDispatcher {
handler: Arc<RequestHandler>,
card_handler: Option<StaticAgentCardHandler>,
cors: Option<CorsConfig>,
config: super::DispatchConfig,
}
impl JsonRpcDispatcher {
#[must_use]
pub fn new(handler: Arc<RequestHandler>) -> Self {
Self::with_config(handler, super::DispatchConfig::default())
}
#[must_use]
pub fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
let card_handler = handler
.agent_card
.as_ref()
.and_then(|card| StaticAgentCardHandler::new(card).ok());
Self {
handler,
card_handler,
cors: None,
config,
}
}
#[must_use]
pub fn with_cors(mut self, cors: CorsConfig) -> Self {
self.cors = Some(cors);
self
}
pub async fn dispatch(
&self,
req: hyper::Request<Incoming>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
if req.method() == "OPTIONS" {
if let Some(ref cors) = self.cors {
return cors.preflight_response();
}
return json_response(204, Vec::new());
}
if req.method() == "GET" && req.uri().path() == "/.well-known/agent-card.json" {
let mut resp = self.card_handler.as_ref().map_or_else(
|| json_response(404, br#"{"error":"agent card not configured"}"#.to_vec()),
|h| h.handle(&req).map(http_body_util::BodyExt::boxed),
);
if let Some(ref cors) = self.cors {
cors.apply_headers(&mut resp);
}
return resp;
}
let mut resp = self.dispatch_inner(req).await;
if let Some(ref cors) = self.cors {
cors.apply_headers(&mut resp);
}
resp
}
#[allow(clippy::too_many_lines)]
async fn dispatch_inner(
&self,
req: hyper::Request<Incoming>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
if let Some(ct) = req.headers().get("content-type") {
let ct_str = ct.to_str().unwrap_or("");
if !ct_str.starts_with("application/json")
&& !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
{
return parse_error_response(
None,
&format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
);
}
}
if let Some(version) = req.headers().get(a2a_protocol_types::A2A_VERSION_HEADER) {
if let Ok(v) = version.to_str() {
let v = v.trim();
if !v.is_empty() {
let major = v.split('.').next().and_then(|s| s.parse::<u32>().ok());
if major != Some(1) {
return error_response(
None,
&ServerError::Protocol(a2a_protocol_types::error::A2aError::new(
a2a_protocol_types::error::ErrorCode::VersionNotSupported,
format!("unsupported A2A version: {v}; this server supports 1.x"),
)),
);
}
}
}
}
let headers = extract_headers(req.headers());
let body_bytes = match read_body_limited(
req.into_body(),
self.config.max_request_body_size,
self.config.body_read_timeout,
)
.await
{
Ok(bytes) => bytes,
Err(msg) => return parse_error_response(None, &msg),
};
let raw: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(v) => v,
Err(e) => return parse_error_response(None, &e.to_string()),
};
if raw.is_array() {
let serde_json::Value::Array(items) = raw else {
unreachable!()
};
if items.is_empty() {
return parse_error_response(None, "empty batch request");
}
if items.len() > self.config.max_batch_size {
return parse_error_response(
None,
&format!(
"batch too large: {} requests exceeds {} limit",
items.len(),
self.config.max_batch_size
),
);
}
let mut responses: Vec<serde_json::Value> = Vec::with_capacity(items.len());
for item in items {
let rpc_req: JsonRpcRequest = match serde_json::from_value(item) {
Ok(r) => r,
Err(e) => {
let err_resp = JsonRpcErrorResponse::new(
None,
JsonRpcError::new(
a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
format!("Parse error: {e}"),
),
);
if let Ok(v) = serde_json::to_value(&err_resp) {
responses.push(v);
}
continue;
}
};
let resp_body = self.dispatch_single_request(&rpc_req, &headers).await;
if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&resp_body) {
responses.push(v);
}
}
let body = serde_json::to_vec(&responses).unwrap_or_default();
json_response(200, body)
} else {
let rpc_req: JsonRpcRequest = match serde_json::from_value(raw) {
Ok(r) => r,
Err(e) => return parse_error_response(None, &e.to_string()),
};
self.dispatch_single_request_http(&rpc_req, &headers).await
}
}
#[allow(clippy::too_many_lines)]
async fn dispatch_single_request_http(
&self,
rpc_req: &JsonRpcRequest,
headers: &HashMap<String, String>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let id = rpc_req.id.clone();
trace_info!(method = %rpc_req.method, "dispatching JSON-RPC request");
match rpc_req.method.as_str() {
"SendStreamingMessage" | "message/stream" => {
return self.dispatch_send_message(id, rpc_req, true, headers).await;
}
"SubscribeToTask" | "tasks/subscribe" => {
return match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
Ok(p) => match self.handler.on_resubscribe(p, Some(headers)).await {
Ok(reader) => build_sse_response(
reader,
Some(self.config.sse_keep_alive_interval),
Some(self.config.sse_channel_capacity),
true, ),
Err(e) => error_response(id, &e),
},
Err(e) => error_response(id, &e),
};
}
_ => {}
}
let body = self.dispatch_single_request(rpc_req, headers).await;
json_response(200, body)
}
#[allow(clippy::too_many_lines)]
async fn dispatch_single_request(
&self,
rpc_req: &JsonRpcRequest,
headers: &HashMap<String, String>,
) -> Vec<u8> {
let id = rpc_req.id.clone();
match rpc_req.method.as_str() {
"SendMessage" | "message/send" => {
match self
.dispatch_send_message_inner(id.clone(), rpc_req, false, headers)
.await
{
Ok(resp) => serde_json::to_vec(&resp).unwrap_or_default(),
Err(body) => body,
}
}
"SendStreamingMessage" | "message/stream" => {
let err = ServerError::InvalidParams(
"SendStreamingMessage not supported in batch requests".into(),
);
let a2a_err = err.to_a2a_error();
let resp = JsonRpcErrorResponse::new(
id,
JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
);
serde_json::to_vec(&resp).unwrap_or_default()
}
"GetTask" | "tasks/get" => {
match parse_params::<a2a_protocol_types::params::TaskQueryParams>(rpc_req) {
Ok(p) => match self.handler.on_get_task(p, Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"ListTasks" | "tasks/list" => {
match parse_params::<a2a_protocol_types::params::ListTasksParams>(rpc_req) {
Ok(p) => match self.handler.on_list_tasks(p, Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"CancelTask" | "tasks/cancel" => {
match parse_params::<a2a_protocol_types::params::CancelTaskParams>(rpc_req) {
Ok(p) => match self.handler.on_cancel_task(p, Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"SubscribeToTask" | "tasks/subscribe" => {
let err = ServerError::InvalidParams(
"SubscribeToTask not supported in batch requests".into(),
);
error_response_bytes(id, &err)
}
"CreateTaskPushNotificationConfig" | "tasks/pushNotificationConfig/set" => {
match parse_params::<a2a_protocol_types::push::TaskPushNotificationConfig>(rpc_req)
{
Ok(p) => match self.handler.on_set_push_config(p, Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"GetTaskPushNotificationConfig" | "tasks/pushNotificationConfig/get" => {
match parse_params::<a2a_protocol_types::params::GetPushConfigParams>(rpc_req) {
Ok(p) => match self.handler.on_get_push_config(p, Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"ListTaskPushNotificationConfigs" | "tasks/pushNotificationConfig/list" => {
match parse_params::<a2a_protocol_types::params::ListPushConfigsParams>(rpc_req) {
Ok(p) => match self
.handler
.on_list_push_configs(&p.task_id, p.tenant.as_deref(), Some(headers))
.await
{
Ok(configs) => {
let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
configs,
next_page_token: None,
};
success_response_bytes(id, &resp)
}
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"DeleteTaskPushNotificationConfig" | "tasks/pushNotificationConfig/delete" => {
match parse_params::<a2a_protocol_types::params::DeletePushConfigParams>(rpc_req) {
Ok(p) => match self.handler.on_delete_push_config(p, Some(headers)).await {
Ok(()) => success_response_bytes(id, &serde_json::json!({})),
Err(e) => error_response_bytes(id, &e),
},
Err(e) => error_response_bytes(id, &e),
}
}
"GetExtendedAgentCard" | "agent/authenticatedExtendedCard" => {
match self.handler.on_get_extended_agent_card(Some(headers)).await {
Ok(r) => success_response_bytes(id, &r),
Err(e) => error_response_bytes(id, &e),
}
}
other => {
let err = ServerError::MethodNotFound(other.to_owned());
error_response_bytes(id, &err)
}
}
}
async fn dispatch_send_message_inner(
&self,
id: JsonRpcId,
rpc_req: &JsonRpcRequest,
streaming: bool,
headers: &HashMap<String, String>,
) -> Result<JsonRpcSuccessResponse<serde_json::Value>, Vec<u8>> {
let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
Ok(p) => p,
Err(e) => return Err(error_response_bytes(id, &e)),
};
match self
.handler
.on_send_message(params, streaming, Some(headers))
.await
{
Ok(SendMessageResult::Response(resp)) => {
let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
Ok(JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id,
result,
})
}
Ok(SendMessageResult::Stream(_)) => {
let err = ServerError::Internal("unexpected stream response".into());
Err(error_response_bytes(id, &err))
}
Err(e) => Err(error_response_bytes(id, &e)),
}
}
async fn dispatch_send_message(
&self,
id: JsonRpcId,
rpc_req: &JsonRpcRequest,
streaming: bool,
headers: &HashMap<String, String>,
) -> hyper::Response<BoxBody<Bytes, Infallible>> {
let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
Ok(p) => p,
Err(e) => return error_response(id, &e),
};
match self
.handler
.on_send_message(params, streaming, Some(headers))
.await
{
Ok(SendMessageResult::Response(resp)) => success_response(id, &resp),
Ok(SendMessageResult::Stream(reader)) => build_sse_response(
reader,
Some(self.config.sse_keep_alive_interval),
Some(self.config.sse_channel_capacity),
true, ),
Err(e) => error_response(id, &e),
}
}
}
impl std::fmt::Debug for JsonRpcDispatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonRpcDispatcher").finish()
}
}
impl Dispatcher for JsonRpcDispatcher {
fn dispatch(
&self,
req: hyper::Request<Incoming>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::serve::DispatchResponse> + Send + '_>,
> {
Box::pin(self.dispatch(req))
}
}