use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use a2a_protocol_client::ClientBuilder;
use a2a_protocol_types::agent_card::{AgentCapabilities, AgentCard, AgentInterface, AgentSkill};
use a2a_protocol_types::artifact::Artifact;
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::events::{StreamResponse, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
use a2a_protocol_types::params::MessageSendParams;
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
use a2a_protocol_server::builder::RequestHandlerBuilder;
use a2a_protocol_server::dispatch::RestDispatcher;
use a2a_protocol_server::executor::AgentExecutor;
use a2a_protocol_server::request_context::RequestContext;
use a2a_protocol_server::streaming::EventQueueWriter;
const TEST_TIMEOUT: Duration = Duration::from_secs(10);
struct StreamingExecutor;
impl AgentExecutor for StreamingExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
artifact: Artifact::new("echo", vec![Part::text("Hello from REST stream")]),
append: None,
last_chunk: Some(true),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
fn agent_card(url: &str) -> AgentCard {
AgentCard {
url: Some(url.to_owned()),
name: "test-rest-agent".into(),
description: "REST streaming test agent".into(),
version: "1.0.0".into(),
supported_interfaces: vec![AgentInterface {
url: url.to_owned(),
protocol_binding: "REST".into(),
protocol_version: "1.0.0".into(),
tenant: None,
}],
default_input_modes: vec!["text/plain".into()],
default_output_modes: vec!["text/plain".into()],
skills: vec![AgentSkill {
id: "echo".into(),
name: "Echo".into(),
description: "Echoes input".into(),
tags: vec![],
examples: None,
input_modes: None,
output_modes: None,
security_requirements: None,
}],
capabilities: AgentCapabilities::none(),
provider: None,
icon_url: None,
documentation_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
}
}
fn send_params() -> MessageSendParams {
MessageSendParams {
tenant: None,
message: Message {
id: MessageId::new("msg-1"),
role: MessageRole::User,
parts: vec![Part::text("hello")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
}
}
async fn start_rest_server() -> (SocketAddr, tokio::task::JoinHandle<()>) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local addr");
let url = format!("http://{addr}");
let handler = Arc::new(
RequestHandlerBuilder::new(StreamingExecutor)
.with_agent_card(agent_card(&url))
.build()
.expect("build handler"),
);
let dispatcher = Arc::new(RestDispatcher::new(handler));
let handle = tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let io = hyper_util::rt::TokioIo::new(stream);
let dispatcher = Arc::clone(&dispatcher);
tokio::spawn(async move {
let service = hyper::service::service_fn(move |req| {
let d = Arc::clone(&dispatcher);
async move { Ok::<_, std::convert::Infallible>(d.dispatch(req).await) }
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
(addr, handle)
}
#[tokio::test]
async fn rest_stream_message_receives_all_events() {
let (addr, _handle) = start_rest_server().await;
let url = format!("http://{addr}");
let client = ClientBuilder::new(&url)
.with_protocol_binding("REST")
.build()
.expect("build REST client");
let mut stream = tokio::time::timeout(TEST_TIMEOUT, client.stream_message(send_params()))
.await
.expect("timed out connecting")
.expect("stream_message should succeed");
let mut events = Vec::new();
while let Some(event) = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out waiting for event")
{
events.push(event.expect("stream event should deserialize"));
}
assert_eq!(
events.len(),
4,
"should receive exactly 4 events, got {}: {:?}",
events.len(),
events
);
assert!(
matches!(&events[0], StreamResponse::Task(_)),
"first event should be Task snapshot"
);
assert!(
matches!(&events[1], StreamResponse::StatusUpdate(e) if e.status.state == TaskState::Working),
"second event should be Working"
);
assert!(
matches!(&events[2], StreamResponse::ArtifactUpdate(_)),
"third event should be ArtifactUpdate"
);
assert!(
matches!(&events[3], StreamResponse::StatusUpdate(e) if e.status.state == TaskState::Completed),
"fourth event should be Completed"
);
}
#[tokio::test]
async fn rest_send_message_returns_completed_task() {
let (addr, _handle) = start_rest_server().await;
let url = format!("http://{addr}");
let client = ClientBuilder::new(&url)
.with_protocol_binding("REST")
.build()
.expect("build REST client");
let result = tokio::time::timeout(TEST_TIMEOUT, client.send_message(send_params()))
.await
.expect("timed out")
.expect("send_message should succeed");
match result {
a2a_protocol_types::responses::SendMessageResponse::Task(task) => {
assert_eq!(task.status.state, TaskState::Completed);
}
other => panic!("expected Task, got {other:?}"),
}
}
#[tokio::test]
async fn rest_get_task_roundtrip() {
let (addr, _handle) = start_rest_server().await;
let url = format!("http://{addr}");
let client = ClientBuilder::new(&url)
.with_protocol_binding("REST")
.build()
.expect("build REST client");
let result = tokio::time::timeout(TEST_TIMEOUT, client.send_message(send_params()))
.await
.expect("timed out")
.expect("send_message");
let task_id = match result {
a2a_protocol_types::responses::SendMessageResponse::Task(t) => t.id,
other => panic!("expected Task, got {other:?}"),
};
let params = a2a_protocol_types::TaskQueryParams {
tenant: None,
id: task_id.0.clone(),
history_length: None,
};
let fetched = tokio::time::timeout(TEST_TIMEOUT, client.get_task(params))
.await
.expect("timed out")
.expect("get_task");
assert_eq!(fetched.id, task_id);
assert_eq!(fetched.status.state, TaskState::Completed);
}
#[tokio::test]
async fn rest_list_tasks() {
let (addr, _handle) = start_rest_server().await;
let url = format!("http://{addr}");
let client = ClientBuilder::new(&url)
.with_protocol_binding("REST")
.build()
.expect("build REST client");
let params = a2a_protocol_types::ListTasksParams::default();
let result = tokio::time::timeout(TEST_TIMEOUT, client.list_tasks(params))
.await
.expect("timed out")
.expect("list_tasks");
let _ = result.tasks.len();
}
#[tokio::test]
async fn rest_cancel_nonexistent_task_returns_error() {
let (addr, _handle) = start_rest_server().await;
let url = format!("http://{addr}");
let client = ClientBuilder::new(&url)
.with_protocol_binding("REST")
.build()
.expect("build REST client");
let result = tokio::time::timeout(TEST_TIMEOUT, client.cancel_task("does-not-exist"))
.await
.expect("timed out");
assert!(result.is_err(), "cancel nonexistent task should fail");
}