a2a_protocol_client/streaming/
event_stream.rs1use a2a_protocol_types::{JsonRpcResponse, StreamResponse};
35use hyper::body::Bytes;
36use tokio::sync::mpsc;
37use tokio::task::AbortHandle;
38
39use crate::error::{ClientError, ClientResult};
40use crate::streaming::sse_parser::SseParser;
41
42pub(crate) type BodyChunk = ClientResult<Bytes>;
46
47pub struct EventStream {
58 rx: mpsc::Receiver<BodyChunk>,
60 parser: SseParser,
62 done: bool,
64 abort_handle: Option<AbortHandle>,
66 status_code: u16,
73 jsonrpc_envelope: bool,
79}
80
81impl EventStream {
82 #[must_use]
88 #[cfg(any(test, feature = "websocket"))]
89 pub(crate) fn new(rx: mpsc::Receiver<BodyChunk>) -> Self {
90 Self {
91 rx,
92 parser: SseParser::new(),
93 done: false,
94 abort_handle: None,
95 status_code: 200,
96 jsonrpc_envelope: true,
97 }
98 }
99
100 #[must_use]
105 #[cfg(test)]
106 pub(crate) fn with_abort_handle(
107 rx: mpsc::Receiver<BodyChunk>,
108 abort_handle: AbortHandle,
109 ) -> Self {
110 Self {
111 rx,
112 parser: SseParser::new(),
113 done: false,
114 abort_handle: Some(abort_handle),
115 status_code: 200,
116 jsonrpc_envelope: true,
117 }
118 }
119
120 #[must_use]
123 pub(crate) fn with_status(
124 rx: mpsc::Receiver<BodyChunk>,
125 abort_handle: AbortHandle,
126 status_code: u16,
127 ) -> Self {
128 Self {
129 rx,
130 parser: SseParser::new(),
131 done: false,
132 abort_handle: Some(abort_handle),
133 status_code,
134 jsonrpc_envelope: true,
135 }
136 }
137
138 #[must_use]
143 pub(crate) const fn with_jsonrpc_envelope(mut self, envelope: bool) -> Self {
144 self.jsonrpc_envelope = envelope;
145 self
146 }
147
148 #[must_use]
153 pub const fn status_code(&self) -> u16 {
154 self.status_code
155 }
156
157 pub async fn next(&mut self) -> Option<ClientResult<StreamResponse>> {
164 loop {
165 if let Some(result) = self.parser.next_frame() {
167 match result {
168 Ok(frame) => return Some(self.decode_frame(&frame.data)),
169 Err(e) => {
170 return Some(Err(ClientError::Transport(e.to_string())));
171 }
172 }
173 }
174
175 if self.done {
176 return None;
177 }
178
179 match self.rx.recv().await {
181 None => {
182 self.done = true;
184 if let Some(result) = self.parser.next_frame() {
186 match result {
187 Ok(frame) => return Some(self.decode_frame(&frame.data)),
188 Err(e) => {
189 return Some(Err(ClientError::Transport(e.to_string())));
190 }
191 }
192 }
193 return None;
194 }
195 Some(Err(e)) => {
196 self.done = true;
197 return Some(Err(e));
198 }
199 Some(Ok(bytes)) => {
200 self.parser.feed(&bytes);
201 }
202 }
203 }
204 }
205
206 fn decode_frame(&mut self, data: &str) -> ClientResult<StreamResponse> {
209 if self.jsonrpc_envelope {
210 let envelope: JsonRpcResponse<StreamResponse> =
212 serde_json::from_str(data).map_err(ClientError::Serialization)?;
213
214 match envelope {
215 JsonRpcResponse::Success(ok) => {
216 if is_terminal(&ok.result) {
217 self.done = true;
218 }
219 Ok(ok.result)
220 }
221 JsonRpcResponse::Error(err) => {
222 self.done = true;
223 let a2a = a2a_protocol_types::A2aError::new(
224 a2a_protocol_types::ErrorCode::try_from(err.error.code)
225 .unwrap_or(a2a_protocol_types::ErrorCode::InternalError),
226 err.error.message,
227 );
228 Err(ClientError::Protocol(a2a))
229 }
230 }
231 } else {
232 let event: StreamResponse =
235 serde_json::from_str(data).map_err(ClientError::Serialization)?;
236 if is_terminal(&event) {
237 self.done = true;
238 }
239 Ok(event)
240 }
241 }
242}
243
244impl Drop for EventStream {
245 fn drop(&mut self) {
246 if let Some(handle) = self.abort_handle.take() {
247 handle.abort();
248 }
249 }
250}
251
252#[allow(clippy::missing_fields_in_debug)]
253impl std::fmt::Debug for EventStream {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("EventStream")
257 .field("done", &self.done)
258 .field("pending_frames", &self.parser.pending_count())
259 .finish()
260 }
261}
262
263const fn is_terminal(event: &StreamResponse) -> bool {
265 matches!(
266 event,
267 StreamResponse::StatusUpdate(ev) if ev.status.state.is_terminal()
268 )
269}
270
271#[cfg(test)]
274mod tests {
275 use super::*;
276 use a2a_protocol_types::{
277 JsonRpcSuccessResponse, JsonRpcVersion, TaskId, TaskState, TaskStatus,
278 TaskStatusUpdateEvent,
279 };
280 use std::time::Duration;
281
282 const TEST_TIMEOUT: Duration = Duration::from_secs(5);
285
286 fn make_status_event(state: TaskState, _is_final: bool) -> StreamResponse {
287 StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
288 task_id: TaskId::new("t1"),
289 context_id: a2a_protocol_types::ContextId::new("c1"),
290 status: TaskStatus {
291 state,
292 message: None,
293 timestamp: None,
294 },
295 metadata: None,
296 })
297 }
298
299 fn sse_frame(event: &StreamResponse) -> String {
300 let resp = JsonRpcSuccessResponse {
301 jsonrpc: JsonRpcVersion,
302 id: Some(serde_json::json!(1)),
303 result: event.clone(),
304 };
305 let json = serde_json::to_string(&resp).unwrap();
306 format!("data: {json}\n\n")
307 }
308
309 #[tokio::test]
310 async fn stream_delivers_events() {
311 let (tx, rx) = mpsc::channel(8);
312 let mut stream = EventStream::new(rx);
313
314 let event = make_status_event(TaskState::Working, false);
315 let sse_bytes = sse_frame(&event);
316 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
317 drop(tx);
318
319 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
320 .await
321 .expect("timed out")
322 .unwrap()
323 .unwrap();
324 assert!(
325 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
326 );
327 }
328
329 #[tokio::test]
330 async fn stream_ends_on_final_event() {
331 let (tx, rx) = mpsc::channel(8);
332 let mut stream = EventStream::new(rx);
333
334 let event = make_status_event(TaskState::Completed, true);
335 let sse_bytes = sse_frame(&event);
336 tx.send(Ok(Bytes::from(sse_bytes))).await.unwrap();
337
338 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
340 .await
341 .expect("timed out waiting for final event")
342 .unwrap()
343 .unwrap();
344 assert!(
345 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
346 );
347
348 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
350 .await
351 .expect("timed out waiting for stream end");
352 assert!(end.is_none());
353 }
354
355 #[tokio::test]
356 async fn stream_propagates_body_error() {
357 let (tx, rx) = mpsc::channel(8);
358 let mut stream = EventStream::new(rx);
359
360 tx.send(Err(ClientError::Transport("network error".into())))
361 .await
362 .unwrap();
363
364 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
365 .await
366 .expect("timed out")
367 .unwrap();
368 assert!(result.is_err());
369 }
370
371 #[tokio::test]
372 async fn stream_ends_when_channel_closed() {
373 let (tx, rx) = mpsc::channel(8);
374 let mut stream = EventStream::new(rx);
375 drop(tx);
376
377 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
378 .await
379 .expect("timed out");
380 assert!(result.is_none());
381 }
382
383 #[tokio::test]
384 async fn drop_aborts_background_task() {
385 let (tx, rx) = mpsc::channel::<BodyChunk>(8);
386 let handle = tokio::spawn(async move {
388 let _tx = tx;
390 tokio::time::sleep(Duration::from_secs(60 * 60)).await;
392 });
393 let abort_handle = handle.abort_handle();
394 let stream = EventStream::with_abort_handle(rx, abort_handle);
395 drop(stream);
397 let result = tokio::time::timeout(TEST_TIMEOUT, handle)
399 .await
400 .expect("timed out waiting for task abort");
401 assert!(result.is_err(), "task should have been aborted");
402 assert!(
403 result.unwrap_err().is_cancelled(),
404 "task should be cancelled"
405 );
406 }
407
408 #[test]
409 fn debug_output_contains_fields() {
410 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
411 let stream = EventStream::new(rx);
412 let debug = format!("{stream:?}");
413 assert!(debug.contains("EventStream"), "should contain struct name");
414 assert!(debug.contains("done"), "should contain 'done' field");
415 assert!(
416 debug.contains("pending_frames"),
417 "should contain 'pending_frames' field"
418 );
419 }
420
421 #[test]
422 fn is_terminal_returns_false_for_working() {
423 let event = make_status_event(TaskState::Working, false);
424 assert!(!is_terminal(&event), "Working state should not be terminal");
425 }
426
427 #[test]
428 fn is_terminal_returns_true_for_completed() {
429 let event = make_status_event(TaskState::Completed, true);
430 assert!(is_terminal(&event), "Completed state should be terminal");
431 }
432
433 #[tokio::test]
436 async fn stream_decodes_jsonrpc_error_as_protocol_error() {
437 use a2a_protocol_types::{JsonRpcErrorResponse, JsonRpcVersion};
438
439 let (tx, rx) = mpsc::channel(8);
440 let mut stream = EventStream::new(rx);
441
442 let error_resp = JsonRpcErrorResponse {
444 jsonrpc: JsonRpcVersion,
445 id: Some(serde_json::json!(1)),
446 error: a2a_protocol_types::JsonRpcError {
447 code: -32601,
448 message: "method not found".into(),
449 data: None,
450 },
451 };
452 let json = serde_json::to_string(&error_resp).unwrap();
453 let sse_data = format!("data: {json}\n\n");
454 tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
455 drop(tx);
456
457 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
458 .await
459 .expect("timed out")
460 .unwrap();
461 assert!(result.is_err(), "JSON-RPC error should produce Err");
462 match result.unwrap_err() {
463 ClientError::Protocol(err) => {
464 assert!(
465 format!("{err}").contains("method not found"),
466 "error message should be preserved"
467 );
468 }
469 other => panic!("expected Protocol error, got {other:?}"),
470 }
471
472 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
474 .await
475 .expect("timed out");
476 assert!(end.is_none(), "stream should end after JSON-RPC error");
477 }
478
479 #[tokio::test]
482 async fn stream_invalid_json_returns_serialization_error() {
483 let (tx, rx) = mpsc::channel(8);
484 let mut stream = EventStream::new(rx);
485
486 let sse_data = "data: {not valid json}\n\n";
487 tx.send(Ok(Bytes::from(sse_data))).await.unwrap();
488 drop(tx);
489
490 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
491 .await
492 .expect("timed out")
493 .unwrap();
494 assert!(result.is_err(), "invalid JSON should produce Err");
495 assert!(
496 matches!(result.unwrap_err(), ClientError::Serialization(_)),
497 "should be a Serialization error"
498 );
499 }
500
501 #[tokio::test]
504 async fn stream_drains_parser_after_channel_close() {
505 let (tx, rx) = mpsc::channel(8);
506 let mut stream = EventStream::new(rx);
507
508 let event = make_status_event(TaskState::Working, false);
511 let sse_bytes = sse_frame(&event);
512 let (first_half, second_half) = sse_bytes.split_at(sse_bytes.len() / 2);
513
514 tx.send(Ok(Bytes::from(first_half.to_owned())))
515 .await
516 .unwrap();
517 tx.send(Ok(Bytes::from(second_half.to_owned())))
518 .await
519 .unwrap();
520 drop(tx);
521
522 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
523 .await
524 .expect("timed out")
525 .unwrap();
526 let event = result.unwrap();
527 assert!(
528 matches!(event, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working),
529 "should deliver Working event from drained parser"
530 );
531 }
532
533 #[tokio::test]
535 async fn status_code_returns_set_value() {
536 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
537 let stream = EventStream::new(rx);
538 assert_eq!(stream.status_code(), 200, "default status should be 200");
539 }
540
541 #[tokio::test]
543 async fn status_code_with_custom_value() {
544 let (_tx, rx) = mpsc::channel::<BodyChunk>(8);
545 let task = tokio::spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
546 let stream = EventStream::with_status(rx, task.abort_handle(), 201);
547 assert_eq!(stream.status_code(), 201);
548 }
549
550 #[tokio::test]
553 async fn stream_transport_error_from_channel() {
554 let (tx, rx) = mpsc::channel(8);
555 let mut stream = EventStream::new(rx);
556
557 tx.send(Err(ClientError::HttpClient("connection reset".into())))
559 .await
560 .unwrap();
561
562 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
563 .await
564 .expect("timed out")
565 .unwrap();
566 match result {
567 Err(ClientError::HttpClient(msg)) => {
568 assert!(msg.contains("connection reset"));
569 }
570 other => panic!("expected HttpClient error, got {other:?}"),
571 }
572
573 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
575 .await
576 .expect("timed out");
577 assert!(end.is_none(), "stream should end after transport error");
578 }
579
580 #[tokio::test]
581 async fn non_terminal_event_does_not_end_stream() {
582 let (tx, rx) = mpsc::channel(8);
583 let mut stream = EventStream::new(rx);
584
585 let working = make_status_event(TaskState::Working, false);
587 let completed = make_status_event(TaskState::Completed, true);
588 tx.send(Ok(Bytes::from(sse_frame(&working)))).await.unwrap();
589 tx.send(Ok(Bytes::from(sse_frame(&completed))))
590 .await
591 .unwrap();
592
593 let first = tokio::time::timeout(TEST_TIMEOUT, stream.next())
595 .await
596 .expect("timed out on first event")
597 .unwrap()
598 .unwrap();
599 assert!(
600 matches!(first, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
601 );
602
603 let second = tokio::time::timeout(TEST_TIMEOUT, stream.next())
605 .await
606 .expect("timed out on second event")
607 .unwrap()
608 .unwrap();
609 assert!(
610 matches!(second, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
611 );
612
613 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
615 .await
616 .expect("timed out waiting for stream end");
617 assert!(end.is_none());
618 }
619
620 fn bare_sse_frame(event: &StreamResponse) -> String {
624 let json = serde_json::to_string(event).unwrap();
625 format!("data: {json}\n\n")
626 }
627
628 #[tokio::test]
629 async fn bare_stream_delivers_events() {
630 let (tx, rx) = mpsc::channel(8);
631 let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
632
633 let event = make_status_event(TaskState::Working, false);
634 tx.send(Ok(Bytes::from(bare_sse_frame(&event))))
635 .await
636 .unwrap();
637 drop(tx);
638
639 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
640 .await
641 .expect("timed out")
642 .unwrap()
643 .unwrap();
644 assert!(
645 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
646 );
647 }
648
649 #[tokio::test]
650 async fn bare_stream_ends_on_terminal() {
651 let (tx, rx) = mpsc::channel(8);
652 let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
653
654 let event = make_status_event(TaskState::Completed, true);
655 tx.send(Ok(Bytes::from(bare_sse_frame(&event))))
656 .await
657 .unwrap();
658
659 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
660 .await
661 .expect("timed out")
662 .unwrap()
663 .unwrap();
664 assert!(
665 matches!(result, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
666 );
667
668 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
669 .await
670 .expect("timed out");
671 assert!(end.is_none(), "bare stream should end after terminal event");
672 }
673
674 #[tokio::test]
675 async fn bare_stream_rejects_jsonrpc_envelope() {
676 let (tx, rx) = mpsc::channel(8);
677 let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
678
679 let event = make_status_event(TaskState::Working, false);
681 let envelope_frame = sse_frame(&event); tx.send(Ok(Bytes::from(envelope_frame))).await.unwrap();
683 drop(tx);
684
685 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
686 .await
687 .expect("timed out")
688 .unwrap();
689 assert!(
690 result.is_err(),
691 "bare stream should reject JSON-RPC envelope as invalid"
692 );
693 }
694
695 #[tokio::test]
696 async fn envelope_stream_rejects_bare_response() {
697 let (tx, rx) = mpsc::channel(8);
698 let mut stream = EventStream::new(rx); let event = make_status_event(TaskState::Working, false);
702 let bare_frame = bare_sse_frame(&event);
703 tx.send(Ok(Bytes::from(bare_frame))).await.unwrap();
704 drop(tx);
705
706 let result = tokio::time::timeout(TEST_TIMEOUT, stream.next())
707 .await
708 .expect("timed out")
709 .unwrap();
710 assert!(
711 result.is_err(),
712 "envelope stream should reject bare StreamResponse"
713 );
714 }
715
716 #[tokio::test]
717 async fn bare_stream_multiple_events() {
718 let (tx, rx) = mpsc::channel(8);
719 let mut stream = EventStream::new(rx).with_jsonrpc_envelope(false);
720
721 let working = make_status_event(TaskState::Working, false);
722 let completed = make_status_event(TaskState::Completed, true);
723 tx.send(Ok(Bytes::from(bare_sse_frame(&working))))
724 .await
725 .unwrap();
726 tx.send(Ok(Bytes::from(bare_sse_frame(&completed))))
727 .await
728 .unwrap();
729
730 let first = tokio::time::timeout(TEST_TIMEOUT, stream.next())
731 .await
732 .expect("timed out")
733 .unwrap()
734 .unwrap();
735 assert!(
736 matches!(first, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Working)
737 );
738
739 let second = tokio::time::timeout(TEST_TIMEOUT, stream.next())
740 .await
741 .expect("timed out")
742 .unwrap()
743 .unwrap();
744 assert!(
745 matches!(second, StreamResponse::StatusUpdate(ref ev) if ev.status.state == TaskState::Completed)
746 );
747
748 let end = tokio::time::timeout(TEST_TIMEOUT, stream.next())
749 .await
750 .expect("timed out");
751 assert!(end.is_none());
752 }
753}