a2a_protocol_server/streaming/
sse.rs1use std::convert::Infallible;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::Bytes;
17use http_body_util::BodyExt;
18use hyper::body::Frame;
19
20use a2a_protocol_types::jsonrpc::{JsonRpcId, JsonRpcSuccessResponse, JsonRpcVersion};
21
22use crate::streaming::event_queue::{EventQueueReader, InMemoryQueueReader};
23
24pub(crate) const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(30);
26
27pub(crate) const DEFAULT_SSE_CHANNEL_CAPACITY: usize = 64;
29
30#[must_use]
34pub fn write_event(event_type: &str, data: &str) -> Bytes {
35 let mut buf = String::with_capacity(event_type.len() + data.len() + 32);
36 buf.push_str("event: ");
37 buf.push_str(event_type);
38 buf.push('\n');
39 for line in data.lines() {
40 buf.push_str("data: ");
41 buf.push_str(line);
42 buf.push('\n');
43 }
44 buf.push('\n');
45 Bytes::from(buf)
46}
47
48#[must_use]
50pub const fn write_keep_alive() -> Bytes {
51 Bytes::from_static(b": keep-alive\n\n")
52}
53
54#[derive(Debug)]
58pub struct SseBodyWriter {
59 tx: tokio::sync::mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
60}
61
62impl SseBodyWriter {
63 pub async fn send_event(&self, event_type: &str, data: &str) -> Result<(), ()> {
69 let frame = Frame::data(write_event(event_type, data));
70 self.tx.send(Ok(frame)).await.map_err(|_| ())
71 }
72
73 pub async fn send_keep_alive(&self) -> Result<(), ()> {
79 let frame = Frame::data(write_keep_alive());
80 self.tx.send(Ok(frame)).await.map_err(|_| ())
81 }
82
83 pub fn close(self) {
85 drop(self);
86 }
87}
88
89struct ChannelBody {
95 rx: tokio::sync::mpsc::Receiver<Result<Frame<Bytes>, Infallible>>,
96}
97
98impl hyper::body::Body for ChannelBody {
99 type Data = Bytes;
100 type Error = Infallible;
101
102 fn poll_frame(
103 mut self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
106 self.rx.poll_recv(cx)
107 }
108}
109
110#[must_use]
123#[allow(clippy::too_many_lines)]
124pub fn build_sse_response(
125 mut reader: InMemoryQueueReader,
126 keep_alive_interval: Option<Duration>,
127 channel_capacity: Option<usize>,
128) -> hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>> {
129 trace_info!("building SSE response stream");
130 let interval = keep_alive_interval.unwrap_or(DEFAULT_KEEP_ALIVE);
131 let cap = channel_capacity.unwrap_or(DEFAULT_SSE_CHANNEL_CAPACITY);
132 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(cap);
133
134 let body_writer = SseBodyWriter { tx };
135
136 tokio::spawn(async move {
137 let mut keep_alive = tokio::time::interval(interval);
138 keep_alive.tick().await;
140
141 loop {
142 tokio::select! {
143 biased;
144
145 event = reader.read() => {
146 match event {
147 Some(Ok(stream_response)) => {
148 let envelope = JsonRpcSuccessResponse {
149 jsonrpc: JsonRpcVersion,
150 id: JsonRpcId::default(),
151 result: stream_response,
152 };
153 let data = match serde_json::to_string(&envelope) {
154 Ok(d) => d,
155 Err(e) => {
156 let err_msg = format!("{{\"error\":\"serialization failed: {e}\"}}");
158 let _ = body_writer.send_event("error", &err_msg).await;
159 break;
160 }
161 };
162 if body_writer.send_event("message", &data).await.is_err() {
163 break;
164 }
165 }
166 Some(Err(e)) => {
167 let Ok(data) = serde_json::to_string(&e) else {
168 break;
169 };
170 let _ = body_writer.send_event("error", &data).await;
171 break;
172 }
173 None => break,
174 }
175 }
176 _ = keep_alive.tick() => {
177 if body_writer.send_keep_alive().await.is_err() {
178 break;
179 }
180 }
181 }
182 }
183
184 drop(body_writer);
185 });
186
187 let body = ChannelBody { rx };
188
189 hyper::Response::builder()
190 .status(200)
191 .header("content-type", "text/event-stream")
192 .header("cache-control", "no-cache")
193 .header("transfer-encoding", "chunked")
194 .body(body.boxed())
195 .unwrap_or_else(|_| {
196 hyper::Response::new(
197 http_body_util::Full::new(Bytes::from_static(b"SSE response build error")).boxed(),
198 )
199 })
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
209 fn write_event_single_line_data() {
210 let frame = write_event("message", r#"{"hello":"world"}"#);
211 let expected = "event: message\ndata: {\"hello\":\"world\"}\n\n";
212 assert_eq!(
213 frame,
214 Bytes::from(expected),
215 "single-line data should produce one data: line"
216 );
217 }
218
219 #[test]
220 fn write_event_multiline_data() {
221 let frame = write_event("error", "line1\nline2\nline3");
222 let expected = "event: error\ndata: line1\ndata: line2\ndata: line3\n\n";
223 assert_eq!(
224 frame,
225 Bytes::from(expected),
226 "multiline data should produce separate data: lines"
227 );
228 }
229
230 #[test]
231 fn write_event_empty_data() {
232 let frame = write_event("ping", "");
233 let expected = "event: ping\n\n";
235 assert_eq!(
236 frame,
237 Bytes::from(expected),
238 "empty data should produce no data: lines"
239 );
240 }
241
242 #[test]
243 fn write_event_empty_event_type() {
244 let frame = write_event("", "payload");
245 let expected = "event: \ndata: payload\n\n";
246 assert_eq!(
247 frame,
248 Bytes::from(expected),
249 "empty event type should still produce valid SSE frame"
250 );
251 }
252
253 #[test]
256 fn write_keep_alive_format() {
257 let frame = write_keep_alive();
258 assert_eq!(
259 frame,
260 Bytes::from_static(b": keep-alive\n\n"),
261 "keep-alive should be an SSE comment terminated by double newline"
262 );
263 }
264
265 #[tokio::test]
268 async fn sse_body_writer_send_event_delivers_frame() {
269 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
270 let writer = SseBodyWriter { tx };
271
272 writer
273 .send_event("message", "hello")
274 .await
275 .expect("send_event should succeed while receiver is alive");
276
277 let received = rx.recv().await.expect("should receive a frame");
278 let frame = received.expect("frame result should be Ok");
279 let data = frame.into_data().expect("frame should be a data frame");
280 assert_eq!(
281 data,
282 write_event("message", "hello"),
283 "received frame should match write_event output"
284 );
285 }
286
287 #[tokio::test]
288 async fn sse_body_writer_send_keep_alive_delivers_comment() {
289 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
290 let writer = SseBodyWriter { tx };
291
292 writer
293 .send_keep_alive()
294 .await
295 .expect("send_keep_alive should succeed while receiver is alive");
296
297 let received = rx.recv().await.expect("should receive a frame");
298 let frame = received.expect("frame result should be Ok");
299 let data = frame.into_data().expect("frame should be a data frame");
300 assert_eq!(
301 data,
302 write_keep_alive(),
303 "should receive keep-alive comment"
304 );
305 }
306
307 #[tokio::test]
308 async fn sse_body_writer_send_fails_after_receiver_dropped() {
309 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
310 let writer = SseBodyWriter { tx };
311 drop(rx);
312
313 let result = writer.send_event("message", "data").await;
314 assert!(
315 result.is_err(),
316 "send_event should return Err after receiver is dropped"
317 );
318 }
319
320 #[tokio::test]
321 async fn sse_body_writer_keep_alive_fails_after_receiver_dropped() {
322 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
323 let writer = SseBodyWriter { tx };
324 drop(rx);
325
326 let result = writer.send_keep_alive().await;
327 assert!(
328 result.is_err(),
329 "send_keep_alive should return Err after receiver is dropped"
330 );
331 }
332
333 #[tokio::test]
334 async fn sse_body_writer_close_drops_sender() {
335 let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(8);
336 let writer = SseBodyWriter { tx };
337
338 writer.close();
339
340 let result = rx.recv().await;
341 assert!(
342 result.is_none(),
343 "receiver should return None after writer is closed"
344 );
345 }
346
347 #[tokio::test]
350 async fn build_sse_response_has_correct_headers() {
351 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
352
353 let response = build_sse_response(reader, None, None);
354
355 assert_eq!(response.status(), 200, "status should be 200 OK");
356 assert_eq!(
357 response
358 .headers()
359 .get("content-type")
360 .map(hyper::http::HeaderValue::as_bytes),
361 Some(b"text/event-stream".as_slice()),
362 "Content-Type should be text/event-stream"
363 );
364 assert_eq!(
365 response
366 .headers()
367 .get("cache-control")
368 .map(hyper::http::HeaderValue::as_bytes),
369 Some(b"no-cache".as_slice()),
370 "Cache-Control should be no-cache"
371 );
372 assert_eq!(
373 response
374 .headers()
375 .get("transfer-encoding")
376 .map(hyper::http::HeaderValue::as_bytes),
377 Some(b"chunked".as_slice()),
378 "Transfer-Encoding should be chunked"
379 );
380 }
381
382 #[tokio::test]
383 async fn build_sse_response_with_custom_keep_alive_and_capacity() {
384 let (_writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
386
387 let response = build_sse_response(reader, Some(Duration::from_secs(5)), Some(16));
388
389 assert_eq!(response.status(), 200);
390 assert_eq!(
391 response
392 .headers()
393 .get("content-type")
394 .map(hyper::http::HeaderValue::as_bytes),
395 Some(b"text/event-stream".as_slice()),
396 );
397 }
398
399 #[tokio::test]
400 async fn build_sse_response_client_disconnect_stops_stream() {
401 use crate::streaming::event_queue::EventQueueWriter;
403 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
404 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
405
406 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
407
408 let response = build_sse_response(reader, None, None);
409
410 drop(response);
412
413 tokio::time::sleep(Duration::from_millis(50)).await;
415
416 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
419 task_id: TaskId::new("t1"),
420 context_id: ContextId::new("c1"),
421 status: TaskStatus {
422 state: TaskState::Working,
423 message: None,
424 timestamp: None,
425 },
426 metadata: None,
427 });
428 let _ = writer.write(event).await;
430 drop(writer);
431 }
432
433 #[tokio::test]
434 async fn build_sse_response_ends_on_reader_close() {
435 use http_body_util::BodyExt;
437
438 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
439
440 drop(writer);
442
443 let mut response = build_sse_response(reader, None, None);
444
445 let frame = response.body_mut().frame().await;
447 if let Some(Ok(_)) = frame {
449 let next = response.body_mut().frame().await;
451 assert!(
452 next.is_none() || matches!(next, Some(Ok(_))),
453 "stream should eventually end"
454 );
455 }
456 }
457
458 #[tokio::test]
459 async fn build_sse_response_streams_error_event() {
460 use a2a_protocol_types::error::A2aError;
462 use http_body_util::BodyExt;
463
464 let (tx, rx) = tokio::sync::broadcast::channel(8);
467 let reader = crate::streaming::event_queue::InMemoryQueueReader::new(rx);
468
469 let err = A2aError::internal("something broke");
470 tx.send(Err(err)).expect("send should succeed");
471 drop(tx);
472
473 let mut response = build_sse_response(reader, None, None);
474
475 let frame = response
476 .body_mut()
477 .frame()
478 .await
479 .expect("should have a frame")
480 .expect("frame should be Ok");
481 let data = frame.into_data().expect("should be a data frame");
482 let text = String::from_utf8_lossy(&data);
483
484 assert!(
485 text.starts_with("event: error\n"),
486 "error event frame should start with 'event: error\\n', got: {text}"
487 );
488 }
489
490 #[tokio::test]
491 async fn build_sse_response_streams_events() {
492 use crate::streaming::event_queue::EventQueueWriter;
493 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
494 use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
495 use http_body_util::BodyExt;
496
497 let (writer, reader) = crate::streaming::event_queue::new_in_memory_queue();
498
499 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
500 task_id: TaskId::new("t1"),
501 context_id: ContextId::new("c1"),
502 status: TaskStatus {
503 state: TaskState::Working,
504 message: None,
505 timestamp: None,
506 },
507 metadata: None,
508 });
509
510 writer.write(event).await.expect("write should succeed");
512 drop(writer);
513
514 let mut response = build_sse_response(reader, None, None);
515
516 let frame = response
518 .body_mut()
519 .frame()
520 .await
521 .expect("should have a frame")
522 .expect("frame should be Ok");
523 let data = frame.into_data().expect("should be a data frame");
524 let text = String::from_utf8_lossy(&data);
525
526 assert!(
527 text.starts_with("event: message\n"),
528 "SSE frame should start with 'event: message\\n', got: {text}"
529 );
530 assert!(
531 text.contains("data: "),
532 "SSE frame should contain a data: line"
533 );
534 assert!(
536 text.contains("\"jsonrpc\""),
537 "data should contain JSON-RPC envelope"
538 );
539 assert!(
540 text.contains("\"result\""),
541 "data should contain result field"
542 );
543 }
544}