1use super::PlaneClientError;
2use crate::exponential_backoff::ExponentialBackoff;
3use reqwest::{
4 header::{HeaderValue, ACCEPT, CONNECTION},
5 Client, Response,
6};
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use url::Url;
10
11struct RawSseStream {
12 response: Response,
13 buffer: Vec<u8>,
15 data: Option<Vec<u8>>,
17 id: Option<String>,
18}
19
20impl RawSseStream {
21 fn new(response: Response) -> Self {
22 Self {
23 response,
24 buffer: Vec::new(),
25 data: None,
26 id: None,
27 }
28 }
29
30 async fn next(&mut self) -> Option<(Option<String>, Vec<u8>)> {
31 loop {
32 let chunk = match self.response.chunk().await {
33 Ok(Some(chunk)) => chunk,
34 Ok(None) => return None,
35 Err(err) => {
36 tracing::error!(?err, "Error reading SSE stream.");
37 return None;
38 }
39 };
40 let mut chunk = chunk.as_ref();
41
42 while let Some(newline_idx) = chunk.iter().position(|&b| b == b'\n') {
44 let current_line = &chunk[..newline_idx];
45 chunk = &chunk[newline_idx + 1..];
46
47 let mut buffer = std::mem::take(&mut self.buffer);
49 buffer.extend_from_slice(current_line);
50
51 if let Some(result) = buffer.strip_prefix(b"data:") {
52 match self.data {
53 Some(ref mut data) => {
54 data.push(b'\n');
58 data.extend_from_slice(result)
59 }
60 None => self.data = Some(result.to_vec()),
61 }
62 } else if let Some(result) = buffer.strip_prefix(b"id:") {
63 let id = match String::from_utf8(result.to_vec()) {
64 Ok(id) => id,
65 Err(err) => {
66 tracing::error!(?err, "Error parsing SSE stream ID.");
67 continue;
68 }
69 };
70 self.id = Some(id);
71 } else if buffer.is_empty() && self.data.is_some() {
72 let data = self.data.take().unwrap_or_default();
73 return Some((self.id.take(), data));
74 }
75 }
76
77 self.buffer.extend_from_slice(chunk);
79 }
80 }
81}
82
83pub struct SseStream<T: DeserializeOwned> {
84 url: Url,
85 client: Client,
86 stream: Option<RawSseStream>,
87 backoff: ExponentialBackoff,
88 last_id: Option<String>,
89 _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> SseStream<T> {
93 fn new(url: Url, client: Client) -> Self {
94 Self {
95 url,
96 client,
97 stream: None,
98 backoff: ExponentialBackoff::default(),
99 last_id: None,
100 _phantom: PhantomData,
101 }
102 }
103
104 async fn ensure_stream(&mut self) -> Result<(), PlaneClientError> {
105 if self.stream.is_none() {
106 let mut request = self
107 .client
108 .get(self.url.clone())
109 .header(ACCEPT, HeaderValue::from_static("text/event-stream"))
110 .header(CONNECTION, HeaderValue::from_static("keep-alive"));
111
112 if let Some(id) = &self.last_id {
113 request = request.header("Last-Event-ID", id);
114 }
115
116 let response = request.send().await?;
117
118 if response.status() != 200 {
119 let status = response.status();
120 return Err(PlaneClientError::UnexpectedStatus(status));
121 }
122
123 self.stream = Some(RawSseStream::new(response));
124 return Ok(());
125 }
126
127 Ok(())
128 }
129
130 pub async fn next(&mut self) -> Option<T> {
131 loop {
132 if let Err(err) = self.ensure_stream().await {
133 tracing::error!(?err, "Error connecting to SSE stream.");
134 self.backoff.wait().await;
135 continue;
136 }
137
138 let stream = self.stream.as_mut().expect("Stream is always Some.");
140 self.backoff.defer_reset();
141
142 let (id, data) = match stream.next().await {
143 Some(data) => data,
144 None => {
145 self.stream = None;
146 continue;
147 }
148 };
149
150 self.last_id = id;
151
152 match serde_json::from_slice(&data) {
153 Ok(value) => return Some(value),
154 Err(err) => {
155 let typ = std::any::type_name::<T>();
156 tracing::error!(?err, typ, "Failed to parse SSE data as type.");
157 continue;
158 }
159 }
160 }
161 }
162}
163
164pub async fn sse_request<T: DeserializeOwned>(
165 url: Url,
166 client: Client,
167) -> Result<SseStream<T>, PlaneClientError> {
168 let mut stream = SseStream::new(url, client);
169 stream.ensure_stream().await?;
170 Ok(stream)
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use async_stream::stream;
177 use axum::{
178 extract::State,
179 http::HeaderMap,
180 response::sse::{Event, KeepAlive, Sse},
181 routing::get,
182 Router,
183 };
184 use futures_util::stream::Stream;
185 use serde::{Deserialize, Serialize};
186 use std::{convert::Infallible, time::Duration};
187 use tokio::{net::TcpListener, sync::broadcast, task::JoinHandle, time::timeout};
188
189 #[derive(Serialize, Deserialize, Debug)]
190 struct Count {
191 value: u32,
192 }
193
194 struct DemoSseServer {
195 port: u16,
196 handle: Option<JoinHandle<std::result::Result<(), anyhow::Error>>>,
197 disconnect_sender: broadcast::Sender<()>,
198 }
199
200 impl Drop for DemoSseServer {
201 fn drop(&mut self) {
202 self.handle.take().unwrap().abort();
203 }
204 }
205
206 async fn handle_sse(
207 State(disconnect_sender): State<broadcast::Sender<()>>,
208 headers: HeaderMap,
209 ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
210 let mut receiver = disconnect_sender.subscribe();
211
212 let mut value = headers
213 .get("Last-Event-ID")
214 .and_then(|id| {
215 id.to_str()
216 .ok()
217 .and_then(|id| id.parse::<u32>().ok())
218 .map(|id| id + 1)
219 })
220 .unwrap_or(0);
221
222 let stream = stream! {
223 loop {
224 if (timeout(Duration::from_millis(100), receiver.recv()).await).is_ok() {
225 break;
226 };
227
228 let event = Event::default().json_data(&Count { value }).unwrap().id(value.to_string());
229 yield Ok(event);
230 value += 1;
231 tokio::time::sleep(Duration::from_millis(100)).await;
232 }
233 };
234
235 Sse::new(stream).keep_alive(KeepAlive::default())
236 }
237
238 impl DemoSseServer {
239 async fn new() -> Self {
240 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 0));
241 let listener = TcpListener::bind(addr).await.unwrap();
242 let port = listener.local_addr().unwrap().port();
243 let (disconnect_sender, _) = broadcast::channel::<()>(1);
244
245 let app = Router::new()
246 .route("/counter", get(handle_sse))
247 .with_state(disconnect_sender.clone());
248
249 let server = axum::serve(listener, app.into_make_service());
250 let handle = tokio::spawn(async move { server.await.map_err(anyhow::Error::new) });
251
252 Self {
253 port,
254 handle: Some(handle),
255 disconnect_sender,
256 }
257 }
258
259 async fn disconnect(&self) {
260 self.disconnect_sender.send(()).unwrap();
261 }
262
263 fn url(&self) -> Url {
264 let url = format!("http://localhost:{}/counter", self.port);
265 url::Url::parse(&url).unwrap()
266 }
267 }
268
269 #[tokio::test]
270 async fn test_simple_sse() {
271 let server = DemoSseServer::new().await;
272
273 let client = reqwest::Client::new();
274 let mut stream = super::sse_request::<Count>(server.url(), client)
275 .await
276 .unwrap();
277
278 for i in 0..10 {
279 let value = stream.next().await.unwrap();
280 assert_eq!(value.value, i);
281 }
282 }
283
284 #[tokio::test]
285 async fn test_sse_reconnect() {
286 let server = DemoSseServer::new().await;
287
288 let client = reqwest::Client::new();
289 let mut stream = super::sse_request::<Count>(server.url(), client)
290 .await
291 .unwrap();
292
293 for i in 0..10 {
294 let value = stream.next().await.unwrap();
295 assert_eq!(value.value, i);
296 }
297
298 server.disconnect().await;
299
300 for i in 10..20 {
301 let value = stream.next().await.unwrap();
302 assert_eq!(value.value, i);
303 }
304 }
305}