Skip to main content

a2a_protocol_client/streaming/
event_stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Async SSE event stream with typed deserialization.
5//!
6//! [`EventStream`] provides an async `next()` iterator over
7//! [`a2a_protocol_types::StreamResponse`] events received via Server-Sent Events.
8//!
9//! The stream terminates when:
10//! - The underlying HTTP body closes (normal end-of-stream).
11//! - A [`a2a_protocol_types::TaskStatusUpdateEvent`] with `final: true` is received.
12//! - A protocol or transport error occurs (returned as `Some(Err(...))`).
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! let mut stream = client.stream_message(params).await?;
18//! while let Some(event) = stream.next().await {
19//!     match event? {
20//!         StreamResponse::StatusUpdate(ev) => {
21//!             println!("State: {:?}", ev.state);
22//!             if ev.r#final { break; }
23//!         }
24//!         StreamResponse::ArtifactUpdate(ev) => {
25//!             println!("Artifact: {:?}", ev.artifact);
26//!         }
27//!         _ => {}
28//!     }
29//! }
30//! ```
31
32use a2a_protocol_types::{JsonRpcResponse, StreamResponse};
33use hyper::body::Bytes;
34use tokio::sync::mpsc;
35use tokio::task::AbortHandle;
36
37use crate::error::{ClientError, ClientResult};
38use crate::streaming::sse_parser::SseParser;
39
40// ── Chunk ─────────────────────────────────────────────────────────────────────
41
42/// A raw byte chunk from the HTTP body reader task.
43pub(crate) type BodyChunk = ClientResult<Bytes>;
44
45// ── EventStream ───────────────────────────────────────────────────────────────
46
47/// An async stream of [`StreamResponse`] events from an SSE endpoint.
48///
49/// Created by [`crate::A2aClient::stream_message`] or
50/// [`crate::A2aClient::subscribe_to_task`]. Call [`EventStream::next`] in a loop
51/// to consume events.
52///
53/// When dropped, the background body-reader task is aborted to prevent
54/// resource leaks.
55pub struct EventStream {
56    /// Channel receiver delivering raw byte chunks from the HTTP body.
57    rx: mpsc::Receiver<BodyChunk>,
58    /// SSE parser state machine.
59    parser: SseParser,
60    /// Whether the stream has been signalled as terminated.
61    done: bool,
62    /// Handle to abort the background body-reader task on drop.
63    abort_handle: Option<AbortHandle>,
64}
65
66impl EventStream {
67    /// Creates a new [`EventStream`] from a channel receiver (without abort handle).
68    ///
69    /// The channel must be fed raw HTTP body bytes from a background task.
70    /// Prefer [`EventStream::with_abort_handle`] to ensure the background task
71    /// is cancelled when the stream is dropped.
72    #[must_use]
73    #[cfg(test)]
74    pub(crate) fn new(rx: mpsc::Receiver<BodyChunk>) -> Self {
75        Self {
76            rx,
77            parser: SseParser::new(),
78            done: false,
79            abort_handle: None,
80        }
81    }
82
83    /// Creates a new [`EventStream`] with an abort handle for the body-reader task.
84    ///
85    /// When the `EventStream` is dropped, the abort handle is used to cancel
86    /// the background task, preventing resource leaks.
87    #[must_use]
88    pub(crate) fn with_abort_handle(
89        rx: mpsc::Receiver<BodyChunk>,
90        abort_handle: AbortHandle,
91    ) -> Self {
92        Self {
93            rx,
94            parser: SseParser::new(),
95            done: false,
96            abort_handle: Some(abort_handle),
97        }
98    }
99
100    /// Returns the next event from the stream.
101    ///
102    /// Returns `None` when the stream ends normally (either the HTTP body
103    /// closed or a `final: true` event was received).
104    ///
105    /// Returns `Some(Err(...))` on transport or protocol errors.
106    pub async fn next(&mut self) -> Option<ClientResult<StreamResponse>> {
107        loop {
108            // First, drain any frames the parser already has buffered.
109            if let Some(result) = self.parser.next_frame() {
110                match result {
111                    Ok(frame) => return Some(self.decode_frame(&frame.data)),
112                    Err(e) => {
113                        return Some(Err(ClientError::Transport(e.to_string())));
114                    }
115                }
116            }
117
118            if self.done {
119                return None;
120            }
121
122            // Need more bytes — wait for the next chunk from the body reader.
123            match self.rx.recv().await {
124                None => {
125                    // Channel closed — body reader task exited.
126                    self.done = true;
127                    // Drain any remaining parser frames.
128                    if let Some(result) = self.parser.next_frame() {
129                        match result {
130                            Ok(frame) => return Some(self.decode_frame(&frame.data)),
131                            Err(e) => {
132                                return Some(Err(ClientError::Transport(e.to_string())));
133                            }
134                        }
135                    }
136                    return None;
137                }
138                Some(Err(e)) => {
139                    self.done = true;
140                    return Some(Err(e));
141                }
142                Some(Ok(bytes)) => {
143                    self.parser.feed(&bytes);
144                }
145            }
146        }
147    }
148
149    // ── internals ─────────────────────────────────────────────────────────────
150
151    fn decode_frame(&mut self, data: &str) -> ClientResult<StreamResponse> {
152        // Each SSE frame's `data` is a JSON-RPC response carrying a StreamResponse.
153        let envelope: JsonRpcResponse<StreamResponse> =
154            serde_json::from_str(data).map_err(ClientError::Serialization)?;
155
156        match envelope {
157            JsonRpcResponse::Success(ok) => {
158                // Check for terminal event so callers don't need to.
159                if is_terminal(&ok.result) {
160                    self.done = true;
161                }
162                Ok(ok.result)
163            }
164            JsonRpcResponse::Error(err) => {
165                self.done = true;
166                let a2a = a2a_protocol_types::A2aError::new(
167                    a2a_protocol_types::ErrorCode::try_from(err.error.code)
168                        .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
169                    err.error.message,
170                );
171                Err(ClientError::Protocol(a2a))
172            }
173        }
174    }
175}
176
177impl Drop for EventStream {
178    fn drop(&mut self) {
179        if let Some(handle) = self.abort_handle.take() {
180            handle.abort();
181        }
182    }
183}
184
185#[allow(clippy::missing_fields_in_debug)]
186impl std::fmt::Debug for EventStream {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        // `rx` and `parser` don't implement Debug in a useful way; show key state only.
189        f.debug_struct("EventStream")
190            .field("done", &self.done)
191            .field("pending_frames", &self.parser.pending_count())
192            .finish()
193    }
194}
195
196/// Returns `true` if `event` is the terminal event for its stream.
197const fn is_terminal(event: &StreamResponse) -> bool {
198    matches!(
199        event,
200        StreamResponse::StatusUpdate(ev) if ev.status.state.is_terminal()
201    )
202}
203
204// ── Tests ─────────────────────────────────────────────────────────────────────
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use a2a_protocol_types::{
210        JsonRpcSuccessResponse, JsonRpcVersion, TaskId, TaskState, TaskStatus,
211        TaskStatusUpdateEvent,
212    };
213
214    fn make_status_event(state: TaskState, _is_final: bool) -> StreamResponse {
215        StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
216            task_id: TaskId::new("t1"),
217            context_id: a2a_protocol_types::ContextId::new("c1"),
218            status: TaskStatus {
219                state,
220                message: None,
221                timestamp: None,
222            },
223            metadata: None,
224        })
225    }
226
227    fn sse_frame(event: &StreamResponse) -> String {
228        let resp = JsonRpcSuccessResponse {
229            jsonrpc: JsonRpcVersion,
230            id: Some(serde_json::json!(1)),
231            result: event.clone(),
232        };
233        let json = serde_json::to_string(&resp).unwrap();
234        format!("data: {json}\n\n")
235    }
236
237    #[tokio::test]
238    async fn stream_delivers_events() {
239        let (tx, rx) = mpsc::channel(8);
240        let mut stream = EventStream::new(rx);
241
242        let event = make_status_event(TaskState::Working, false);
243        let sse_bytes = sse_frame(&event);
244        tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
245        drop(tx);
246
247        let result = stream.next().await.unwrap();
248        assert!(result.is_ok());
249        assert!(matches!(result.unwrap(), StreamResponse::StatusUpdate(_)));
250    }
251
252    #[tokio::test]
253    async fn stream_ends_on_final_event() {
254        let (tx, rx) = mpsc::channel(8);
255        let mut stream = EventStream::new(rx);
256
257        let event = make_status_event(TaskState::Completed, true);
258        let sse_bytes = sse_frame(&event);
259        tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
260
261        // First next() returns the final event.
262        let result = stream.next().await.unwrap();
263        assert!(result.is_ok());
264
265        // Second next() returns None — stream is done.
266        assert!(stream.next().await.is_none());
267    }
268
269    #[tokio::test]
270    async fn stream_propagates_body_error() {
271        let (tx, rx) = mpsc::channel(8);
272        let mut stream = EventStream::new(rx);
273
274        tx.send(Err(ClientError::Transport("network error".into())))
275            .await
276            .unwrap();
277
278        let result = stream.next().await.unwrap();
279        assert!(result.is_err());
280    }
281
282    #[tokio::test]
283    async fn stream_ends_when_channel_closed() {
284        let (tx, rx) = mpsc::channel(8);
285        let mut stream = EventStream::new(rx);
286        drop(tx);
287
288        assert!(stream.next().await.is_none());
289    }
290}