use a2a_protocol_types::{JsonRpcResponse, StreamResponse};
use hyper::body::Bytes;
use tokio::sync::mpsc;
use tokio::task::AbortHandle;
use crate::error::{ClientError, ClientResult};
use crate::streaming::sse_parser::SseParser;
pub(crate) type BodyChunk = ClientResult<Bytes>;
pub struct EventStream {
rx: mpsc::Receiver<BodyChunk>,
parser: SseParser,
done: bool,
abort_handle: Option<AbortHandle>,
status_code: u16,
jsonrpc_envelope: bool,
}
impl EventStream {
#[must_use]
#[cfg(any(test, feature = "websocket"))]
pub(crate) fn new(rx: mpsc::Receiver<BodyChunk>) -> Self {
Self {
rx,
parser: SseParser::new(),
done: false,
abort_handle: None,
status_code: 200,
jsonrpc_envelope: true,
}
}
#[must_use]
#[cfg(test)]
pub(crate) fn with_abort_handle(
rx: mpsc::Receiver<BodyChunk>,
abort_handle: AbortHandle,
) -> Self {
Self {
rx,
parser: SseParser::new(),
done: false,
abort_handle: Some(abort_handle),
status_code: 200,
jsonrpc_envelope: true,
}
}
#[must_use]
pub(crate) fn with_status(
rx: mpsc::Receiver<BodyChunk>,
abort_handle: AbortHandle,
status_code: u16,
) -> Self {
Self {
rx,
parser: SseParser::new(),
done: false,
abort_handle: Some(abort_handle),
status_code,
jsonrpc_envelope: true,
}
}
#[must_use]
pub(crate) const fn with_jsonrpc_envelope(mut self, envelope: bool) -> Self {
self.jsonrpc_envelope = envelope;
self
}
#[must_use]
pub const fn status_code(&self) -> u16 {
self.status_code
}
pub async fn next(&mut self) -> Option<ClientResult<StreamResponse>> {
loop {
if let Some(result) = self.parser.next_frame() {
match result {
Ok(frame) => return Some(self.decode_frame(&frame.data)),
Err(e) => {
return Some(Err(ClientError::Transport(e.to_string())));
}
}
}
if self.done {
return None;
}
match self.rx.recv().await {
None => {
self.done = true;
if let Some(result) = self.parser.next_frame() {
match result {
Ok(frame) => return Some(self.decode_frame(&frame.data)),
Err(e) => {
return Some(Err(ClientError::Transport(e.to_string())));
}
}
}
return None;
}
Some(Err(e)) => {
self.done = true;
return Some(Err(e));
}
Some(Ok(bytes)) => {
self.parser.feed(&bytes);
}
}
}
}
fn decode_frame(&mut self, data: &str) -> ClientResult<StreamResponse> {
if self.jsonrpc_envelope {
let envelope: JsonRpcResponse<StreamResponse> =
serde_json::from_str(data).map_err(ClientError::Serialization)?;
match envelope {
JsonRpcResponse::Success(ok) => {
if is_terminal(&ok.result) {
self.done = true;
}
Ok(ok.result)
}
JsonRpcResponse::Error(err) => {
self.done = true;
let a2a = a2a_protocol_types::A2aError::new(
a2a_protocol_types::ErrorCode::try_from(err.error.code)
.unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
err.error.message,
);
Err(ClientError::Protocol(a2a))
}
}
} else {
let event: StreamResponse =
serde_json::from_str(data).map_err(ClientError::Serialization)?;
if is_terminal(&event) {
self.done = true;
}
Ok(event)
}
}
}
impl Drop for EventStream {
fn drop(&mut self) {
if let Some(handle) = self.abort_handle.take() {
handle.abort();
}
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for EventStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventStream")
.field("done", &self.done)
.field("pending_frames", &self.parser.pending_count())
.finish()
}
}
const fn is_terminal(event: &StreamResponse) -> bool {
matches!(
event,
StreamResponse::StatusUpdate(ev) if ev.status.state.is_terminal()
)
}
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::{
JsonRpcSuccessResponse, JsonRpcVersion, TaskId, TaskState, TaskStatus,
TaskStatusUpdateEvent,
};
use std::time::Duration;
const TEST_TIMEOUT: Duration = Duration::from_secs(5);
fn make_status_event(state: TaskState, _is_final: bool) -> StreamResponse {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: TaskId::new("t1"),
context_id: a2a_protocol_types::ContextId::new("c1"),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
}
fn sse_frame(event: &StreamResponse) -> String {
let resp = JsonRpcSuccessResponse {
jsonrpc: JsonRpcVersion,
id: Some(serde_json::json!(1)),
result: event.clone(),
};
let json = serde_json::to_string(&resp).unwrap();
format!("data: {json}\n\n")
}
#[tokio::test]
async fn stream_delivers_events() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let event = make_status_event(TaskState::Working, false);
let sse_bytes = sse_frame(&event);
tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap()
.unwrap();
assert!(
matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
);
}
#[tokio::test]
async fn stream_ends_on_final_event() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let event = make_status_event(TaskState::Completed, true);
let sse_bytes = sse_frame(&event);
tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out waiting for final event")
.unwrap()
.unwrap();
assert!(
matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
);
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out waiting for stream end");
assert!(end.is_none());
}
#[tokio::test]
async fn stream_propagates_body_error() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
tx.send(Err(ClientError::Transport("network error".into())))
.await
.unwrap();
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn stream_ends_when_channel_closed() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out");
assert!(result.is_none());
}
#[tokio::test]
async fn drop_aborts_background_task() {
let (tx, rx) = mpsc::channel::<BodyChunk>(8);
let handle = tokio::spawn(async move {
let _tx = tx;
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
});
let abort_handle = handle.abort_handle();
let stream = EventStream::with_abort_handle(rx, abort_handle);
drop(stream);
let result = tokio::time::timeout(TEST_TIMEOUT, handle)
.await
.expect("timed out waiting for task abort");
assert!(result.is_err(), "task should have been aborted");
assert!(
result.unwrap_err().is_cancelled(),
"task should be cancelled"
);
}
#[test]
fn debug_output_contains_fields() {
let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
let stream = EventStream::new(rx);
let debug = format!("{stream:?}");
assert!(debug.contains("EventStream"), "should contain struct name");
assert!(debug.contains("done"), "should contain 'done' field");
assert!(
debug.contains("pending_frames"),
"should contain 'pending_frames' field"
);
}
#[test]
fn is_terminal_returns_false_for_working() {
let event = make_status_event(TaskState::Working, false);
assert!(!is_terminal(&event), "Working state should not be terminal");
}
#[test]
fn is_terminal_returns_true_for_completed() {
let event = make_status_event(TaskState::Completed, true);
assert!(is_terminal(&event), "Completed state should be terminal");
}
#[tokio::test]
async fn stream_decodes_jsonrpc_error_as_protocol_error() {
use a2a_protocol_types::{JsonRpcErrorResponse, JsonRpcVersion};
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let error_resp = JsonRpcErrorResponse {
jsonrpc: JsonRpcVersion,
id: Some(serde_json::json!(1)),
error: a2a_protocol_types::JsonRpcError {
code: -32601,
message: "method not found".into(),
data: None,
},
};
let json = serde_json::to_string(&error_resp).unwrap();
let sse_data = format!("data: {json}\n\n");
tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
assert!(result.is_err(), "JSON-RPC error should produce Err");
match result.unwrap_err() {
ClientError::Protocol(err) => {
assert!(
format!("{err}").contains("method not found"),
"error message should be preserved"
);
}
other => panic!("expected Protocol error, got {other:?}"),
}
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out");
assert!(end.is_none(), "stream should end after JSON-RPC error");
}
#[tokio::test]
async fn stream_invalid_json_returns_serialization_error() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let sse_data = "data: {not valid json}\n\n";
tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
assert!(result.is_err(), "invalid JSON should produce Err");
assert!(
matches!(result.unwrap_err(), ClientError::Serialization(_)),
"should be a Serialization error"
);
}
#[tokio::test]
async fn stream_drains_parser_after_channel_close() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let event = make_status_event(TaskState::Working, false);
let sse_bytes = sse_frame(&event);
let (first_half, second_half) = sse_bytes.split_at(sse_bytes.len() / 2);
tx.send(Ok(Bytes::from(first_half.to_owned())))
.await
.unwrap();
tx.send(Ok(Bytes::from(second_half.to_owned())))
.await
.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
let event = result.unwrap();
assert!(
matches!(event, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working),
"should deliver Working event from drained parser"
);
}
#[tokio::test]
async fn status_code_returns_set_value() {
let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
let stream = EventStream::new(rx);
assert_eq!(stream.status_code(), 200, "default status should be 200");
}
#[tokio::test]
async fn status_code_with_custom_value() {
let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
let task = tokio::spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
let stream = EventStream::with_status(rx, task.abort_handle(), 201);
assert_eq!(stream.status_code(), 201);
}
#[tokio::test]
async fn stream_transport_error_from_channel() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
tx.send(Err(ClientError::HttpClient("connection reset".into())))
.await
.unwrap();
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
match result {
Err(ClientError::HttpClient(msg)) => {
assert!(msg.contains("connection reset"));
}
other => panic!("expected HttpClient error, got {other:?}"),
}
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out");
assert!(end.is_none(), "stream should end after transport error");
}
#[tokio::test]
async fn non_terminal_event_does_not_end_stream() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let working = make_status_event(TaskState::Working, false);
let completed = make_status_event(TaskState::Completed, true);
tx.send(Ok(Bytes::from(sse_frame(&working)))).await.unwrap();
tx.send(Ok(Bytes::from(sse_frame(&completed))))
.await
.unwrap();
let first = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out on first event")
.unwrap()
.unwrap();
assert!(
matches!(first, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
);
let second = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out on second event")
.unwrap()
.unwrap();
assert!(
matches!(second, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
);
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out waiting for stream end");
assert!(end.is_none());
}
fn bare_sse_frame(event: &StreamResponse) -> String {
let json = serde_json::to_string(event).unwrap();
format!("data: {json}\n\n")
}
#[tokio::test]
async fn bare_stream_delivers_events() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
let event = make_status_event(TaskState::Working, false);
tx.send(Ok(Bytes::from(bare_sse_frame(&event))))
.await
.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap()
.unwrap();
assert!(
matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
);
}
#[tokio::test]
async fn bare_stream_ends_on_terminal() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
let event = make_status_event(TaskState::Completed, true);
tx.send(Ok(Bytes::from(bare_sse_frame(&event))))
.await
.unwrap();
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap()
.unwrap();
assert!(
matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
);
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out");
assert!(end.is_none(), "bare stream should end after terminal event");
}
#[tokio::test]
async fn bare_stream_rejects_jsonrpc_envelope() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
let event = make_status_event(TaskState::Working, false);
let envelope_frame = sse_frame(&event); tx.send(Ok(Bytes::from(envelope_frame))).await.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
assert!(
result.is_err(),
"bare stream should reject JSON-RPC envelope as invalid"
);
}
#[tokio::test]
async fn envelope_stream_rejects_bare_response() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx);
let event = make_status_event(TaskState::Working, false);
let bare_frame = bare_sse_frame(&event);
tx.send(Ok(Bytes::from(bare_frame))).await.unwrap();
drop(tx);
let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap();
assert!(
result.is_err(),
"envelope stream should reject bare StreamResponse"
);
}
#[tokio::test]
async fn bare_stream_multiple_events() {
let (tx, rx) = mpsc::channel(8);
let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
let working = make_status_event(TaskState::Working, false);
let completed = make_status_event(TaskState::Completed, true);
tx.send(Ok(Bytes::from(bare_sse_frame(&working))))
.await
.unwrap();
tx.send(Ok(Bytes::from(bare_sse_frame(&completed))))
.await
.unwrap();
let first = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap()
.unwrap();
assert!(
matches!(first, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
);
let second = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out")
.unwrap()
.unwrap();
assert!(
matches!(second, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
);
let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect("timed out");
assert!(end.is_none());
}
}