use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::WebSocketStream;
use a2a_protocol_types::jsonrpc::{
JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
JsonRpcVersion,
};
use crate::error::ServerError;
use crate::handler::{RequestHandler, SendMessageResult};
use crate::streaming::EventQueueReader;
pub struct WebSocketDispatcher {
handler: Arc<RequestHandler>,
}
impl WebSocketDispatcher {
#[must_use]
pub const fn new(handler: Arc<RequestHandler>) -> Self {
Self { handler }
}
pub async fn serve(
self: Arc<Self>,
addr: impl tokio::net::ToSocketAddrs,
) -> std::io::Result<()> {
let listener = TcpListener::bind(addr).await?;
trace_info!(
addr = %listener.local_addr().unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0))),
"A2A WebSocket server listening"
);
loop {
let (stream, _peer) = listener.accept().await?;
let dispatcher = Arc::clone(&self);
tokio::spawn(async move {
trace_debug!("WebSocket connection accepted");
if let Err(_e) = dispatcher.handle_connection(stream).await {
trace_warn!("WebSocket connection error");
}
});
}
}
pub async fn serve_with_addr(
self: Arc<Self>,
addr: impl tokio::net::ToSocketAddrs,
) -> std::io::Result<SocketAddr> {
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
trace_info!(%local_addr, "A2A WebSocket server listening");
tokio::spawn(async move {
loop {
let Ok((stream, _peer)) = listener.accept().await else {
break;
};
let dispatcher = Arc::clone(&self);
tokio::spawn(async move {
let _ = dispatcher.handle_connection(stream).await;
});
}
});
Ok(local_addr)
}
async fn handle_connection(&self, stream: TcpStream) -> Result<(), WsError> {
let ws_stream = tokio_tungstenite::accept_async(stream)
.await
.map_err(WsError::Handshake)?;
let (writer, mut reader) = ws_stream.split();
let writer = Arc::new(tokio::sync::Mutex::new(writer));
let semaphore = Arc::new(tokio::sync::Semaphore::new(64));
while let Some(msg) = reader.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
if text.len() > 4 * 1024 * 1024 {
let err_resp = JsonRpcErrorResponse::new(
None,
JsonRpcError::new(-32000, "message too large".to_string()),
);
send_json(&writer, &err_resp).await;
continue;
}
let Ok(permit) = semaphore.clone().try_acquire_owned() else {
let err_resp = JsonRpcErrorResponse::new(
None,
JsonRpcError::new(
-32000,
"server busy: too many concurrent requests".to_string(),
),
);
send_json(&writer, &err_resp).await;
continue;
};
let writer = Arc::clone(&writer);
let handler = Arc::clone(&self.handler);
tokio::spawn(async move {
process_ws_message(&handler, &text, writer).await;
drop(permit); });
}
Ok(WsMessage::Ping(data)) => {
let mut w = writer.lock().await;
let _ = w.send(WsMessage::Pong(data)).await;
drop(w);
}
Ok(WsMessage::Close(_)) | Err(_) => break,
Ok(_) => {} }
}
Ok(())
}
}
#[derive(Debug)]
enum WsError {
Handshake(tokio_tungstenite::tungstenite::Error),
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Handshake(e) => write!(f, "WebSocket handshake failed: {e}"),
}
}
}
type WsSink = Arc<tokio::sync::Mutex<SplitSink<WebSocketStream<TcpStream>, WsMessage>>>;
#[allow(clippy::too_many_lines)]
async fn process_ws_message(handler: &RequestHandler, text: &str, writer: WsSink) {
let rpc_req: JsonRpcRequest = match serde_json::from_str(text) {
Ok(req) => req,
Err(e) => {
let err_resp = JsonRpcErrorResponse::new(
None,
JsonRpcError::new(-32700, format!("parse error: {e}")),
);
send_json(&writer, &err_resp).await;
return;
}
};
let id = rpc_req.id.clone();
let headers = HashMap::new();
match rpc_req.method.as_str() {
"SendMessage" => {
dispatch_send_message(handler, &rpc_req, false, &headers, id, &writer).await;
}
"SendStreamingMessage" => {
dispatch_send_message(handler, &rpc_req, true, &headers, id, &writer).await;
}
"GetTask" => {
dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
Box::pin(async move {
let params: a2a_protocol_types::params::TaskQueryParams =
serde_json::from_value(p).map_err(|e| {
a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
})?;
h.on_get_task(params, Some(hdr))
.await
.map(|r| serde_json::to_value(&r).unwrap_or_default())
.map_err(|e| e.to_a2a_error())
})
})
.await;
}
"ListTasks" => {
dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
Box::pin(async move {
let params: a2a_protocol_types::params::ListTasksParams =
serde_json::from_value(p).map_err(|e| {
a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
})?;
h.on_list_tasks(params, Some(hdr))
.await
.map(|r| serde_json::to_value(&r).unwrap_or_default())
.map_err(|e| e.to_a2a_error())
})
})
.await;
}
"CancelTask" => {
dispatch_simple(handler, &rpc_req, id, &headers, &writer, |h, p, hdr| {
Box::pin(async move {
let params: a2a_protocol_types::params::CancelTaskParams =
serde_json::from_value(p).map_err(|e| {
a2a_protocol_types::error::A2aError::invalid_params(e.to_string())
})?;
h.on_cancel_task(params, Some(hdr))
.await
.map(|r| serde_json::to_value(&r).unwrap_or_default())
.map_err(|e| e.to_a2a_error())
})
})
.await;
}
"SubscribeToTask" => {
let params = match parse_params::<a2a_protocol_types::params::TaskIdParams>(
rpc_req.params.as_ref(),
) {
Ok(p) => p,
Err(e) => {
send_error(&writer, id, &e).await;
return;
}
};
match handler.on_resubscribe(params, Some(&headers)).await {
Ok(reader) => {
stream_events(&writer, reader, id).await;
}
Err(e) => {
send_error(&writer, id, &e).await;
}
}
}
other => {
let err = ServerError::MethodNotFound(other.to_owned());
send_error(&writer, id, &err).await;
}
}
}
async fn dispatch_send_message(
handler: &RequestHandler,
rpc_req: &JsonRpcRequest,
streaming: bool,
headers: &HashMap<String, String>,
id: JsonRpcId,
writer: &WsSink,
) {
let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(
rpc_req.params.as_ref(),
) {
Ok(p) => p,
Err(e) => {
send_error(writer, id, &e).await;
return;
}
};
match handler
.on_send_message(params, streaming, Some(headers))
.await
{
Ok(SendMessageResult::Response(resp)) => {
let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
let success = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id,
result,
};
send_json(writer, &success).await;
}
Ok(SendMessageResult::Stream(reader)) => {
stream_events(writer, reader, id).await;
}
Err(e) => {
send_error(writer, id, &e).await;
}
}
}
async fn stream_events(
writer: &WsSink,
mut reader: crate::streaming::InMemoryQueueReader,
id: JsonRpcId,
) {
while let Some(event) = reader.read().await {
match event {
Ok(stream_resp) => {
let envelope = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id: id.clone(),
result: stream_resp,
};
let json = serde_json::to_string(&envelope).unwrap_or_default();
let mut w = writer.lock().await;
if w.send(WsMessage::Text(json)).await.is_err() {
return; }
drop(w);
}
Err(e) => {
let err_resp =
JsonRpcErrorResponse::new(id.clone(), JsonRpcError::new(-32000, e.to_string()));
send_json(writer, &err_resp).await;
return;
}
}
}
let success = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id,
result: serde_json::json!({"status": "stream_complete"}),
};
send_json(writer, &success).await;
}
async fn dispatch_simple<'a, F>(
handler: &'a RequestHandler,
rpc_req: &JsonRpcRequest,
id: JsonRpcId,
headers: &'a HashMap<String, String>,
writer: &WsSink,
f: F,
) where
F: FnOnce(
&'a RequestHandler,
serde_json::Value,
&'a HashMap<String, String>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<serde_json::Value, a2a_protocol_types::error::A2aError>,
> + Send
+ 'a,
>,
>,
{
let params = rpc_req.params.clone().unwrap_or(serde_json::Value::Null);
match f(handler, params, headers).await {
Ok(result) => {
let success = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id,
result,
};
send_json(writer, &success).await;
}
Err(e) => {
let err_resp =
JsonRpcErrorResponse::new(id, JsonRpcError::new(e.code.as_i32(), e.message));
send_json(writer, &err_resp).await;
}
}
}
async fn send_json<T: serde::Serialize + Sync>(writer: &WsSink, value: &T) {
let json = serde_json::to_string(value).unwrap_or_default();
let mut w = writer.lock().await;
let _ = w.send(WsMessage::Text(json)).await;
drop(w);
}
async fn send_error(writer: &WsSink, id: JsonRpcId, err: &ServerError) {
let a2a_err = err.to_a2a_error();
let resp = JsonRpcErrorResponse::new(
id,
JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
);
send_json(writer, &resp).await;
}
fn parse_params<T: serde::de::DeserializeOwned>(
params: Option<&serde_json::Value>,
) -> Result<T, ServerError> {
let value = params.cloned().unwrap_or(serde_json::Value::Null);
serde_json::from_value(value)
.map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_params_with_valid_json() {
let value = Some(serde_json::json!({"id": "task-1"}));
let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
parse_params(value.as_ref());
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "task-1");
}
#[test]
fn parse_params_with_none_returns_error() {
let result: Result<a2a_protocol_types::params::TaskQueryParams, _> = parse_params(None);
assert!(result.is_err());
}
#[test]
fn parse_params_with_wrong_type_returns_error() {
let value = Some(serde_json::json!("not an object"));
let result: Result<a2a_protocol_types::params::TaskQueryParams, _> =
parse_params(value.as_ref());
assert!(result.is_err());
}
#[test]
fn ws_error_display_contains_message() {
let err = WsError::Handshake(tokio_tungstenite::tungstenite::Error::ConnectionClosed);
let s = err.to_string();
assert!(s.contains("WebSocket handshake failed"));
}
#[test]
fn websocket_dispatcher_new() {
use crate::agent_executor;
use crate::RequestHandlerBuilder;
use std::sync::Arc;
struct DummyExec;
agent_executor!(DummyExec, |_ctx, _queue| async { Ok(()) });
let handler = Arc::new(RequestHandlerBuilder::new(DummyExec).build().unwrap());
let _dispatcher = WebSocketDispatcher::new(handler);
}
use crate::agent_executor;
use crate::RequestHandlerBuilder;
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
use futures_util::{SinkExt, StreamExt};
struct EchoExec;
agent_executor!(EchoExec, |ctx, queue| async {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
});
async fn spawn_ws_server() -> std::net::SocketAddr {
let handler = Arc::new(RequestHandlerBuilder::new(EchoExec).build().unwrap());
let dispatcher = Arc::new(WebSocketDispatcher::new(handler));
dispatcher
.serve_with_addr("127.0.0.1:0")
.await
.expect("bind to port 0")
}
async fn ws_connect(
addr: std::net::SocketAddr,
) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
{
let (ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}"))
.await
.expect("ws connect");
ws
}
async fn read_text(
ws: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
) -> String {
let msg = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next())
.await
.expect("timeout waiting for WS frame")
.expect("stream ended")
.expect("ws error");
msg.into_text().expect("not a text frame")
}
fn send_message_json(id: &str) -> String {
serde_json::json!({
"jsonrpc": "2.0",
"method": "SendMessage",
"id": id,
"params": {
"message": {
"messageId": "msg-1",
"role": "ROLE_USER",
"parts": [{"text": "hello"}]
}
}
})
.to_string()
}
#[tokio::test]
async fn ws_send_message_success() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
ws.send(WsMessage::Text(send_message_json("sm-1")))
.await
.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(v["id"], "sm-1");
assert!(v.get("result").is_some(), "expected result key: {text}");
}
#[tokio::test]
async fn ws_get_task_not_found() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "GetTask",
"id": "gt-1",
"params": {"id": "nonexistent"}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
}
#[tokio::test]
async fn ws_list_tasks_success() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "ListTasks",
"id": "lt-1",
"params": {}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(v["id"], "lt-1");
assert!(v.get("result").is_some(), "expected result: {text}");
}
#[tokio::test]
async fn ws_cancel_task_not_found() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "CancelTask",
"id": "ct-1",
"params": {"id": "nonexistent"}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
}
#[tokio::test]
async fn ws_subscribe_task_not_found() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "SubscribeToTask",
"id": "sub-1",
"params": {"id": "nonexistent"}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
}
#[tokio::test]
async fn ws_unknown_method_error() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "FooBar",
"id": "unk-1",
"params": {}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
let msg = v["error"]["message"].as_str().unwrap_or("");
assert!(
msg.to_lowercase().contains("method")
|| msg.to_lowercase().contains("not found")
|| msg.to_lowercase().contains("unsupported"),
"error message should mention method not found: {msg}"
);
}
#[tokio::test]
async fn ws_invalid_json_parse_error() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
ws.send(WsMessage::Text("this is not json {{".into()))
.await
.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(v["error"]["code"], -32700, "expected parse error code");
}
#[tokio::test]
async fn ws_oversized_message_rejected() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let big = "x".repeat(4 * 1024 * 1024 + 1);
ws.send(WsMessage::Text(big)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
let msg = v["error"]["message"].as_str().unwrap_or("");
assert!(
msg.contains("too large"),
"error should mention 'too large': {msg}"
);
}
#[tokio::test]
async fn ws_ping_pong_response() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
ws.send(WsMessage::Ping(vec![42, 43])).await.unwrap();
let pong = tokio::time::timeout(std::time::Duration::from_secs(3), async {
loop {
let msg = ws.next().await.unwrap().unwrap();
if let WsMessage::Pong(data) = msg {
return data;
}
}
})
.await
.expect("should get pong within 3s");
assert_eq!(pong, vec![42, 43]);
}
#[tokio::test]
async fn ws_get_task_invalid_params() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "GetTask",
"id": "gti-1",
"params": {"wrong_field": 123}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(
v.get("error").is_some(),
"expected error for bad params: {text}"
);
}
#[tokio::test]
async fn ws_send_streaming_message_events() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "SendStreamingMessage",
"id": "ssm-1",
"params": {
"message": {
"messageId": "msg-stream-1",
"role": "ROLE_USER",
"parts": [{"text": "stream me"}]
}
}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let mut frames = Vec::new();
let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async {
loop {
let msg = ws.next().await.unwrap().unwrap();
let text = msg.into_text().unwrap();
let done = text.contains("stream_complete");
frames.push(text);
if done {
break;
}
}
});
timeout.await.expect("streaming should complete within 5s");
assert!(
frames.len() >= 3,
"expected >= 3 frames, got {}: {:?}",
frames.len(),
frames
);
assert!(frames.last().unwrap().contains("stream_complete"));
}
#[tokio::test]
async fn ws_send_message_invalid_params() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "SendMessage",
"id": "smi-1",
"params": {"not_message": true}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(
v.get("error").is_some(),
"expected error for bad send params: {text}"
);
}
#[tokio::test]
async fn ws_subscribe_invalid_params() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "SubscribeToTask",
"id": "subi-1",
"params": {}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(
v.get("error").is_some(),
"expected error for bad subscribe params: {text}"
);
}
#[tokio::test]
async fn ws_cancel_task_invalid_params() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "CancelTask",
"id": "cti-1",
"params": {"wrong": 1}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert!(v.get("error").is_some(), "expected error: {text}");
}
#[tokio::test]
async fn ws_list_tasks_with_filters() {
let addr = spawn_ws_server().await;
let mut ws = ws_connect(addr).await;
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "ListTasks",
"id": "ltf-1",
"params": {
"contextId": "ctx-1",
"pageSize": 10
}
})
.to_string();
ws.send(WsMessage::Text(req)).await.unwrap();
let text = read_text(&mut ws).await;
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(v["id"], "ltf-1");
assert!(v.get("result").is_some(), "expected result: {text}");
}
}