mod background_processor;
mod error_handling;
mod state_transitions;
mod streaming_mode;
mod sync_mode;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use a2a_protocol_types::artifact::Artifact;
use a2a_protocol_types::error::{A2aError, A2aResult};
use a2a_protocol_types::events::{StreamResponse, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
use a2a_protocol_types::params::MessageSendParams;
use a2a_protocol_types::push::TaskPushNotificationConfig;
use a2a_protocol_types::responses::SendMessageResponse;
use a2a_protocol_types::task::{ContextId, Task, TaskState, TaskStatus};
use a2a_protocol_server::builder::RequestHandlerBuilder;
use a2a_protocol_server::executor::AgentExecutor;
use a2a_protocol_server::handler::SendMessageResult;
use a2a_protocol_server::push::PushSender;
use a2a_protocol_server::request_context::RequestContext;
use a2a_protocol_server::streaming::{EventQueueReader, EventQueueWriter};
struct StatusExecutor;
impl AgentExecutor for StatusExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct ArtifactExecutor;
impl AgentExecutor for ArtifactExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
artifact: Artifact::new("art-1", vec![Part::text("artifact content")]),
append: None,
last_chunk: Some(true),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct TaskEventExecutor;
impl AgentExecutor for TaskEventExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let replacement = Task {
id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
history: None,
artifacts: Some(vec![Artifact::new(
"replaced-art",
vec![Part::text("replaced")],
)]),
metadata: Some(serde_json::json!({"replaced": true})),
};
queue.write(StreamResponse::Task(replacement)).await?;
Ok(())
})
}
}
struct ErrorExecutor;
impl AgentExecutor for ErrorExecutor {
fn execute<'a>(
&'a self,
_ctx: &'a RequestContext,
_queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move { Err(A2aError::internal("something went wrong")) })
}
}
struct MessageEventExecutor;
impl AgentExecutor for MessageEventExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
let msg = Message {
id: MessageId::new("msg-event"),
role: MessageRole::Agent,
parts: vec![Part::text("hello from agent")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
};
queue.write(StreamResponse::Message(msg)).await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct EmptyExecutor;
impl AgentExecutor for EmptyExecutor {
fn execute<'a>(
&'a self,
_ctx: &'a RequestContext,
_queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move { Ok(()) })
}
}
struct SharedRecordingPushSender {
calls: Arc<Mutex<Vec<String>>>,
}
impl PushSender for SharedRecordingPushSender {
fn send<'a>(
&'a self,
url: &'a str,
_event: &'a StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
self.calls.lock().unwrap().push(url.to_string());
Ok(())
})
}
}
struct SleepForeverPushSender;
impl PushSender for SleepForeverPushSender {
fn send<'a>(
&'a self,
_url: &'a str,
_event: &'a StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
Ok(())
})
}
}
fn make_send_params() -> MessageSendParams {
MessageSendParams {
tenant: None,
context_id: None,
message: Message {
id: MessageId::new("msg-1"),
role: MessageRole::User,
parts: vec![Part::text("hello")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
},
configuration: None,
metadata: None,
}
}
fn default_list_params() -> a2a_protocol_types::params::ListTasksParams {
a2a_protocol_types::params::ListTasksParams {
tenant: None,
context_id: None,
status: None,
page_size: None,
page_token: None,
status_timestamp_after: None,
include_artifacts: None,
history_length: None,
}
}
fn extract_task(result: SendMessageResult) -> Task {
match result {
SendMessageResult::Response(SendMessageResponse::Task(task)) => task,
_ => panic!("expected SendMessageResult::Response(Task)"),
}
}