use super::*;
struct InvalidTransitionExecutor;
impl AgentExecutor for InvalidTransitionExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Submitted),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct TerminalTransitionExecutor;
impl AgentExecutor for TerminalTransitionExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Completed),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct MultiTransitionExecutor;
impl AgentExecutor for MultiTransitionExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
for state in [
TaskState::Working,
TaskState::InputRequired,
TaskState::Working,
TaskState::Completed,
] {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(state),
metadata: None,
}))
.await?;
}
Ok(())
})
}
}
struct CanceledExecutor;
impl AgentExecutor for CanceledExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Canceled),
metadata: None,
}))
.await?;
Ok(())
})
}
}
struct FailedStatusExecutor;
impl AgentExecutor for FailedStatusExecutor {
fn execute<'a>(
&'a self,
ctx: &'a RequestContext,
queue: &'a dyn EventQueueWriter,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Working),
metadata: None,
}))
.await?;
queue
.write(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: ctx.task_id.clone(),
context_id: ContextId::new(ctx.context_id.clone()),
status: TaskStatus::new(TaskState::Failed),
metadata: None,
}))
.await?;
Ok(())
})
}
}
#[tokio::test]
async fn sync_mode_invalid_state_transition_returns_error() {
let handler = RequestHandlerBuilder::new(InvalidTransitionExecutor)
.build()
.expect("build handler");
let result = handler
.on_send_message(make_send_params(), false, None)
.await;
match result {
Err(ref err) => {
assert!(
matches!(
err,
a2a_protocol_server::ServerError::InvalidStateTransition { .. }
),
"expected InvalidStateTransition, got {err:?}"
);
}
Ok(_) => panic!("expected error for invalid state transition"),
}
}
#[tokio::test]
async fn sync_mode_completed_to_working_is_invalid() {
let handler = RequestHandlerBuilder::new(TerminalTransitionExecutor)
.build()
.expect("build handler");
let task = extract_task(
handler
.on_send_message(make_send_params(), false, None)
.await
.expect("send should succeed — early exit at terminal state"),
);
assert_eq!(
task.status.state,
TaskState::Completed,
"should return at terminal state before invalid transition"
);
}
#[tokio::test]
async fn streaming_mode_invalid_transition_does_not_crash_stream() {
let handler = RequestHandlerBuilder::new(InvalidTransitionExecutor)
.build()
.expect("build handler");
let result = handler
.on_send_message(make_send_params(), true, None)
.await
.expect("send streaming");
let mut reader = match result {
SendMessageResult::Stream(r) => r,
_ => panic!("expected Stream"),
};
let mut events = vec![];
while let Some(event) = reader.read().await {
events.push(event);
}
assert!(!events.is_empty(), "stream should still produce events");
}
#[tokio::test]
async fn sync_mode_multiple_valid_transitions() {
let handler = RequestHandlerBuilder::new(MultiTransitionExecutor)
.build()
.expect("build handler");
let task = extract_task(
handler
.on_send_message(make_send_params(), false, None)
.await
.expect("send"),
);
assert_eq!(
task.status.state,
TaskState::InputRequired,
"blocking mode should return at first interrupted state"
);
}
#[tokio::test]
async fn sync_mode_working_to_canceled() {
let handler = RequestHandlerBuilder::new(CanceledExecutor)
.build()
.expect("build handler");
let task = extract_task(
handler
.on_send_message(make_send_params(), false, None)
.await
.expect("send"),
);
assert_eq!(task.status.state, TaskState::Canceled);
}
#[tokio::test]
async fn sync_mode_working_to_failed_via_status_update() {
let handler = RequestHandlerBuilder::new(FailedStatusExecutor)
.build()
.expect("build handler");
let task = extract_task(
handler
.on_send_message(make_send_params(), false, None)
.await
.expect("send"),
);
assert_eq!(task.status.state, TaskState::Failed);
}