a2a_protocol_client/streaming/
event_stream.rs1use a2a_protocol_types::{JsonRpcResponse, StreamResponse};
33use hyper::body::Bytes;
34use tokio::sync::mpsc;
35use tokio::task::AbortHandle;
36
37use crate::error::{ClientError, ClientResult};
38use crate::streaming::sse_parser::SseParser;
39
40pub(crate) type BodyChunk = ClientResult<Bytes>;
44
45pub struct EventStream {
56 rx: mpsc::Receiver<BodyChunk>,
58 parser: SseParser,
60 done: bool,
62 abort_handle: Option<AbortHandle>,
64}
65
66impl EventStream {
67 #[must_use]
73 #[cfg(test)]
74 pub(crate) fn new(rx: mpsc::Receiver<BodyChunk>) -> Self {
75 Self {
76 rx,
77 parser: SseParser::new(),
78 done: false,
79 abort_handle: None,
80 }
81 }
82
83 #[must_use]
88 pub(crate) fn with_abort_handle(
89 rx: mpsc::Receiver<BodyChunk>,
90 abort_handle: AbortHandle,
91 ) -> Self {
92 Self {
93 rx,
94 parser: SseParser::new(),
95 done: false,
96 abort_handle: Some(abort_handle),
97 }
98 }
99
100 pub async fn next(&mut self) -> Option<ClientResult<StreamResponse>> {
107 loop {
108 if let Some(result) = self.parser.next_frame() {
110 match result {
111 Ok(frame) => return Some(self.decode_frame(&frame.data)),
112 Err(e) => {
113 return Some(Err(ClientError::Transport(e.to_string())));
114 }
115 }
116 }
117
118 if self.done {
119 return None;
120 }
121
122 match self.rx.recv().await {
124 None => {
125 self.done = true;
127 if let Some(result) = self.parser.next_frame() {
129 match result {
130 Ok(frame) => return Some(self.decode_frame(&frame.data)),
131 Err(e) => {
132 return Some(Err(ClientError::Transport(e.to_string())));
133 }
134 }
135 }
136 return None;
137 }
138 Some(Err(e)) => {
139 self.done = true;
140 return Some(Err(e));
141 }
142 Some(Ok(bytes)) => {
143 self.parser.feed(&bytes);
144 }
145 }
146 }
147 }
148
149 fn decode_frame(&mut self, data: &str) -> ClientResult<StreamResponse> {
152 let envelope: JsonRpcResponse<StreamResponse> =
154 serde_json::from_str(data).map_err(ClientError::Serialization)?;
155
156 match envelope {
157 JsonRpcResponse::Success(ok) => {
158 if is_terminal(&ok.result) {
160 self.done = true;
161 }
162 Ok(ok.result)
163 }
164 JsonRpcResponse::Error(err) => {
165 self.done = true;
166 let a2a = a2a_protocol_types::A2aError::new(
167 a2a_protocol_types::ErrorCode::try_from(err.error.code)
168 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
169 err.error.message,
170 );
171 Err(ClientError::Protocol(a2a))
172 }
173 }
174 }
175}
176
177impl Drop for EventStream {
178 fn drop(&mut self) {
179 if let Some(handle) = self.abort_handle.take() {
180 handle.abort();
181 }
182 }
183}
184
185#[allow(clippy::missing_fields_in_debug)]
186impl std::fmt::Debug for EventStream {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 f.debug_struct("EventStream")
190 .field("done", &self.done)
191 .field("pending_frames", &self.parser.pending_count())
192 .finish()
193 }
194}
195
196const fn is_terminal(event: &StreamResponse) -> bool {
198 matches!(
199 event,
200 StreamResponse::StatusUpdate(ev) if ev.status.state.is_terminal()
201 )
202}
203
204#[cfg(test)]
207mod tests {
208 use super::*;
209 use a2a_protocol_types::{
210 JsonRpcSuccessResponse, JsonRpcVersion, TaskId, TaskState, TaskStatus,
211 TaskStatusUpdateEvent,
212 };
213
214 fn make_status_event(state: TaskState, _is_final: bool) -> StreamResponse {
215 StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
216 task_id: TaskId::new("t1"),
217 context_id: a2a_protocol_types::ContextId::new("c1"),
218 status: TaskStatus {
219 state,
220 message: None,
221 timestamp: None,
222 },
223 metadata: None,
224 })
225 }
226
227 fn sse_frame(event: &StreamResponse) -> String {
228 let resp = JsonRpcSuccessResponse {
229 jsonrpc: JsonRpcVersion,
230 id: Some(serde_json::json!(1)),
231 result: event.clone(),
232 };
233 let json = serde_json::to_string(&resp).unwrap();
234 format!("data: {json}\n\n")
235 }
236
237 #[tokio::test]
238 async fn stream_delivers_events() {
239 let (tx, rx) = mpsc::channel(8);
240 let mut stream = EventStream::new(rx);
241
242 let event = make_status_event(TaskState::Working, false);
243 let sse_bytes = sse_frame(&event);
244 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
245 drop(tx);
246
247 let result = stream.next().await.unwrap();
248 assert!(result.is_ok());
249 assert!(matches!(result.unwrap(), StreamResponse::StatusUpdate(_)));
250 }
251
252 #[tokio::test]
253 async fn stream_ends_on_final_event() {
254 let (tx, rx) = mpsc::channel(8);
255 let mut stream = EventStream::new(rx);
256
257 let event = make_status_event(TaskState::Completed, true);
258 let sse_bytes = sse_frame(&event);
259 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
260
261 let result = stream.next().await.unwrap();
263 assert!(result.is_ok());
264
265 assert!(stream.next().await.is_none());
267 }
268
269 #[tokio::test]
270 async fn stream_propagates_body_error() {
271 let (tx, rx) = mpsc::channel(8);
272 let mut stream = EventStream::new(rx);
273
274 tx.send(Err(ClientError::Transport("network error".into())))
275 .await
276 .unwrap();
277
278 let result = stream.next().await.unwrap();
279 assert!(result.is_err());
280 }
281
282 #[tokio::test]
283 async fn stream_ends_when_channel_closed() {
284 let (tx, rx) = mpsc::channel(8);
285 let mut stream = EventStream::new(rx);
286 drop(tx);
287
288 assert!(stream.next().await.is_none());
289 }
290}