a2a_protocol_server/streaming/
sse.rs1use std::convert::Infallible;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use bytes::Bytes;
15use http_body_util::BodyExt;
16use hyper::body::Frame;
17
18use a2a_protocol_types::jsonrpc::{JsonRpcId, JsonRpcSuccessResponse, JsonRpcVersion};
19
20use crate::streaming::event_queue::{EventQueueReader, InMemoryQueueReader};
21
22const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(30);
24
25#[must_use]
29pub fn write_event(event_type: &str, data: &str) -> Bytes {
30 let mut buf = String::with_capacity(event_type.len() + data.len() + 32);
31 buf.push_str("event: ");
32 buf.push_str(event_type);
33 buf.push('\n');
34 for line in data.lines() {
35 buf.push_str("data: ");
36 buf.push_str(line);
37 buf.push('\n');
38 }
39 buf.push('\n');
40 Bytes::from(buf)
41}
42
43#[must_use]
45pub const fn write_keep_alive() -> Bytes {
46 Bytes::from_static(b": keep-alive\n\n")
47}
48
49#[derive(Debug)]
53pub struct SseBodyWriter {
54 tx: tokio::sync::mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
55}
56
57impl SseBodyWriter {
58 pub async fn send_event(&self, event_type: &str, data: &str) -> Result<(), ()> {
64 let frame = Frame::data(write_event(event_type, data));
65 self.tx.send(Ok(frame)).await.map_err(|_| ())
66 }
67
68 pub async fn send_keep_alive(&self) -> Result<(), ()> {
74 let frame = Frame::data(write_keep_alive());
75 self.tx.send(Ok(frame)).await.map_err(|_| ())
76 }
77
78 pub fn close(self) {
80 drop(self);
81 }
82}
83
84struct ChannelBody {
90 rx: tokio::sync::mpsc::Receiver<Result<Frame<Bytes>, Infallible>>,
91}
92
93impl hyper::body::Body for ChannelBody {
94 type Data = Bytes;
95 type Error = Infallible;
96
97 fn poll_frame(
98 mut self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
101 self.rx.poll_recv(cx)
102 }
103}
104
105#[must_use]
118#[allow(clippy::too_many_lines)]
119pub fn build_sse_response(
120 mut reader: InMemoryQueueReader,
121 keep_alive_interval: Option<Duration>,
122) -> hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>> {
123 trace_info!("building SSE response stream");
124 let interval = keep_alive_interval.unwrap_or(DEFAULT_KEEP_ALIVE);
125 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(64);
126
127 let body_writer = SseBodyWriter { tx };
128
129 tokio::spawn(async move {
130 let mut keep_alive = tokio::time::interval(interval);
131 keep_alive.tick().await;
133
134 loop {
135 tokio::select! {
136 biased;
137
138 event = reader.read() => {
139 match event {
140 Some(Ok(stream_response)) => {
141 let envelope = JsonRpcSuccessResponse {
142 jsonrpc: JsonRpcVersion,
143 id: JsonRpcId::default(),
144 result: stream_response,
145 };
146 let data = match serde_json::to_string(&envelope) {
147 Ok(d) => d,
148 Err(e) => {
149 let err_msg = format!("{{\"error\":\"serialization failed: {e}\"}}");
151 let _ = body_writer.send_event("error", &err_msg).await;
152 break;
153 }
154 };
155 if body_writer.send_event("message", &data).await.is_err() {
156 break;
157 }
158 }
159 Some(Err(e)) => {
160 let Ok(data) = serde_json::to_string(&e) else {
161 break;
162 };
163 let _ = body_writer.send_event("error", &data).await;
164 break;
165 }
166 None => break,
167 }
168 }
169 _ = keep_alive.tick() => {
170 if body_writer.send_keep_alive().await.is_err() {
171 break;
172 }
173 }
174 }
175 }
176
177 drop(body_writer);
178 });
179
180 let body = ChannelBody { rx };
181
182 hyper::Response::builder()
183 .status(200)
184 .header("content-type", "text/event-stream")
185 .header("cache-control", "no-cache")
186 .header("transfer-encoding", "chunked")
187 .body(body.boxed())
188 .unwrap_or_else(|_| {
189 hyper::Response::new(
190 http_body_util::Full::new(Bytes::from_static(b"SSE response build error")).boxed(),
191 )
192 })
193}