a2a_protocol_client/streaming/
event_stream.rs1use a2a_protocol_types::{JsonRpcResponse, StreamResponse};
35use hyper::body::Bytes;
36use tokio::sync::mpsc;
37use tokio::task::AbortHandle;
38
39use crate::error::{ClientError, ClientResult};
40use crate::streaming::sse_parser::SseParser;
41
42pub(crate) type BodyChunk = ClientResult<Bytes>;
46
47pub struct EventStream {
58 rx: mpsc::Receiver<BodyChunk>,
60 parser: SseParser,
62 done: bool,
64 abort_handle: Option<AbortHandle>,
66 status_code: u16,
73}
74
75impl EventStream {
76 #[must_use]
82 #[cfg(any(test, feature = "websocket"))]
83 pub(crate) fn new(rx: mpsc::Receiver<BodyChunk>) -> Self {
84 Self {
85 rx,
86 parser: SseParser::new(),
87 done: false,
88 abort_handle: None,
89 status_code: 200,
90 }
91 }
92
93 #[must_use]
98 #[cfg(test)]
99 pub(crate) fn with_abort_handle(
100 rx: mpsc::Receiver<BodyChunk>,
101 abort_handle: AbortHandle,
102 ) -> Self {
103 Self {
104 rx,
105 parser: SseParser::new(),
106 done: false,
107 abort_handle: Some(abort_handle),
108 status_code: 200,
109 }
110 }
111
112 #[must_use]
115 pub(crate) fn with_status(
116 rx: mpsc::Receiver<BodyChunk>,
117 abort_handle: AbortHandle,
118 status_code: u16,
119 ) -> Self {
120 Self {
121 rx,
122 parser: SseParser::new(),
123 done: false,
124 abort_handle: Some(abort_handle),
125 status_code,
126 }
127 }
128
129 #[must_use]
134 pub const fn status_code(&self) -> u16 {
135 self.status_code
136 }
137
138 pub async fn next(&mut self) -> Option<ClientResult<StreamResponse>> {
145 loop {
146 if let Some(result) = self.parser.next_frame() {
148 match result {
149 Ok(frame) => return Some(self.decode_frame(&frame.data)),
150 Err(e) => {
151 return Some(Err(ClientError::Transport(e.to_string())));
152 }
153 }
154 }
155
156 if self.done {
157 return None;
158 }
159
160 match self.rx.recv().await {
162 None => {
163 self.done = true;
165 if let Some(result) = self.parser.next_frame() {
167 match result {
168 Ok(frame) => return Some(self.decode_frame(&frame.data)),
169 Err(e) => {
170 return Some(Err(ClientError::Transport(e.to_string())));
171 }
172 }
173 }
174 return None;
175 }
176 Some(Err(e)) => {
177 self.done = true;
178 return Some(Err(e));
179 }
180 Some(Ok(bytes)) => {
181 self.parser.feed(&bytes);
182 }
183 }
184 }
185 }
186
187 fn decode_frame(&mut self, data: &str) -> ClientResult<StreamResponse> {
190 let envelope: JsonRpcResponse<StreamResponse> =
192 serde_json::from_str(data).map_err(ClientError::Serialization)?;
193
194 match envelope {
195 JsonRpcResponse::Success(ok) => {
196 if is_terminal(&ok.result) {
198 self.done = true;
199 }
200 Ok(ok.result)
201 }
202 JsonRpcResponse::Error(err) => {
203 self.done = true;
204 let a2a = a2a_protocol_types::A2aError::new(
205 a2a_protocol_types::ErrorCode::try_from(err.error.code)
206 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
207 err.error.message,
208 );
209 Err(ClientError::Protocol(a2a))
210 }
211 }
212 }
213}
214
215impl Drop for EventStream {
216 fn drop(&mut self) {
217 if let Some(handle) = self.abort_handle.take() {
218 handle.abort();
219 }
220 }
221}
222
223#[allow(clippy::missing_fields_in_debug)]
224impl std::fmt::Debug for EventStream {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("EventStream")
228 .field("done", &self.done)
229 .field("pending_frames", &self.parser.pending_count())
230 .finish()
231 }
232}
233
234const fn is_terminal(event: &StreamResponse) -> bool {
236 matches!(
237 event,
238 StreamResponse::StatusUpdate(ev) if ev.status.state.is_terminal()
239 )
240}
241
242#[cfg(test)]
245mod tests {
246 use super::*;
247 use a2a_protocol_types::{
248 JsonRpcSuccessResponse, JsonRpcVersion, TaskId, TaskState, TaskStatus,
249 TaskStatusUpdateEvent,
250 };
251 use std::time::Duration;
252
253 const TEST_TIMEOUT: Duration = Duration::from_secs(5);
256
257 fn make_status_event(state: TaskState, _is_final: bool) -> StreamResponse {
258 StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
259 task_id: TaskId::new("t1"),
260 context_id: a2a_protocol_types::ContextId::new("c1"),
261 status: TaskStatus {
262 state,
263 message: None,
264 timestamp: None,
265 },
266 metadata: None,
267 })
268 }
269
270 fn sse_frame(event: &StreamResponse) -> String {
271 let resp = JsonRpcSuccessResponse {
272 jsonrpc: JsonRpcVersion,
273 id: Some(serde_json::json!(1)),
274 result: event.clone(),
275 };
276 let json = serde_json::to_string(&resp).unwrap();
277 format!("data: {json}\n\n")
278 }
279
280 #[tokio::test]
281 async fn stream_delivers_events() {
282 let (tx, rx) = mpsc::channel(8);
283 let mut stream = EventStream::new(rx);
284
285 let event = make_status_event(TaskState::Working, false);
286 let sse_bytes = sse_frame(&event);
287 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
288 drop(tx);
289
290 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
291 .await
292 .expect("timed out")
293 .unwrap()
294 .unwrap();
295 assert!(
296 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
297 );
298 }
299
300 #[tokio::test]
301 async fn stream_ends_on_final_event() {
302 let (tx, rx) = mpsc::channel(8);
303 let mut stream = EventStream::new(rx);
304
305 let event = make_status_event(TaskState::Completed, true);
306 let sse_bytes = sse_frame(&event);
307 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
308
309 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
311 .await
312 .expect("timed out waiting for final event")
313 .unwrap()
314 .unwrap();
315 assert!(
316 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
317 );
318
319 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
321 .await
322 .expect("timed out waiting for stream end");
323 assert!(end.is_none());
324 }
325
326 #[tokio::test]
327 async fn stream_propagates_body_error() {
328 let (tx, rx) = mpsc::channel(8);
329 let mut stream = EventStream::new(rx);
330
331 tx.send(Err(ClientError::Transport("network error".into())))
332 .await
333 .unwrap();
334
335 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
336 .await
337 .expect("timed out")
338 .unwrap();
339 assert!(result.is_err());
340 }
341
342 #[tokio::test]
343 async fn stream_ends_when_channel_closed() {
344 let (tx, rx) = mpsc::channel(8);
345 let mut stream = EventStream::new(rx);
346 drop(tx);
347
348 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
349 .await
350 .expect("timed out");
351 assert!(result.is_none());
352 }
353
354 #[tokio::test]
355 async fn drop_aborts_background_task() {
356 let (tx, rx) = mpsc::channel::<BodyChunk>(8);
357 let handle = tokio::spawn(async move {
359 let _tx = tx;
361 tokio::time::sleep(Duration::from_secs(60 * 60)).await;
363 });
364 let abort_handle = handle.abort_handle();
365 let stream = EventStream::with_abort_handle(rx, abort_handle);
366 drop(stream);
368 let result = tokio::time::timeout(TEST_TIMEOUT, handle)
370 .await
371 .expect("timed out waiting for task abort");
372 assert!(result.is_err(), "task should have been aborted");
373 assert!(
374 result.unwrap_err().is_cancelled(),
375 "task should be cancelled"
376 );
377 }
378
379 #[test]
380 fn debug_output_contains_fields() {
381 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
382 let stream = EventStream::new(rx);
383 let debug = format!("{stream:?}");
384 assert!(debug.contains("EventStream"), "should contain struct name");
385 assert!(debug.contains("done"), "should contain 'done' field");
386 assert!(
387 debug.contains("pending_frames"),
388 "should contain 'pending_frames' field"
389 );
390 }
391
392 #[test]
393 fn is_terminal_returns_false_for_working() {
394 let event = make_status_event(TaskState::Working, false);
395 assert!(!is_terminal(&event), "Working state should not be terminal");
396 }
397
398 #[test]
399 fn is_terminal_returns_true_for_completed() {
400 let event = make_status_event(TaskState::Completed, true);
401 assert!(is_terminal(&event), "Completed state should be terminal");
402 }
403
404 #[tokio::test]
407 async fn stream_decodes_jsonrpc_error_as_protocol_error() {
408 use a2a_protocol_types::{JsonRpcErrorResponse, JsonRpcVersion};
409
410 let (tx, rx) = mpsc::channel(8);
411 let mut stream = EventStream::new(rx);
412
413 let error_resp = JsonRpcErrorResponse {
415 jsonrpc: JsonRpcVersion,
416 id: Some(serde_json::json!(1)),
417 error: a2a_protocol_types::JsonRpcError {
418 code: -32601,
419 message: "method not found".into(),
420 data: None,
421 },
422 };
423 let json = serde_json::to_string(&error_resp).unwrap();
424 let sse_data = format!("data: {json}\n\n");
425 tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
426 drop(tx);
427
428 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
429 .await
430 .expect("timed out")
431 .unwrap();
432 assert!(result.is_err(), "JSON-RPC error should produce Err");
433 match result.unwrap_err() {
434 ClientError::Protocol(err) => {
435 assert!(
436 format!("{err}").contains("method not found"),
437 "error message should be preserved"
438 );
439 }
440 other => panic!("expected Protocol error, got {other:?}"),
441 }
442
443 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
445 .await
446 .expect("timed out");
447 assert!(end.is_none(), "stream should end after JSON-RPC error");
448 }
449
450 #[tokio::test]
453 async fn stream_invalid_json_returns_serialization_error() {
454 let (tx, rx) = mpsc::channel(8);
455 let mut stream = EventStream::new(rx);
456
457 let sse_data = "data: {not valid json}\n\n";
458 tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
459 drop(tx);
460
461 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
462 .await
463 .expect("timed out")
464 .unwrap();
465 assert!(result.is_err(), "invalid JSON should produce Err");
466 assert!(
467 matches!(result.unwrap_err(), ClientError::Serialization(_)),
468 "should be a Serialization error"
469 );
470 }
471
472 #[tokio::test]
475 async fn stream_drains_parser_after_channel_close() {
476 let (tx, rx) = mpsc::channel(8);
477 let mut stream = EventStream::new(rx);
478
479 let event = make_status_event(TaskState::Working, false);
482 let sse_bytes = sse_frame(&event);
483 let (first_half, second_half) = sse_bytes.split_at(sse_bytes.len() / 2);
484
485 tx.send(Ok(Bytes::from(first_half.to_owned())))
486 .await
487 .unwrap();
488 tx.send(Ok(Bytes::from(second_half.to_owned())))
489 .await
490 .unwrap();
491 drop(tx);
492
493 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
494 .await
495 .expect("timed out")
496 .unwrap();
497 let event = result.unwrap();
498 assert!(
499 matches!(event, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working),
500 "should deliver Working event from drained parser"
501 );
502 }
503
504 #[tokio::test]
506 async fn status_code_returns_set_value() {
507 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
508 let stream = EventStream::new(rx);
509 assert_eq!(stream.status_code(), 200, "default status should be 200");
510 }
511
512 #[tokio::test]
514 async fn status_code_with_custom_value() {
515 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
516 let task = tokio::spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
517 let stream = EventStream::with_status(rx, task.abort_handle(), 201);
518 assert_eq!(stream.status_code(), 201);
519 }
520
521 #[tokio::test]
524 async fn stream_transport_error_from_channel() {
525 let (tx, rx) = mpsc::channel(8);
526 let mut stream = EventStream::new(rx);
527
528 tx.send(Err(ClientError::HttpClient("connection reset".into())))
530 .await
531 .unwrap();
532
533 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
534 .await
535 .expect("timed out")
536 .unwrap();
537 match result {
538 Err(ClientError::HttpClient(msg)) => {
539 assert!(msg.contains("connection reset"));
540 }
541 other => panic!("expected HttpClient error, got {other:?}"),
542 }
543
544 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
546 .await
547 .expect("timed out");
548 assert!(end.is_none(), "stream should end after transport error");
549 }
550
551 #[tokio::test]
552 async fn non_terminal_event_does_not_end_stream() {
553 let (tx, rx) = mpsc::channel(8);
554 let mut stream = EventStream::new(rx);
555
556 let working = make_status_event(TaskState::Working, false);
558 let completed = make_status_event(TaskState::Completed, true);
559 tx.send(Ok(Bytes::from(sse_frame(&working)))).await.unwrap();
560 tx.send(Ok(Bytes::from(sse_frame(&completed))))
561 .await
562 .unwrap();
563
564 let first = tokio::time::timeout(TEST_TIMEOUT, stream.next())
566 .await
567 .expect("timed out on first event")
568 .unwrap()
569 .unwrap();
570 assert!(
571 matches!(first, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
572 );
573
574 let second = tokio::time::timeout(TEST_TIMEOUT, stream.next())
576 .await
577 .expect("timed out on second event")
578 .unwrap()
579 .unwrap();
580 assert!(
581 matches!(second, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
582 );
583
584 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
586 .await
587 .expect("timed out waiting for stream end");
588 assert!(end.is_none());
589 }
590}