use a2a_protocol_types::events::StreamResponse;
use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
use crate::error::{ServerError, ServerResult};
use crate::streaming::{EventQueueReader, InMemoryQueueReader};
use super::super::RequestHandler;
impl RequestHandler {
pub(crate) async fn collect_events(
&self,
mut reader: InMemoryQueueReader,
task_id: TaskId,
executor_handle: tokio::task::JoinHandle<()>,
) -> ServerResult<Task> {
let mut last_task = self
.task_store
.get(&task_id)
.await?
.ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
let mut executor_done = false;
let mut handle_fuse = executor_handle;
loop {
if executor_done {
match reader.read().await {
Some(event) => {
self.process_event(event, &task_id, &mut last_task).await?;
}
None => break,
}
} else {
tokio::select! {
biased;
event = reader.read() => {
match event {
Some(event) => {
self.process_event(event, &task_id, &mut last_task).await?;
}
None => break,
}
}
result = &mut handle_fuse => {
executor_done = true;
if result.is_err() {
trace_error!(
task_id = %task_id,
"executor task panicked"
);
if !last_task.status.state.is_terminal() {
last_task.status = TaskStatus::with_timestamp(TaskState::Failed);
self.task_store.save(last_task.clone()).await?;
}
}
}
}
}
}
Ok(last_task)
}
async fn process_event(
&self,
event: a2a_protocol_types::error::A2aResult<StreamResponse>,
task_id: &TaskId,
last_task: &mut Task,
) -> ServerResult<()> {
match event {
Ok(ref stream_resp @ StreamResponse::StatusUpdate(ref update)) => {
let current = last_task.status.state;
let next = update.status.state;
if !current.can_transition_to(next) {
trace_warn!(
task_id = %task_id,
from = %current,
to = %next,
"invalid state transition rejected"
);
return Err(ServerError::InvalidStateTransition {
task_id: task_id.clone(),
from: current,
to: next,
});
}
last_task.status = TaskStatus {
state: next,
message: update.status.message.clone(),
timestamp: update.status.timestamp.clone(),
};
self.task_store.save(last_task.clone()).await?;
self.deliver_push(task_id, stream_resp).await;
}
Ok(ref stream_resp @ StreamResponse::ArtifactUpdate(ref update)) => {
let artifacts = last_task.artifacts.get_or_insert_with(Vec::new);
if artifacts.len() >= self.limits.max_artifacts_per_task {
trace_warn!(
task_id = %task_id,
max = self.limits.max_artifacts_per_task,
"artifact limit reached; dropping artifact update"
);
} else {
artifacts.push(update.artifact.clone());
self.task_store.save(last_task.clone()).await?;
self.deliver_push(task_id, stream_resp).await;
}
}
Ok(StreamResponse::Task(task)) => {
*last_task = task;
self.task_store.save(last_task.clone()).await?;
}
Ok(StreamResponse::Message(_) | _) => {
}
Err(e) => {
last_task.status = TaskStatus::with_timestamp(TaskState::Failed);
self.task_store.save(last_task.clone()).await?;
return Err(ServerError::Protocol(e));
}
}
Ok(())
}
async fn deliver_push(&self, task_id: &TaskId, event: &StreamResponse) {
let Some(ref sender) = self.push_sender else {
return;
};
let Ok(configs) = self.push_config_store.list(task_id.as_ref()).await else {
return;
};
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(30);
for config in &configs {
if tokio::time::Instant::now() >= deadline {
trace_warn!(
task_id = %task_id,
"push delivery deadline exceeded; skipping remaining configs"
);
break;
}
let result = tokio::time::timeout(
self.limits.push_delivery_timeout,
sender.send(&config.url, event, config),
)
.await;
match result {
Ok(Err(_err)) => {
trace_warn!(
task_id = %task_id,
url = %config.url,
error = %_err,
"push notification delivery failed"
);
}
Err(_) => {
trace_warn!(
task_id = %task_id,
url = %config.url,
"push notification delivery timed out"
);
}
Ok(Ok(())) => {}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use a2a_protocol_types::events::StreamResponse;
use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
use crate::agent_executor;
use crate::builder::RequestHandlerBuilder;
use crate::store::{InMemoryTaskStore, TaskStore};
use crate::streaming::event_queue::new_in_memory_queue;
use crate::streaming::EventQueueWriter;
struct DummyExecutor;
agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
fn make_task(id: &str, state: TaskState) -> Task {
Task {
id: id.into(),
context_id: ContextId::new("ctx-1"),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
fn make_status_event(task_id: &str, state: TaskState) -> StreamResponse {
use a2a_protocol_types::events::TaskStatusUpdateEvent;
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: TaskId::new(task_id),
context_id: ContextId::new("ctx-1"),
status: TaskStatus::new(state),
metadata: None,
})
}
#[tokio::test]
async fn process_event_self_valid_state_transition() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t1");
task_store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
writer
.write(make_status_event("t1", TaskState::Working))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok(), "collect_events should succeed");
let final_task = result.unwrap();
assert_eq!(final_task.status.state, TaskState::Working);
let stored = task_store.get(&task_id).await.unwrap().unwrap();
assert_eq!(stored.status.state, TaskState::Working);
}
#[tokio::test]
async fn process_event_invalid_state_transition_returns_error() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-invalid-trans");
task_store
.save(make_task("t-invalid-trans", TaskState::Completed))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
writer
.write(make_status_event("t-invalid-trans", TaskState::Working))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(
matches!(
result,
Err(crate::error::ServerError::InvalidStateTransition { .. })
),
"expected InvalidStateTransition error, got: {result:?}"
);
}
#[tokio::test]
async fn process_event_artifact_update_appends() {
use a2a_protocol_types::artifact::{Artifact, ArtifactId};
use a2a_protocol_types::events::TaskArtifactUpdateEvent;
use a2a_protocol_types::message::Part;
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-art");
task_store
.save(make_task("t-art", TaskState::Working))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
let artifact_event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
task_id: TaskId::new("t-art"),
context_id: a2a_protocol_types::task::ContextId::new("ctx-1"),
artifact: Artifact::new(ArtifactId::new("art-1"), vec![Part::text("output data")]),
append: None,
last_chunk: Some(true),
metadata: None,
});
writer.write(artifact_event).await.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok(), "collect_events should succeed");
let final_task = result.unwrap();
let artifacts = final_task.artifacts.expect("artifacts should be Some");
assert_eq!(artifacts.len(), 1);
assert_eq!(artifacts[0].id, ArtifactId::new("art-1"));
}
#[tokio::test]
async fn process_event_task_snapshot_replaces() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-snap");
task_store
.save(make_task("t-snap", TaskState::Submitted))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
let replacement = make_task("t-snap", TaskState::Completed);
writer
.write(StreamResponse::Task(replacement))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status.state, TaskState::Completed);
}
#[tokio::test]
async fn process_event_message_event_is_ignored() {
use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-msg");
task_store
.save(make_task("t-msg", TaskState::Working))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
let msg_event = StreamResponse::Message(Message {
id: MessageId::new("m1"),
role: MessageRole::Agent,
parts: vec![Part::text("hello")],
context_id: None,
task_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
});
writer.write(msg_event).await.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status.state, TaskState::Working);
}
#[tokio::test]
async fn process_event_error_marks_task_failed() {
use a2a_protocol_types::error::A2aError;
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-err-evt");
task_store
.save(make_task("t-err-evt", TaskState::Working))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (tx, rx) = tokio::sync::broadcast::channel(8);
let reader = crate::streaming::event_queue::InMemoryQueueReader::new(rx);
let err = A2aError::internal("executor failure");
tx.send(Err(err)).expect("send should succeed");
drop(tx);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(
matches!(result, Err(crate::error::ServerError::Protocol(_))),
"expected Protocol error, got: {result:?}"
);
let stored = task_store.get(&task_id).await.unwrap().unwrap();
assert_eq!(stored.status.state, TaskState::Failed);
}
#[allow(clippy::too_many_lines)]
#[tokio::test]
async fn collect_events_with_push_sender_delivers_notifications() {
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::push::TaskPushNotificationConfig;
struct CountingPushSender {
count: Arc<AtomicU64>,
}
impl crate::push::PushSender for CountingPushSender {
fn send<'a>(
&'a self,
_url: &'a str,
_event: &'a a2a_protocol_types::events::StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
self.count.fetch_add(1, Ordering::Relaxed);
Box::pin(async { Ok(()) })
}
}
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-push");
task_store
.save(make_task("t-push", TaskState::Submitted))
.await
.unwrap();
let counter = Arc::new(AtomicU64::new(0));
let sender = CountingPushSender {
count: Arc::clone(&counter),
};
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.with_push_sender(sender)
.build()
.unwrap();
let config = TaskPushNotificationConfig {
tenant: None,
id: Some("cfg-1".to_owned()),
task_id: "t-push".to_owned(),
url: "https://example.com/webhook".to_owned(),
token: None,
authentication: None,
};
handler.push_config_store.set(config).await.unwrap();
let (writer, reader) = new_in_memory_queue();
writer
.write(make_status_event("t-push", TaskState::Working))
.await
.unwrap();
writer
.write(make_status_event("t-push", TaskState::Completed))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id, executor_handle)
.await;
assert!(result.is_ok());
assert!(
counter.load(Ordering::Relaxed) >= 2,
"push sender should have been called at least twice"
);
}
#[tokio::test]
async fn collect_events_executor_done_drains_remaining() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-drain");
task_store
.save(make_task("t-drain", TaskState::Submitted))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
let writer_clone = writer.clone();
let executor_handle = tokio::spawn(async move {
writer_clone
.write(make_status_event("t-drain", TaskState::Working))
.await
.unwrap();
writer_clone
.write(make_status_event("t-drain", TaskState::Completed))
.await
.unwrap();
drop(writer_clone);
});
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
drop(writer);
});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok());
let final_task = result.unwrap();
assert_eq!(
final_task.status.state,
TaskState::Completed,
"task should drain remaining events after executor completes"
);
}
#[tokio::test]
async fn collect_events_executor_panic_marks_failed() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-panic");
task_store
.save(make_task("t-panic", TaskState::Submitted))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
let executor_handle = tokio::spawn(async {
panic!("executor panicked!");
});
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
drop(writer);
});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok(), "collect_events should still return Ok");
let final_task = result.unwrap();
assert_eq!(
final_task.status.state,
TaskState::Failed,
"task should be marked Failed after executor panic"
);
}
#[tokio::test]
async fn collect_events_artifact_limit_enforced() {
use crate::handler::limits::HandlerLimits;
use a2a_protocol_types::artifact::{Artifact, ArtifactId};
use a2a_protocol_types::events::TaskArtifactUpdateEvent;
use a2a_protocol_types::message::Part;
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-art-limit");
task_store
.save(make_task("t-art-limit", TaskState::Working))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.with_handler_limits(HandlerLimits::default().with_max_artifacts_per_task(1))
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
for i in 0..2 {
let artifact_event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
task_id: TaskId::new("t-art-limit"),
context_id: a2a_protocol_types::task::ContextId::new("ctx-1"),
artifact: Artifact::new(
ArtifactId::new(format!("art-{i}")),
vec![Part::text(format!("data {i}"))],
),
append: None,
last_chunk: Some(true),
metadata: None,
});
writer.write(artifact_event).await.unwrap();
}
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(result.is_ok());
let final_task = result.unwrap();
let artifacts = final_task.artifacts.expect("artifacts should be Some");
assert_eq!(artifacts.len(), 1, "artifact count should not exceed limit");
}
#[allow(clippy::too_many_lines)]
#[tokio::test]
async fn collect_events_push_delivery_failure_does_not_block() {
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::push::TaskPushNotificationConfig;
struct FailingPushSender {
count: Arc<AtomicU64>,
}
impl crate::push::PushSender for FailingPushSender {
fn send<'a>(
&'a self,
_url: &'a str,
_event: &'a a2a_protocol_types::events::StreamResponse,
_config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
self.count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Box::pin(async {
Err(a2a_protocol_types::error::A2aError::internal("push failed"))
})
}
}
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-push-fail");
task_store
.save(make_task("t-push-fail", TaskState::Submitted))
.await
.unwrap();
let counter = Arc::new(AtomicU64::new(0));
let sender = FailingPushSender {
count: Arc::clone(&counter),
};
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.with_push_sender(sender)
.build()
.unwrap();
let config = TaskPushNotificationConfig {
tenant: None,
id: Some("cfg-1".to_owned()),
task_id: "t-push-fail".to_owned(),
url: "https://example.com/webhook".to_owned(),
token: None,
authentication: None,
};
handler.push_config_store.set(config).await.unwrap();
let (writer, reader) = new_in_memory_queue();
writer
.write(make_status_event("t-push-fail", TaskState::Working))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let result = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await;
assert!(
result.is_ok(),
"collect_events should succeed despite push failure"
);
assert!(
counter.load(Ordering::Relaxed) >= 1,
"push sender should have been called"
);
}
#[tokio::test]
async fn collect_events_returns_final_task() {
let task_store = Arc::new(InMemoryTaskStore::new());
let task_id = TaskId::new("t-collect");
task_store
.save(make_task("t-collect", TaskState::Submitted))
.await
.unwrap();
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_task_store_arc(Arc::clone(&task_store) as Arc<dyn crate::store::TaskStore>)
.build()
.unwrap();
let (writer, reader) = new_in_memory_queue();
writer
.write(make_status_event("t-collect", TaskState::Working))
.await
.unwrap();
writer
.write(make_status_event("t-collect", TaskState::Completed))
.await
.unwrap();
drop(writer);
let executor_handle = tokio::spawn(async {});
let final_task = handler
.collect_events(reader, task_id.clone(), executor_handle)
.await
.expect("collect_events should not fail");
assert_eq!(
final_task.status.state,
TaskState::Completed,
"collect_events should return the task in its final state"
);
}
}