a2a_protocol_server/streaming/
sse.rs1use std::convert::Infallible;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::Bytes;
17use http_body_util::BodyExt;
18use hyper::body::Frame;
19
20use a2a_protocol_types::jsonrpc::{JsonRpcId, JsonRpcSuccessResponse, JsonRpcVersion};
21
22use crate::streaming::event_queue::{EventQueueReader, InMemoryQueueReader};
23
24pub(crate) const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(30);
26
27pub(crate) const DEFAULT_SSE_CHANNEL_CAPACITY: usize = 64;
29
30#[must_use]
34pub fn write_event(event_type: &str, data: &str) -> Bytes {
35 let mut buf = String::with_capacity(event_type.len() + data.len() + 32);
36 buf.push_str("event: ");
37 buf.push_str(event_type);
38 buf.push('\n');
39 for line in data.lines() {
40 buf.push_str("data: ");
41 buf.push_str(line);
42 buf.push('\n');
43 }
44 buf.push('\n');
45 Bytes::from(buf)
46}
47
48std::thread_local! {
56 static SSE_FRAME_BUF: std::cell::RefCell<Vec<u8>> =
57 std::cell::RefCell::new(Vec::with_capacity(1024));
58}
59
60fn build_sse_message_frame<T: serde::Serialize>(value: &T) -> Result<Bytes, serde_json::Error> {
69 SSE_FRAME_BUF.with(|cell| {
70 let mut buf = cell.borrow_mut();
71 buf.clear();
72 buf.extend_from_slice(b"event: message\ndata: ");
73 serde_json::to_writer(&mut *buf, value)?;
74 buf.extend_from_slice(b"\n\n");
75 Ok(Bytes::from(buf.clone()))
76 })
77}
78
79#[must_use]
81pub const fn write_keep_alive() -> Bytes {
82 Bytes::from_static(b": keep-alive\n\n")
83}
84
85#[derive(Debug)]
89pub struct SseBodyWriter {
90 tx: tokio::sync::mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
91}
92
93impl SseBodyWriter {
94 pub async fn send_event(&self, event_type: &str, data: &str) -> Result<(), ()> {
100 let frame = Frame::data(write_event(event_type, data));
101 self.tx.send(Ok(frame)).await.map_err(|_| ())
102 }
103
104 async fn send_raw_frame(&self, bytes: Bytes) -> Result<(), ()> {
113 let frame = Frame::data(bytes);
114 self.tx.send(Ok(frame)).await.map_err(|_| ())
115 }
116
117 pub async fn send_keep_alive(&self) -> Result<(), ()> {
123 let frame = Frame::data(write_keep_alive());
124 self.tx.send(Ok(frame)).await.map_err(|_| ())
125 }
126
127 pub fn close(self) {
129 drop(self);
130 }
131}
132
133struct ChannelBody {
139 rx: tokio::sync::mpsc::Receiver<Result<Frame<Bytes>, Infallible>>,
140}
141
142impl hyper::body::Body for ChannelBody {
143 type Data = Bytes;
144 type Error = Infallible;
145
146 fn poll_frame(
147 mut self: Pin<&mut Self>,
148 cx: &mut Context<'_>,
149 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
150 self.rx.poll_recv(cx)
151 }
152}
153
154#[must_use]
170#[allow(clippy::too_many_lines)]
171pub fn build_sse_response(
172 mut reader: InMemoryQueueReader,
173 keep_alive_interval: Option<Duration>,
174 channel_capacity: Option<usize>,
175 jsonrpc_envelope: bool,
176) -> hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>> {
177 trace_info!("building SSE response stream");
178 let interval = keep_alive_interval.unwrap_or(DEFAULT_KEEP_ALIVE);
179 let cap = channel_capacity.unwrap_or(DEFAULT_SSE_CHANNEL_CAPACITY);
180 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(cap);
181
182 let body_writer = SseBodyWriter { tx };
183
184 tokio::spawn(async move {
185 tokio::task::yield_now().await;
193
194 let keep_alive_deadline = tokio::time::sleep(interval);
203 tokio::pin!(keep_alive_deadline);
204
205 loop {
206 tokio::select! {
207 biased;
208
209 event = reader.read() => {
210 match event {
211 Some(Ok(stream_response)) => {
212 let frame_bytes = if jsonrpc_envelope {
217 let envelope = JsonRpcSuccessResponse {
218 jsonrpc: JsonRpcVersion,
219 id: JsonRpcId::default(),
220 result: stream_response,
221 };
222 build_sse_message_frame(&envelope)
223 } else {
224 build_sse_message_frame(&stream_response)
226 };
227 let frame_bytes = match frame_bytes {
228 Ok(b) => b,
229 Err(e) => {
230 let err_msg = format!("{{\"error\":\"serialization failed: {e}\"}}");
231 let _ = body_writer.send_event("error", &err_msg).await;
232 break;
233 }
234 };
235 if body_writer.send_raw_frame(frame_bytes).await.is_err() {
236 break;
237 }
238 keep_alive_deadline.as_mut().reset(
240 tokio::time::Instant::now() + interval,
241 );
242 }
243 Some(Err(e)) => {
244 let Ok(data) = serde_json::to_string(&e) else {
245 break;
246 };
247 let _ = body_writer.send_event("error", &data).await;
248 break;
249 }
250 None => break,
251 }
252 }
253 () = &mut keep_alive_deadline => {
254 if body_writer.send_keep_alive().await.is_err() {
255 break;
256 }
257 keep_alive_deadline.as_mut().reset(
258 tokio::time::Instant::now() + interval,
259 );
260 }
261 }
262 }
263
264 drop(body_writer);
265 });
266
267 let body = ChannelBody { rx };
268
269 hyper::Response::builder()
270 .status(200)
271 .header("content-type", "text/event-stream")
272 .header("cache-control", "no-cache")
273 .header("transfer-encoding", "chunked")
274 .body(body.boxed())
275 .unwrap_or_else(|_| {
276 hyper::Response::new(
277 http_body_util::Full::new(Bytes::from_static(b"SSE response build error")).boxed(),
278 )
279 })
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
289 fn write_event_single_line_data() {
290 let frame = write_event("message", r#"{"hello":"world"}"#);
291 let expected = "event: message\ndata: {\"hello\":\"world\"}\n\n";
292 assert_eq!(
293 frame,
294 Bytes::from(expected),
295 "single-line data should produce one data: line"
296 );
297 }
298
299 #[test]
300 fn write_event_multiline_data() {
301 let frame = write_event("error", "line1\nline2\nline3");
302 let expected = "event: error\ndata: line1\ndata: line2\ndata: line3\n\n";
303 assert_eq!(
304 frame,
305 Bytes::from(expected),
306 "multiline data should produce separate data: lines"
307 );
308 }
309
310 #[test]
311 fn write_event_empty_data() {
312 let frame = write_event("ping", "");
313 let expected = "event: ping\n\n";
315 assert_eq!(
316 frame,
317 Bytes::from(expected),
318 "empty data should produce no data: lines"
319 );
320 }
321
322 #[test]
323 fn write_event_empty_event_type() {
324 let frame = write_event("", "payload");
325 let expected = "event: \ndata: payload\n\n";
326 assert_eq!(
327 frame,
328 Bytes::from(expected),
329 "empty event type should still produce valid SSE frame"
330 );
331 }
332
333 #[test]
336 fn write_keep_alive_format() {
337 let frame = write_keep_alive();
338 assert_eq!(
339 frame,
340 Bytes::from_static(b": keep-alive\n\n"),
341 "keep-alive should be an SSE comment terminated by double newline"
342 );
343 }
344
345 #[tokio::test]
348 async fn sse_body_writer_send_event_delivers_frame() {
349 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
350 let writer = SseBodyWriter { tx };
351
352 writer
353 .send_event("message", "hello")
354 .await
355 .expect("send_event should succeed while receiver is alive");
356
357 let received = rx.recv().await.expect("should receive a frame");
358 let frame = received.expect("frame result should be Ok");
359 let data = frame.into_data().expect("frame should be a data frame");
360 assert_eq!(
361 data,
362 write_event("message", "hello"),
363 "received frame should match write_event output"
364 );
365 }
366
367 #[tokio::test]
368 async fn sse_body_writer_send_keep_alive_delivers_comment() {
369 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
370 let writer = SseBodyWriter { tx };
371
372 writer
373 .send_keep_alive()
374 .await
375 .expect("send_keep_alive should succeed while receiver is alive");
376
377 let received = rx.recv().await.expect("should receive a frame");
378 let frame = received.expect("frame result should be Ok");
379 let data = frame.into_data().expect("frame should be a data frame");
380 assert_eq!(
381 data,
382 write_keep_alive(),
383 "should receive keep-alive comment"
384 );
385 }
386
387 #[tokio::test]
388 async fn sse_body_writer_send_fails_after_receiver_dropped() {
389 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
390 let writer = SseBodyWriter { tx };
391 drop(rx);
392
393 let result = writer.send_event("message", "data").await;
394 assert!(
395 result.is_err(),
396 "send_event should return Err after receiver is dropped"
397 );
398 }
399
400 #[tokio::test]
401 async fn sse_body_writer_keep_alive_fails_after_receiver_dropped() {
402 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
403 let writer = SseBodyWriter { tx };
404 drop(rx);
405
406 let result = writer.send_keep_alive().await;
407 assert!(
408 result.is_err(),
409 "send_keep_alive should return Err after receiver is dropped"
410 );
411 }
412
413 #[tokio::test]
414 async fn sse_body_writer_close_drops_sender() {
415 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
416 let writer = SseBodyWriter { tx };
417
418 writer.close();
419
420 let result = rx.recv().await;
421 assert!(
422 result.is_none(),
423 "receiver should return None after writer is closed"
424 );
425 }
426
427 #[tokio::test]
430 async fn build_sse_response_has_correct_headers() {
431 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
432
433 let response = build_sse_response(reader, None, None, true);
434
435 assert_eq!(response.status(), 200, "status should be 200 OK");
436 assert_eq!(
437 response
438 .headers()
439 .get("content-type")
440 .map(hyper::http::HeaderValue::as_bytes),
441 Some(b"text/event-stream".as_slice()),
442 "Content-Type should be text/event-stream"
443 );
444 assert_eq!(
445 response
446 .headers()
447 .get("cache-control")
448 .map(hyper::http::HeaderValue::as_bytes),
449 Some(b"no-cache".as_slice()),
450 "Cache-Control should be no-cache"
451 );
452 assert_eq!(
453 response
454 .headers()
455 .get("transfer-encoding")
456 .map(hyper::http::HeaderValue::as_bytes),
457 Some(b"chunked".as_slice()),
458 "Transfer-Encoding should be chunked"
459 );
460 }
461
462 #[tokio::test]
463 async fn build_sse_response_with_custom_keep_alive_and_capacity() {
464 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
466
467 let response = build_sse_response(reader, Some(Duration::from_secs(5)), Some(16), true);
468
469 assert_eq!(response.status(), 200);
470 assert_eq!(
471 response
472 .headers()
473 .get("content-type")
474 .map(hyper::http::HeaderValue::as_bytes),
475 Some(b"text/event-stream".as_slice()),
476 );
477 }
478
479 #[tokio::test]
480 async fn build_sse_response_client_disconnect_stops_stream() {
481 use crate::streaming::event_queue::EventQueueWriter;
483 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
484 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
485
486 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
487
488 let response = build_sse_response(reader, None, None, true);
489
490 drop(response);
492
493 tokio::time::sleep(Duration::from_millis(50)).await;
495
496 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
499 task_id: TaskId::new("t1"),
500 context_id: ContextId::new("c1"),
501 status: TaskStatus {
502 state: TaskState::Working,
503 message: None,
504 timestamp: None,
505 },
506 metadata: None,
507 });
508 let _ = writer.write(event).await;
510 drop(writer);
511 }
512
513 #[tokio::test]
514 async fn build_sse_response_ends_on_reader_close() {
515 use http_body_util::BodyExt;
517
518 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
519
520 drop(writer);
522
523 let mut response = build_sse_response(reader, None, None, true);
524
525 let frame = response.body_mut().frame().await;
527 if let Some(Ok(_)) = frame {
529 let next = response.body_mut().frame().await;
531 assert!(
532 next.is_none() || matches!(next, Some(Ok(_))),
533 "stream should eventually end"
534 );
535 }
536 }
537
538 #[tokio::test]
539 async fn build_sse_response_streams_error_event() {
540 use a2a_protocol_types::error::A2aError;
542 use http_body_util::BodyExt;
543
544 let (tx, rx) = tokio::sync::broadcast::channel(8);
547 let reader = crate::streaming::event_queue::InMemoryQueueReader::new(rx);
548
549 let err = A2aError::internal("something broke");
550 tx.send(Err(err)).expect("send should succeed");
551 drop(tx);
552
553 let mut response = build_sse_response(reader, None, None, true);
554
555 let frame = response
556 .body_mut()
557 .frame()
558 .await
559 .expect("should have a frame")
560 .expect("frame should be Ok");
561 let data = frame.into_data().expect("should be a data frame");
562 let text = String::from_utf8_lossy(&data);
563
564 assert!(
565 text.starts_with("event: error\n"),
566 "error event frame should start with 'event: error\\n', got: {text}"
567 );
568 }
569
570 #[tokio::test]
571 async fn build_sse_response_streams_events() {
572 use crate::streaming::event_queue::EventQueueWriter;
573 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
574 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
575 use http_body_util::BodyExt;
576
577 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
578
579 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
580 task_id: TaskId::new("t1"),
581 context_id: ContextId::new("c1"),
582 status: TaskStatus {
583 state: TaskState::Working,
584 message: None,
585 timestamp: None,
586 },
587 metadata: None,
588 });
589
590 writer.write(event).await.expect("write should succeed");
592 drop(writer);
593
594 let mut response = build_sse_response(reader, None, None, true);
595
596 let frame = response
598 .body_mut()
599 .frame()
600 .await
601 .expect("should have a frame")
602 .expect("frame should be Ok");
603 let data = frame.into_data().expect("should be a data frame");
604 let text = String::from_utf8_lossy(&data);
605
606 assert!(
607 text.starts_with("event: message\n"),
608 "SSE frame should start with 'event: message\\n', got: {text}"
609 );
610 assert!(
611 text.contains("data: "),
612 "SSE frame should contain a data: line"
613 );
614 assert!(
616 text.contains("\"jsonrpc\""),
617 "data should contain JSON-RPC envelope"
618 );
619 assert!(
620 text.contains("\"result\""),
621 "data should contain result field"
622 );
623 }
624}