use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use a2a_protocol_types::agent_card::{AgentCapabilities, AgentCard, AgentInterface, AgentSkill};
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::jsonrpc::{JsonRpcRequest, JsonRpcSuccessResponse};
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
use a2a_protocol_types::params::MessageSendParams;
use a2a_protocol_types::push::TaskPushNotificationConfig;
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
use a2a_protocol_server::builder::RequestHandlerBuilder;
use a2a_protocol_server::dispatch::JsonRpcDispatcher;
use a2a_protocol_server::executor::AgentExecutor;
use a2a_protocol_server::push::PushSender;
use a2a_protocol_server::request_context::RequestContext;
use a2a_protocol_server::streaming::EventQueueWriter;
struct StressExecutor {
completed_count: Arc<AtomicUsize>,
}
impl AgentExecutor for StressExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
let completed = Arc::clone(&self.completed_count);
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
tokio::time::sleep(Duration::from_millis(1)).await;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
completed.fetch_add(1, Ordering::Relaxed);
Ok(())
})
}
}
struct NoopPushSender;
impl PushSender for NoopPushSender {
fn send<'a>(
&'a self,
_url: &'a str,
_event: &'a StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move { Ok(()) })
}
}
fn minimal_agent_card() -> AgentCard {
AgentCard {
url: None,
name: "Stress Test Agent".into(),
description: "Handles load tests".into(),
version: "1.0.0".into(),
supported_interfaces: vec![AgentInterface {
url: "http://localhost/rpc".into(),
protocol_binding: "JSONRPC".into(),
protocol_version: "1.0.0".into(),
tenant: None,
}],
default_input_modes: vec!["text/plain".into()],
default_output_modes: vec!["text/plain".into()],
skills: vec![AgentSkill {
id: "stress".into(),
name: "Stress".into(),
description: "Stress test skill".into(),
tags: vec![],
examples: None,
input_modes: None,
output_modes: None,
security_requirements: None,
}],
capabilities: AgentCapabilities::none(),
provider: None,
icon_url: None,
documentation_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
}
}
fn make_send_params(id: usize) -> MessageSendParams {
MessageSendParams {
tenant: None,
message: Message {
id: MessageId::new(format!("stress-msg-{id}")),
role: MessageRole::User,
parts: vec![Part::text(format!("stress request {id}"))],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
}
}
async fn start_stress_server(completed_count: Arc<AtomicUsize>) -> SocketAddr {
let handler = Arc::new(
RequestHandlerBuilder::new(StressExecutor { completed_count })
.with_agent_card(minimal_agent_card())
.with_push_sender(NoopPushSender)
.build()
.expect("build handler"),
);
let dispatcher = Arc::new(JsonRpcDispatcher::new(handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let io = hyper_util::rt::TokioIo::new(stream);
let dispatcher = Arc::clone(&dispatcher);
tokio::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&dispatcher);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(io, service)
.await;
});
}
});
addr
}
type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>;
fn build_http_client() -> HttpClient {
Client::builder(TokioExecutor::new()).build_http()
}
async fn send_request(client: &HttpClient, addr: SocketAddr, id: usize) -> Result<(), String> {
let params = make_send_params(id);
let rpc_req = JsonRpcRequest::with_params(
serde_json::json!(format!("stress-{id}")),
"SendMessage",
serde_json::to_value(¶ms).unwrap(),
);
let body = serde_json::to_vec(&rpc_req).unwrap();
let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{addr}/"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.map_err(|e| format!("build request: {e}"))?;
let resp = client
.request(req)
.await
.map_err(|e| format!("request failed: {e}"))?;
let status = resp.status();
let body_bytes = resp
.collect()
.await
.map_err(|e| format!("read body: {e}"))?
.to_bytes();
if !status.is_success() {
return Err(format!(
"unexpected status {status}: {}",
String::from_utf8_lossy(&body_bytes)
));
}
let _: JsonRpcSuccessResponse<serde_json::Value> =
serde_json::from_slice(&body_bytes).map_err(|e| format!("parse response: {e}"))?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_200_requests_all_succeed() {
let completed = Arc::new(AtomicUsize::new(0));
let addr = start_stress_server(Arc::clone(&completed)).await;
let client = build_http_client();
let mut handles = Vec::new();
for i in 0..200 {
let client = client.clone();
handles.push(tokio::spawn(
async move { send_request(&client, addr, i).await },
));
}
let mut success_count = 0;
let mut error_count = 0;
for handle in handles {
match handle.await.unwrap() {
Ok(()) => success_count += 1,
Err(e) => {
error_count += 1;
eprintln!("request error: {e}");
}
}
}
assert_eq!(error_count, 0, "all requests should succeed");
assert_eq!(success_count, 200);
tokio::time::sleep(Duration::from_millis(500)).await;
assert_eq!(
completed.load(Ordering::Relaxed),
200,
"all executors should have completed"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn sustained_load_10_seconds() {
let completed = Arc::new(AtomicUsize::new(0));
let addr = start_stress_server(Arc::clone(&completed)).await;
let client = build_http_client();
let start = Instant::now();
let mut request_id = 0_usize;
let mut total_sent = 0_usize;
for _wave in 0..10 {
let mut handles = Vec::new();
for _ in 0..50 {
let client = client.clone();
let id = request_id;
request_id += 1;
total_sent += 1;
handles.push(tokio::spawn(async move {
send_request(&client, addr, id).await
}));
}
for handle in handles {
let result = handle.await.unwrap();
result.expect("each wave request should return a valid JSON-RPC success response");
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(30),
"sustained load test took too long: {elapsed:?}"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let completed_val = completed.load(Ordering::Relaxed);
assert_eq!(
completed_val, total_sent,
"all {total_sent} executors should have completed, but only {completed_val} did"
);
}
#[tokio::test]
async fn task_store_eviction_under_load() {
use a2a_protocol_server::store::{InMemoryTaskStore, TaskStore, TaskStoreConfig};
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
let config = TaskStoreConfig {
max_capacity: Some(50),
..Default::default()
};
let store = InMemoryTaskStore::with_config(config);
for i in 0..200 {
let task = Task {
id: TaskId::new(format!("evict-task-{i}")),
context_id: ContextId::new("ctx-evict"),
status: TaskStatus::new(TaskState::Completed),
history: None,
artifacts: None,
metadata: None,
};
store.save(&task).await.unwrap();
}
store.run_eviction().await;
let params = ListTasksParams::default();
let all = store.list(¶ms).await.unwrap();
assert!(
all.tasks.len() <= 50,
"store should have at most 50 tasks after eviction, but has {}",
all.tasks.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_multi_tenant_isolation() {
use a2a_protocol_server::store::TaskStore;
use a2a_protocol_server::store::{TenantAwareInMemoryTaskStore, TenantContext};
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
let store = Arc::new(TenantAwareInMemoryTaskStore::new());
let mut handles = Vec::new();
for tenant_idx in 0..10 {
for task_idx in 0..50 {
let store = Arc::clone(&store);
let tenant = format!("tenant-{tenant_idx}");
handles.push(tokio::spawn(TenantContext::scope(
tenant.clone(),
async move {
let task = Task {
id: TaskId::new(format!("task-{tenant_idx}-{task_idx}")),
context_id: ContextId::new(format!("ctx-{tenant_idx}")),
status: TaskStatus::new(TaskState::Completed),
history: None,
artifacts: None,
metadata: None,
};
store.save(&task).await.unwrap();
},
)));
}
}
for handle in handles {
handle.await.unwrap();
}
let params = ListTasksParams::default();
for tenant_idx in 0..10 {
let store = Arc::clone(&store);
let params = params.clone();
let tenant = format!("tenant-{tenant_idx}");
let count = TenantContext::scope(tenant, async move {
store.list(¶ms).await.unwrap().tasks.len()
})
.await;
assert_eq!(
count, 50,
"tenant-{tenant_idx} should have exactly 50 tasks, got {count}"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn rapid_connect_disconnect_cycles() {
let completed = Arc::new(AtomicUsize::new(0));
let addr = start_stress_server(Arc::clone(&completed)).await;
for i in 0..100 {
let client = build_http_client();
let result = send_request(&client, addr, i).await;
result.unwrap_or_else(|e| {
panic!("request {i} should return valid JSON-RPC success response: {e}")
});
drop(client);
}
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(completed.load(Ordering::Relaxed), 100);
}