use axum::{
Router,
body::Body,
extract::{Path, State},
http::{StatusCode, header},
response::Response,
routing::{get, post},
};
use super::{ServerState, handle_request};
fn build_response(status: StatusCode, content_type: &str, body: impl Into<Body>) -> Response {
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, content_type)
.body(body.into())
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_default()
})
}
fn json_ok(body: impl Into<Body>) -> Response {
build_response(StatusCode::OK, "application/json", body)
}
fn json_ok_cors(body: impl Into<Body>) -> Response {
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(body.into())
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_default()
})
}
pub fn a2a_router(state: ServerState) -> Router {
Router::new()
.route(crate::WELL_KNOWN_AGENT_CARD_PATH, get(handle_agent_card))
.route("/", post(handle_jsonrpc))
.route("/stream", post(handle_sse))
.with_state(state)
}
pub fn a2a_full_router(state: ServerState) -> Router {
let rest = super::rest::rest_router_inner();
Router::new()
.route(crate::WELL_KNOWN_AGENT_CARD_PATH, get(handle_agent_card))
.route("/", post(handle_jsonrpc))
.route("/stream", post(handle_sse))
.merge(rest)
.with_state(state)
}
pub fn a2a_tenant_router(state: ServerState) -> Router {
let rest = super::rest::rest_router_inner();
let base = Router::new()
.route(crate::WELL_KNOWN_AGENT_CARD_PATH, get(handle_agent_card))
.route("/", post(handle_jsonrpc))
.route("/stream", post(handle_sse))
.merge(rest.clone());
let tenant_routes = Router::new()
.route("/", post(handle_tenant_jsonrpc))
.route("/stream", post(handle_tenant_sse))
.merge(rest)
.layer(axum::middleware::from_fn(inject_tenant_header));
Router::new()
.merge(base)
.nest("/{tenant}", tenant_routes)
.with_state(state)
}
async fn inject_tenant_header(
Path(params): Path<std::collections::HashMap<String, String>>,
mut req: axum::http::Request<Body>,
next: axum::middleware::Next,
) -> Response {
if let Some(tenant) = params.get("tenant")
&& let Ok(val) = axum::http::HeaderValue::from_str(tenant)
{
req.headers_mut().insert("x-a2a-tenant", val);
}
next.run(req).await
}
pub async fn handle_agent_card(State(state): State<ServerState>) -> Response {
state.card_producer.card().await.map_or_else(
|_| {
build_response(
StatusCode::INTERNAL_SERVER_ERROR,
"text/plain",
Body::empty(),
)
},
|card| json_ok_cors(serde_json::to_string(&card).unwrap_or_default()),
)
}
pub async fn handle_jsonrpc(
State(state): State<ServerState>,
headers: axum::http::HeaderMap,
body: String,
) -> Response {
let meta = super::RequestMeta::from_header_map(&headers);
match super::REQUEST_META
.scope(meta, handle_request(&state, &body))
.await
{
Ok(response) => json_ok(response),
Err(e) => json_ok(
serde_json::json!({
"jsonrpc": "2.0",
"error": { "code": -32603, "message": e.to_string() },
"id": null
})
.to_string(),
),
}
}
pub async fn handle_sse(
State(state): State<ServerState>,
headers: axum::http::HeaderMap,
body: String,
) -> Response {
use axum::response::IntoResponse;
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use futures::StreamExt;
use super::handler::parse_params;
use crate::error::JsonRpcError;
use crate::jsonrpc::{self, JsonRpcRequest};
use crate::types::{SendMessageRequest, SubscribeToTaskRequest};
let meta = super::RequestMeta::from_header_map(&headers);
let request: JsonRpcRequest<serde_json::Value> = match serde_json::from_str(&body) {
Ok(req) => req,
Err(_) => {
return sse_error_response(None, &JsonRpcError::parse_error());
}
};
let request_id = request.id.clone();
let handler = &state.handler;
let event_stream = super::REQUEST_META
.scope(meta, async {
match request.method.as_str() {
jsonrpc::METHOD_MESSAGE_STREAM => {
match parse_params::<SendMessageRequest>(&request) {
Ok(p) => handler.on_message_stream(p).await,
Err(e) => Err(e),
}
}
jsonrpc::METHOD_TASKS_RESUBSCRIBE => {
match parse_params::<SubscribeToTaskRequest>(&request) {
Ok(p) => handler.on_subscribe_to_task(p).await,
Err(e) => Err(e),
}
}
_ => Err(JsonRpcError::method_not_found(&request.method).into()),
}
})
.await;
let event_stream = match event_stream {
Ok(s) => s,
Err(e) => {
return sse_error_response(Some(&request_id), &e.to_jsonrpc_error());
}
};
let id_for_stream = request_id.clone();
let sse_stream = event_stream.map(move |item| {
let data = match item {
Ok(event) => serde_json::json!({
"jsonrpc": "2.0",
"id": id_for_stream,
"result": event,
}),
Err(e) => {
let rpc_err = e.to_jsonrpc_error();
serde_json::json!({
"jsonrpc": "2.0",
"id": id_for_stream,
"error": { "code": rpc_err.code, "message": rpc_err.message },
})
}
};
Ok::<_, std::convert::Infallible>(SseEvent::default().data(data.to_string()))
});
Sse::new(sse_stream)
.keep_alive(KeepAlive::default())
.into_response()
}
async fn handle_tenant_jsonrpc(
State(state): State<ServerState>,
Path(tenant): Path<String>,
headers: axum::http::HeaderMap,
body: String,
) -> Response {
let mut meta = super::RequestMeta::from_header_map(&headers);
meta.set("x-a2a-tenant", tenant);
match super::REQUEST_META
.scope(meta, handle_request(&state, &body))
.await
{
Ok(response) => json_ok(response),
Err(e) => json_ok(
serde_json::json!({
"jsonrpc": "2.0",
"error": { "code": -32603, "message": e.to_string() },
"id": null
})
.to_string(),
),
}
}
async fn handle_tenant_sse(
State(state): State<ServerState>,
Path(tenant): Path<String>,
headers: axum::http::HeaderMap,
body: String,
) -> Response {
use axum::response::IntoResponse;
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use futures::StreamExt;
use super::handler::parse_params;
use crate::error::JsonRpcError;
use crate::jsonrpc::{self, JsonRpcRequest};
use crate::types::{SendMessageRequest, SubscribeToTaskRequest};
let mut meta = super::RequestMeta::from_header_map(&headers);
meta.set("x-a2a-tenant", tenant);
let request: JsonRpcRequest<serde_json::Value> = match serde_json::from_str(&body) {
Ok(req) => req,
Err(_) => {
return sse_error_response(None, &JsonRpcError::parse_error());
}
};
let request_id = request.id.clone();
let handler = &state.handler;
let event_stream = super::REQUEST_META
.scope(meta, async {
match request.method.as_str() {
jsonrpc::METHOD_MESSAGE_STREAM => {
match parse_params::<SendMessageRequest>(&request) {
Ok(p) => handler.on_message_stream(p).await,
Err(e) => Err(e),
}
}
jsonrpc::METHOD_TASKS_RESUBSCRIBE => {
match parse_params::<SubscribeToTaskRequest>(&request) {
Ok(p) => handler.on_subscribe_to_task(p).await,
Err(e) => Err(e),
}
}
_ => Err(JsonRpcError::method_not_found(&request.method).into()),
}
})
.await;
let event_stream = match event_stream {
Ok(s) => s,
Err(e) => {
return sse_error_response(Some(&request_id), &e.to_jsonrpc_error());
}
};
let id_for_stream = request_id.clone();
let sse_stream = event_stream.map(move |item| {
let data = match item {
Ok(event) => serde_json::json!({
"jsonrpc": "2.0",
"id": id_for_stream,
"result": event,
}),
Err(e) => {
let rpc_err = e.to_jsonrpc_error();
serde_json::json!({
"jsonrpc": "2.0",
"id": id_for_stream,
"error": { "code": rpc_err.code, "message": rpc_err.message },
})
}
};
Ok::<_, std::convert::Infallible>(SseEvent::default().data(data.to_string()))
});
Sse::new(sse_stream)
.keep_alive(KeepAlive::default())
.into_response()
}
fn sse_error_response(
id: Option<&crate::jsonrpc::RequestId>,
error: &crate::error::JsonRpcError,
) -> Response {
json_ok(
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": error.code, "message": error.message },
})
.to_string(),
)
}