a2a_protocol_server/streaming/
sse.rs1use std::convert::Infallible;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::Bytes;
17use http_body_util::BodyExt;
18use hyper::body::Frame;
19
20use a2a_protocol_types::jsonrpc::{JsonRpcId, JsonRpcSuccessResponse, JsonRpcVersion};
21
22use crate::streaming::event_queue::{EventQueueReader, InMemoryQueueReader};
23
24pub(crate) const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(30);
26
27pub(crate) const DEFAULT_SSE_CHANNEL_CAPACITY: usize = 64;
29
30#[must_use]
34pub fn write_event(event_type: &str, data: &str) -> Bytes {
35 let mut buf = String::with_capacity(event_type.len() + data.len() + 32);
36 buf.push_str("event: ");
37 buf.push_str(event_type);
38 buf.push('\n');
39 for line in data.lines() {
40 buf.push_str("data: ");
41 buf.push_str(line);
42 buf.push('\n');
43 }
44 buf.push('\n');
45 Bytes::from(buf)
46}
47
48#[must_use]
50pub const fn write_keep_alive() -> Bytes {
51 Bytes::from_static(b": keep-alive\n\n")
52}
53
54#[derive(Debug)]
58pub struct SseBodyWriter {
59 tx: tokio::sync::mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
60}
61
62impl SseBodyWriter {
63 pub async fn send_event(&self, event_type: &str, data: &str) -> Result<(), ()> {
69 let frame = Frame::data(write_event(event_type, data));
70 self.tx.send(Ok(frame)).await.map_err(|_| ())
71 }
72
73 pub async fn send_keep_alive(&self) -> Result<(), ()> {
79 let frame = Frame::data(write_keep_alive());
80 self.tx.send(Ok(frame)).await.map_err(|_| ())
81 }
82
83 pub fn close(self) {
85 drop(self);
86 }
87}
88
89struct ChannelBody {
95 rx: tokio::sync::mpsc::Receiver<Result<Frame<Bytes>, Infallible>>,
96}
97
98impl hyper::body::Body for ChannelBody {
99 type Data = Bytes;
100 type Error = Infallible;
101
102 fn poll_frame(
103 mut self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
106 self.rx.poll_recv(cx)
107 }
108}
109
110#[must_use]
126#[allow(clippy::too_many_lines)]
127pub fn build_sse_response(
128 mut reader: InMemoryQueueReader,
129 keep_alive_interval: Option<Duration>,
130 channel_capacity: Option<usize>,
131 jsonrpc_envelope: bool,
132) -> hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>> {
133 trace_info!("building SSE response stream");
134 let interval = keep_alive_interval.unwrap_or(DEFAULT_KEEP_ALIVE);
135 let cap = channel_capacity.unwrap_or(DEFAULT_SSE_CHANNEL_CAPACITY);
136 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(cap);
137
138 let body_writer = SseBodyWriter { tx };
139
140 tokio::spawn(async move {
141 let mut keep_alive = tokio::time::interval(interval);
142 keep_alive.tick().await;
144
145 loop {
146 tokio::select! {
147 biased;
148
149 event = reader.read() => {
150 match event {
151 Some(Ok(stream_response)) => {
152 let data = if jsonrpc_envelope {
153 let envelope = JsonRpcSuccessResponse {
154 jsonrpc: JsonRpcVersion,
155 id: JsonRpcId::default(),
156 result: stream_response,
157 };
158 serde_json::to_string(&envelope)
159 } else {
160 serde_json::to_string(&stream_response)
162 };
163 let data = match data {
164 Ok(d) => d,
165 Err(e) => {
166 let err_msg = format!("{{\"error\":\"serialization failed: {e}\"}}");
167 let _ = body_writer.send_event("error", &err_msg).await;
168 break;
169 }
170 };
171 if body_writer.send_event("message", &data).await.is_err() {
172 break;
173 }
174 }
175 Some(Err(e)) => {
176 let Ok(data) = serde_json::to_string(&e) else {
177 break;
178 };
179 let _ = body_writer.send_event("error", &data).await;
180 break;
181 }
182 None => break,
183 }
184 }
185 _ = keep_alive.tick() => {
186 if body_writer.send_keep_alive().await.is_err() {
187 break;
188 }
189 }
190 }
191 }
192
193 drop(body_writer);
194 });
195
196 let body = ChannelBody { rx };
197
198 hyper::Response::builder()
199 .status(200)
200 .header("content-type", "text/event-stream")
201 .header("cache-control", "no-cache")
202 .header("transfer-encoding", "chunked")
203 .body(body.boxed())
204 .unwrap_or_else(|_| {
205 hyper::Response::new(
206 http_body_util::Full::new(Bytes::from_static(b"SSE response build error")).boxed(),
207 )
208 })
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
218 fn write_event_single_line_data() {
219 let frame = write_event("message", r#"{"hello":"world"}"#);
220 let expected = "event: message\ndata: {\"hello\":\"world\"}\n\n";
221 assert_eq!(
222 frame,
223 Bytes::from(expected),
224 "single-line data should produce one data: line"
225 );
226 }
227
228 #[test]
229 fn write_event_multiline_data() {
230 let frame = write_event("error", "line1\nline2\nline3");
231 let expected = "event: error\ndata: line1\ndata: line2\ndata: line3\n\n";
232 assert_eq!(
233 frame,
234 Bytes::from(expected),
235 "multiline data should produce separate data: lines"
236 );
237 }
238
239 #[test]
240 fn write_event_empty_data() {
241 let frame = write_event("ping", "");
242 let expected = "event: ping\n\n";
244 assert_eq!(
245 frame,
246 Bytes::from(expected),
247 "empty data should produce no data: lines"
248 );
249 }
250
251 #[test]
252 fn write_event_empty_event_type() {
253 let frame = write_event("", "payload");
254 let expected = "event: \ndata: payload\n\n";
255 assert_eq!(
256 frame,
257 Bytes::from(expected),
258 "empty event type should still produce valid SSE frame"
259 );
260 }
261
262 #[test]
265 fn write_keep_alive_format() {
266 let frame = write_keep_alive();
267 assert_eq!(
268 frame,
269 Bytes::from_static(b": keep-alive\n\n"),
270 "keep-alive should be an SSE comment terminated by double newline"
271 );
272 }
273
274 #[tokio::test]
277 async fn sse_body_writer_send_event_delivers_frame() {
278 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
279 let writer = SseBodyWriter { tx };
280
281 writer
282 .send_event("message", "hello")
283 .await
284 .expect("send_event should succeed while receiver is alive");
285
286 let received = rx.recv().await.expect("should receive a frame");
287 let frame = received.expect("frame result should be Ok");
288 let data = frame.into_data().expect("frame should be a data frame");
289 assert_eq!(
290 data,
291 write_event("message", "hello"),
292 "received frame should match write_event output"
293 );
294 }
295
296 #[tokio::test]
297 async fn sse_body_writer_send_keep_alive_delivers_comment() {
298 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
299 let writer = SseBodyWriter { tx };
300
301 writer
302 .send_keep_alive()
303 .await
304 .expect("send_keep_alive should succeed while receiver is alive");
305
306 let received = rx.recv().await.expect("should receive a frame");
307 let frame = received.expect("frame result should be Ok");
308 let data = frame.into_data().expect("frame should be a data frame");
309 assert_eq!(
310 data,
311 write_keep_alive(),
312 "should receive keep-alive comment"
313 );
314 }
315
316 #[tokio::test]
317 async fn sse_body_writer_send_fails_after_receiver_dropped() {
318 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
319 let writer = SseBodyWriter { tx };
320 drop(rx);
321
322 let result = writer.send_event("message", "data").await;
323 assert!(
324 result.is_err(),
325 "send_event should return Err after receiver is dropped"
326 );
327 }
328
329 #[tokio::test]
330 async fn sse_body_writer_keep_alive_fails_after_receiver_dropped() {
331 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
332 let writer = SseBodyWriter { tx };
333 drop(rx);
334
335 let result = writer.send_keep_alive().await;
336 assert!(
337 result.is_err(),
338 "send_keep_alive should return Err after receiver is dropped"
339 );
340 }
341
342 #[tokio::test]
343 async fn sse_body_writer_close_drops_sender() {
344 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
345 let writer = SseBodyWriter { tx };
346
347 writer.close();
348
349 let result = rx.recv().await;
350 assert!(
351 result.is_none(),
352 "receiver should return None after writer is closed"
353 );
354 }
355
356 #[tokio::test]
359 async fn build_sse_response_has_correct_headers() {
360 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
361
362 let response = build_sse_response(reader, None, None, true);
363
364 assert_eq!(response.status(), 200, "status should be 200 OK");
365 assert_eq!(
366 response
367 .headers()
368 .get("content-type")
369 .map(hyper::http::HeaderValue::as_bytes),
370 Some(b"text/event-stream".as_slice()),
371 "Content-Type should be text/event-stream"
372 );
373 assert_eq!(
374 response
375 .headers()
376 .get("cache-control")
377 .map(hyper::http::HeaderValue::as_bytes),
378 Some(b"no-cache".as_slice()),
379 "Cache-Control should be no-cache"
380 );
381 assert_eq!(
382 response
383 .headers()
384 .get("transfer-encoding")
385 .map(hyper::http::HeaderValue::as_bytes),
386 Some(b"chunked".as_slice()),
387 "Transfer-Encoding should be chunked"
388 );
389 }
390
391 #[tokio::test]
392 async fn build_sse_response_with_custom_keep_alive_and_capacity() {
393 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
395
396 let response = build_sse_response(reader, Some(Duration::from_secs(5)), Some(16), true);
397
398 assert_eq!(response.status(), 200);
399 assert_eq!(
400 response
401 .headers()
402 .get("content-type")
403 .map(hyper::http::HeaderValue::as_bytes),
404 Some(b"text/event-stream".as_slice()),
405 );
406 }
407
408 #[tokio::test]
409 async fn build_sse_response_client_disconnect_stops_stream() {
410 use crate::streaming::event_queue::EventQueueWriter;
412 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
413 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
414
415 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
416
417 let response = build_sse_response(reader, None, None, true);
418
419 drop(response);
421
422 tokio::time::sleep(Duration::from_millis(50)).await;
424
425 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
428 task_id: TaskId::new("t1"),
429 context_id: ContextId::new("c1"),
430 status: TaskStatus {
431 state: TaskState::Working,
432 message: None,
433 timestamp: None,
434 },
435 metadata: None,
436 });
437 let _ = writer.write(event).await;
439 drop(writer);
440 }
441
442 #[tokio::test]
443 async fn build_sse_response_ends_on_reader_close() {
444 use http_body_util::BodyExt;
446
447 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
448
449 drop(writer);
451
452 let mut response = build_sse_response(reader, None, None, true);
453
454 let frame = response.body_mut().frame().await;
456 if let Some(Ok(_)) = frame {
458 let next = response.body_mut().frame().await;
460 assert!(
461 next.is_none() || matches!(next, Some(Ok(_))),
462 "stream should eventually end"
463 );
464 }
465 }
466
467 #[tokio::test]
468 async fn build_sse_response_streams_error_event() {
469 use a2a_protocol_types::error::A2aError;
471 use http_body_util::BodyExt;
472
473 let (tx, rx) = tokio::sync::broadcast::channel(8);
476 let reader = crate::streaming::event_queue::InMemoryQueueReader::new(rx);
477
478 let err = A2aError::internal("something broke");
479 tx.send(Err(err)).expect("send should succeed");
480 drop(tx);
481
482 let mut response = build_sse_response(reader, None, None, true);
483
484 let frame = response
485 .body_mut()
486 .frame()
487 .await
488 .expect("should have a frame")
489 .expect("frame should be Ok");
490 let data = frame.into_data().expect("should be a data frame");
491 let text = String::from_utf8_lossy(&data);
492
493 assert!(
494 text.starts_with("event: error\n"),
495 "error event frame should start with 'event: error\\n', got: {text}"
496 );
497 }
498
499 #[tokio::test]
500 async fn build_sse_response_streams_events() {
501 use crate::streaming::event_queue::EventQueueWriter;
502 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
503 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
504 use http_body_util::BodyExt;
505
506 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
507
508 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
509 task_id: TaskId::new("t1"),
510 context_id: ContextId::new("c1"),
511 status: TaskStatus {
512 state: TaskState::Working,
513 message: None,
514 timestamp: None,
515 },
516 metadata: None,
517 });
518
519 writer.write(event).await.expect("write should succeed");
521 drop(writer);
522
523 let mut response = build_sse_response(reader, None, None, true);
524
525 let frame = response
527 .body_mut()
528 .frame()
529 .await
530 .expect("should have a frame")
531 .expect("frame should be Ok");
532 let data = frame.into_data().expect("should be a data frame");
533 let text = String::from_utf8_lossy(&data);
534
535 assert!(
536 text.starts_with("event: message\n"),
537 "SSE frame should start with 'event: message\\n', got: {text}"
538 );
539 assert!(
540 text.contains("data: "),
541 "SSE frame should contain a data: line"
542 );
543 assert!(
545 text.contains("\"jsonrpc\""),
546 "data should contain JSON-RPC envelope"
547 );
548 assert!(
549 text.contains("\"result\""),
550 "data should contain result field"
551 );
552 }
553}