use axum::{
extract::{Path, Request, State},
http::{HeaderMap, StatusCode},
middleware::{self, Next},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
},
routing::{get, post},
Json, Router,
};
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::BroadcastStream;
use tracing::warn;
use crate::auth::{AuthError, AuthValidator, NoAuth};
use crate::events::StreamEvent;
use crate::server::A2aDispatcher;
use crate::types::{SendMessageParams, Task};
#[derive(Debug, Deserialize)]
struct JsonRpcRequest {
#[serde(default)]
jsonrpc: String,
method: String,
#[serde(default)]
params: Value,
#[serde(default)]
id: Value,
}
#[derive(Debug, Serialize)]
struct JsonRpcResponse {
jsonrpc: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<JsonRpcErrorBody>,
id: Value,
}
#[derive(Debug, Serialize)]
struct JsonRpcErrorBody {
code: i32,
message: String,
}
impl JsonRpcResponse {
fn ok(id: Value, result: Value) -> Self {
Self {
jsonrpc: "2.0",
result: Some(result),
error: None,
id,
}
}
fn err(id: Value, code: i32, message: String) -> Self {
Self {
jsonrpc: "2.0",
result: None,
error: Some(JsonRpcErrorBody { code, message }),
id,
}
}
}
fn router(dispatcher: A2aDispatcher, auth: Arc<dyn AuthValidator>) -> Router {
let state = Arc::new(dispatcher);
let public = Router::new()
.route("/.well-known/agent-card.json", get(handle_agent_card))
.route("/.well-known/agent.json", get(handle_agent_card));
let protected = Router::new()
.route("/", post(handle_rpc))
.route("/a2a", post(handle_rpc))
.route("/a2a/stream/:task_id", get(handle_stream))
.layer(middleware::from_fn_with_state(
auth.clone(),
auth_middleware,
));
public.merge(protected).with_state(state)
}
pub fn build_router(dispatcher: A2aDispatcher) -> Router {
router(dispatcher, Arc::new(NoAuth))
}
pub fn build_router_with_auth(
dispatcher: A2aDispatcher,
auth: Arc<dyn AuthValidator>,
) -> Router {
router(dispatcher, auth)
}
pub async fn serve(
dispatcher: A2aDispatcher,
addr: SocketAddr,
) -> std::io::Result<(SocketAddr, JoinHandle<()>)> {
serve_with_auth(dispatcher, addr, Arc::new(NoAuth)).await
}
pub async fn serve_with_auth(
dispatcher: A2aDispatcher,
addr: SocketAddr,
auth: Arc<dyn AuthValidator>,
) -> std::io::Result<(SocketAddr, JoinHandle<()>)> {
let listener = TcpListener::bind(addr).await?;
let bound = listener.local_addr()?;
let app = router(dispatcher, auth);
let handle = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app).await {
warn!(error = %e, "a2a HTTP server exited");
}
});
Ok((bound, handle))
}
async fn auth_middleware(
State(auth): State<Arc<dyn AuthValidator>>,
headers: HeaderMap,
req: Request,
next: Next,
) -> Response {
match auth.validate(&headers).await {
Ok(()) => next.run(req).await,
Err(err) => {
let path = req.uri().path().to_string();
let challenge = match err {
AuthError::Missing => auth.challenge(),
AuthError::Invalid => format!("{} error=\"invalid_token\"", auth.challenge()),
};
tracing::debug!(?err, %path, "a2a auth rejected");
let mut resp = (StatusCode::UNAUTHORIZED, "unauthorized").into_response();
if let Ok(value) = axum::http::HeaderValue::from_str(&challenge) {
resp.headers_mut().insert("www-authenticate", value);
}
resp
}
}
}
async fn handle_agent_card(
State(dispatcher): State<Arc<A2aDispatcher>>,
) -> Result<Json<Value>, (StatusCode, String)> {
match dispatcher
.dispatch(
"agent/getAuthenticatedExtendedCard",
Value::Null,
)
.await
{
Ok(card) => Ok(Json(card)),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
}
}
async fn handle_rpc(
State(dispatcher): State<Arc<A2aDispatcher>>,
Json(req): Json<JsonRpcRequest>,
) -> Response {
if req.jsonrpc != "2.0" {
let body = JsonRpcResponse::err(
req.id,
-32600,
if req.jsonrpc.is_empty() {
"missing required `jsonrpc` field (must be \"2.0\")".to_string()
} else {
format!("unsupported jsonrpc version `{}`", req.jsonrpc)
},
);
return (StatusCode::OK, Json(serde_json::to_value(body).unwrap())).into_response();
}
match req.method.as_str() {
"message/stream" | "SendStreamingMessage" => {
let params: SendMessageParams = match serde_json::from_value(req.params) {
Ok(p) => p,
Err(e) => {
return rpc_error_response(req.id, -32602, e.to_string());
}
};
return match dispatcher.start_message_stream(params).await {
Ok((task, receiver)) => sse_response(
dispatcher.clone(),
task.id.clone(),
Some(req.id),
Some(task),
receiver,
),
Err(err) => rpc_error_response(req.id, err.code(), err.to_string()),
};
}
"tasks/resubscribe" | "SubscribeToTask" => {
#[derive(Deserialize)]
struct ResubscribeParams {
id: String,
}
let params: ResubscribeParams = match serde_json::from_value(req.params) {
Ok(p) => p,
Err(e) => return rpc_error_response(req.id, -32602, e.to_string()),
};
return match dispatcher.resubscribe_task(¶ms.id).await {
Ok((task, receiver)) => sse_response(
dispatcher.clone(),
task.id.clone(),
Some(req.id),
Some(task),
receiver,
),
Err(err) => rpc_error_response(req.id, err.code(), err.to_string()),
};
}
_ => {}
}
let response = match dispatcher.dispatch(&req.method, req.params).await {
Ok(value) => JsonRpcResponse::ok(req.id, value),
Err(err) => {
let code = err.code();
JsonRpcResponse::err(req.id, code, err.to_string())
}
};
(
StatusCode::OK,
Json(serde_json::to_value(response).unwrap()),
)
.into_response()
}
fn rpc_error_response(id: Value, code: i32, message: String) -> Response {
let body = JsonRpcResponse::err(id, code, message);
(StatusCode::OK, Json(serde_json::to_value(body).unwrap())).into_response()
}
async fn handle_stream(
State(dispatcher): State<Arc<A2aDispatcher>>,
Path(task_id): Path<String>,
) -> Response {
let receiver = dispatcher.subscribe(&task_id).await;
sse_response(dispatcher, task_id, None, None, receiver)
}
fn sse_response(
dispatcher: Arc<A2aDispatcher>,
task_id: String,
request_id: Option<Value>,
initial: Option<Task>,
receiver: broadcast::Receiver<StreamEvent>,
) -> Response {
let task_id = Arc::new(task_id);
let request_id = Arc::new(request_id);
let initial_frame = initial.and_then(|task| {
encode_frame(&StreamEvent::Task(task), (*request_id).as_ref()).map(Ok::<_, Infallible>)
});
let initial_stream = futures::stream::iter(initial_frame);
let raw = BroadcastStream::new(receiver);
let main = {
let dispatcher = dispatcher.clone();
let task_id = task_id.clone();
let request_id = request_id.clone();
raw.filter_map(move |res| {
let dispatcher = dispatcher.clone();
let task_id = task_id.clone();
let request_id = request_id.clone();
async move {
match res {
Ok(event) => encode_frame(&event, (*request_id).as_ref())
.map(Ok::<_, Infallible>)
.or_else(|| {
warn!("sse serialize failed");
None
}),
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
warn!(count = n, "sse subscriber lagged, emitting resync snapshot");
match dispatcher.current_task(&task_id).await {
Some(task) => encode_frame(
&StreamEvent::Task(task),
(*request_id).as_ref(),
)
.map(Ok::<_, Infallible>),
None => None,
}
}
}
}
})
};
let combined: futures::stream::BoxStream<'static, Result<Event, Infallible>> =
initial_stream.chain(main).boxed();
Sse::new(combined)
.keep_alive(KeepAlive::default())
.into_response()
}
fn encode_frame(event: &StreamEvent, request_id: Option<&Value>) -> Option<Event> {
let payload = match request_id {
Some(id) => serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": event,
}))
.ok()?,
None => serde_json::to_string(event).ok()?,
};
Some(Event::default().data(payload))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::InMemoryTaskStore;
use crate::types::{AgentCard, AgentProvider};
use crate::AgentCardSource;
use car_engine::Runtime;
use std::sync::Arc;
fn make_dispatcher() -> A2aDispatcher {
let runtime = Arc::new(Runtime::new());
let store = Arc::new(InMemoryTaskStore::new());
let card: Arc<AgentCardSource> = Arc::new(|| AgentCard {
name: "CAR".into(),
description: "test".into(),
url: "http://localhost".into(),
version: "1.0.0".into(),
protocol_version: "1.0".into(),
preferred_transport: Some("JSONRPC".into()),
provider: AgentProvider {
organization: "Parslee".into(),
url: None,
},
capabilities: Default::default(),
default_input_modes: vec!["text".into()],
default_output_modes: vec!["text".into()],
skills: vec![],
documentation_url: None,
icon_url: None,
supported_interfaces: vec![],
additional_interfaces: vec![],
security_schemes: Default::default(),
supports_authenticated_extended_card: false,
security_requirements: vec![],
signatures: vec![],
});
A2aDispatcher::new(runtime, store, card)
}
#[tokio::test]
async fn agent_card_endpoint_round_trips() {
let dispatcher = make_dispatcher();
let (addr, _handle) =
serve(dispatcher, "127.0.0.1:0".parse().unwrap()).await.unwrap();
let url = format!("http://{}/.well-known/agent-card.json", addr);
let resp = reqwest::get(url).await.unwrap();
assert!(resp.status().is_success());
let body: Value = resp.json().await.unwrap();
assert_eq!(body["name"], "CAR");
assert_eq!(body["version"], "1.0.0");
}
#[tokio::test]
async fn rpc_endpoint_handles_method_call() {
let dispatcher = make_dispatcher();
let (addr, _handle) =
serve(dispatcher, "127.0.0.1:0".parse().unwrap()).await.unwrap();
let url = format!("http://{}/", addr);
let body = serde_json::json!({
"jsonrpc": "2.0",
"method": "agent/getAuthenticatedExtendedCard",
"params": null,
"id": 1
});
let resp = reqwest::Client::new()
.post(url)
.json(&body)
.send()
.await
.unwrap();
assert!(resp.status().is_success());
let json: Value = resp.json().await.unwrap();
assert_eq!(json["id"], 1);
assert_eq!(json["result"]["name"], "CAR");
}
#[tokio::test]
async fn rpc_endpoint_returns_error_for_unknown_method() {
let dispatcher = make_dispatcher();
let (addr, _handle) =
serve(dispatcher, "127.0.0.1:0".parse().unwrap()).await.unwrap();
let url = format!("http://{}/", addr);
let body = serde_json::json!({
"jsonrpc": "2.0",
"method": "nonexistent",
"params": {},
"id": 7
});
let resp = reqwest::Client::new()
.post(url)
.json(&body)
.send()
.await
.unwrap();
let json: Value = resp.json().await.unwrap();
assert_eq!(json["id"], 7);
assert_eq!(json["error"]["code"], -32601);
}
}