Skip to main content

a2a_protocol_server/streaming/
sse.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Server-Sent Events (SSE) response builder.
5//!
6//! Builds a `hyper::Response` with `Content-Type: text/event-stream` and
7//! streams events from an [`InMemoryQueueReader`] as SSE frames.
8
9use std::convert::Infallible;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use bytes::Bytes;
15use http_body_util::BodyExt;
16use hyper::body::Frame;
17
18use a2a_protocol_types::jsonrpc::{JsonRpcId, JsonRpcSuccessResponse, JsonRpcVersion};
19
20use crate::streaming::event_queue::{EventQueueReader, InMemoryQueueReader};
21
22/// Default keep-alive interval for SSE streams.
23const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(30);
24
25// ── SSE frame formatting ─────────────────────────────────────────────────────
26
27/// Formats a single SSE frame with the given event type and data.
28#[must_use]
29pub fn write_event(event_type: &str, data: &str) -> Bytes {
30    let mut buf = String::with_capacity(event_type.len() + data.len() + 32);
31    buf.push_str("event: ");
32    buf.push_str(event_type);
33    buf.push('\n');
34    for line in data.lines() {
35        buf.push_str("data: ");
36        buf.push_str(line);
37        buf.push('\n');
38    }
39    buf.push('\n');
40    Bytes::from(buf)
41}
42
43/// Formats a keep-alive SSE comment.
44#[must_use]
45pub const fn write_keep_alive() -> Bytes {
46    Bytes::from_static(b": keep-alive\n\n")
47}
48
49// ── SseBodyWriter ────────────────────────────────────────────────────────────
50
51/// Wraps an `mpsc::Sender` for writing SSE frames to a response body.
52#[derive(Debug)]
53pub struct SseBodyWriter {
54    tx: tokio::sync::mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
55}
56
57impl SseBodyWriter {
58    /// Sends an SSE event frame.
59    ///
60    /// # Errors
61    ///
62    /// Returns `Err(())` if the receiver has been dropped (client disconnected).
63    pub async fn send_event(&self, event_type: &str, data: &str) -> Result<(), ()> {
64        let frame = Frame::data(write_event(event_type, data));
65        self.tx.send(Ok(frame)).await.map_err(|_| ())
66    }
67
68    /// Sends a keep-alive comment.
69    ///
70    /// # Errors
71    ///
72    /// Returns `Err(())` if the receiver has been dropped.
73    pub async fn send_keep_alive(&self) -> Result<(), ()> {
74        let frame = Frame::data(write_keep_alive());
75        self.tx.send(Ok(frame)).await.map_err(|_| ())
76    }
77
78    /// Closes the SSE stream by dropping the sender.
79    pub fn close(self) {
80        drop(self);
81    }
82}
83
84// ── ChannelBody ──────────────────────────────────────────────────────────────
85
86/// A `hyper::body::Body` implementation backed by an `mpsc::Receiver`.
87///
88/// This allows streaming SSE frames through hyper's response pipeline.
89struct ChannelBody {
90    rx: tokio::sync::mpsc::Receiver<Result<Frame<Bytes>, Infallible>>,
91}
92
93impl hyper::body::Body for ChannelBody {
94    type Data = Bytes;
95    type Error = Infallible;
96
97    fn poll_frame(
98        mut self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
101        self.rx.poll_recv(cx)
102    }
103}
104
105// ── build_sse_response ───────────────────────────────────────────────────────
106
107/// Builds an SSE streaming response from an event queue reader.
108///
109/// Each event is wrapped in a JSON-RPC 2.0 success response envelope so that
110/// clients can uniformly parse SSE frames regardless of transport binding.
111///
112/// Spawns a background task that:
113/// 1. Reads events from `reader` and serializes them as SSE `message` frames.
114/// 2. Sends periodic keep-alive comments at the specified interval.
115///
116/// The keep-alive ticker is cancelled when the reader is exhausted.
117#[must_use]
118#[allow(clippy::too_many_lines)]
119pub fn build_sse_response(
120    mut reader: InMemoryQueueReader,
121    keep_alive_interval: Option<Duration>,
122) -> hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>> {
123    trace_info!("building SSE response stream");
124    let interval = keep_alive_interval.unwrap_or(DEFAULT_KEEP_ALIVE);
125    let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(64);
126
127    let body_writer = SseBodyWriter { tx };
128
129    tokio::spawn(async move {
130        let mut keep_alive = tokio::time::interval(interval);
131        // The first tick fires immediately; skip it.
132        keep_alive.tick().await;
133
134        loop {
135            tokio::select! {
136                biased;
137
138                event = reader.read() => {
139                    match event {
140                        Some(Ok(stream_response)) => {
141                            let envelope = JsonRpcSuccessResponse {
142                                jsonrpc: JsonRpcVersion,
143                                id: JsonRpcId::default(),
144                                result: stream_response,
145                            };
146                            let data = match serde_json::to_string(&envelope) {
147                                Ok(d) => d,
148                                Err(e) => {
149                                    // PR-6: Send error event before closing.
150                                    let err_msg = format!("{{\"error\":\"serialization failed: {e}\"}}");
151                                    let _ = body_writer.send_event("error", &err_msg).await;
152                                    break;
153                                }
154                            };
155                            if body_writer.send_event("message", &data).await.is_err() {
156                                break;
157                            }
158                        }
159                        Some(Err(e)) => {
160                            let Ok(data) = serde_json::to_string(&e) else {
161                                break;
162                            };
163                            let _ = body_writer.send_event("error", &data).await;
164                            break;
165                        }
166                        None => break,
167                    }
168                }
169                _ = keep_alive.tick() => {
170                    if body_writer.send_keep_alive().await.is_err() {
171                        break;
172                    }
173                }
174            }
175        }
176
177        drop(body_writer);
178    });
179
180    let body = ChannelBody { rx };
181
182    hyper::Response::builder()
183        .status(200)
184        .header("content-type", "text/event-stream")
185        .header("cache-control", "no-cache")
186        .header("transfer-encoding", "chunked")
187        .body(body.boxed())
188        .unwrap_or_else(|_| {
189            hyper::Response::new(
190                http_body_util::Full::new(Bytes::from_static(b"SSE response build error")).boxed(),
191            )
192        })
193}