plane_common/
sse.rs

1use super::PlaneClientError;
2use crate::exponential_backoff::ExponentialBackoff;
3use reqwest::{
4    header::{HeaderValue, ACCEPT, CONNECTION},
5    Client, Response,
6};
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use url::Url;
10
11struct RawSseStream {
12    response: Response,
13    // Bytes left over from the last chunk.
14    buffer: Vec<u8>,
15    // Bytes that are part of the current data payload.
16    data: Option<Vec<u8>>,
17    id: Option<String>,
18}
19
20impl RawSseStream {
21    fn new(response: Response) -> Self {
22        Self {
23            response,
24            buffer: Vec::new(),
25            data: None,
26            id: None,
27        }
28    }
29
30    async fn next(&mut self) -> Option<(Option<String>, Vec<u8>)> {
31        loop {
32            let chunk = match self.response.chunk().await {
33                Ok(Some(chunk)) => chunk,
34                Ok(None) => return None,
35                Err(err) => {
36                    tracing::error!(?err, "Error reading SSE stream.");
37                    return None;
38                }
39            };
40            let mut chunk = chunk.as_ref();
41
42            // For as long as there are newlines in the chunk, process it line-by-line.
43            while let Some(newline_idx) = chunk.iter().position(|&b| b == b'\n') {
44                let current_line = &chunk[..newline_idx];
45                chunk = &chunk[newline_idx + 1..];
46
47                // If we have anything in the buffer, swap it for an empty buffer and prepend it to the current line.
48                let mut buffer = std::mem::take(&mut self.buffer);
49                buffer.extend_from_slice(current_line);
50
51                if let Some(result) = buffer.strip_prefix(b"data:") {
52                    match self.data {
53                        Some(ref mut data) => {
54                            // Per https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
55                            // > When the EventSource receives multiple consecutive lines that begin with data:,
56                            // > it concatenates them, inserting a newline character between each one.
57                            data.push(b'\n');
58                            data.extend_from_slice(result)
59                        }
60                        None => self.data = Some(result.to_vec()),
61                    }
62                } else if let Some(result) = buffer.strip_prefix(b"id:") {
63                    let id = match String::from_utf8(result.to_vec()) {
64                        Ok(id) => id,
65                        Err(err) => {
66                            tracing::error!(?err, "Error parsing SSE stream ID.");
67                            continue;
68                        }
69                    };
70                    self.id = Some(id);
71                } else if buffer.is_empty() && self.data.is_some() {
72                    let data = self.data.take().unwrap_or_default();
73                    return Some((self.id.take(), data));
74                }
75            }
76
77            // Add anything left over to the buffer.
78            self.buffer.extend_from_slice(chunk);
79        }
80    }
81}
82
83pub struct SseStream<T: DeserializeOwned> {
84    url: Url,
85    client: Client,
86    stream: Option<RawSseStream>,
87    backoff: ExponentialBackoff,
88    last_id: Option<String>,
89    _phantom: PhantomData<T>,
90}
91
92impl<T: DeserializeOwned> SseStream<T> {
93    fn new(url: Url, client: Client) -> Self {
94        Self {
95            url,
96            client,
97            stream: None,
98            backoff: ExponentialBackoff::default(),
99            last_id: None,
100            _phantom: PhantomData,
101        }
102    }
103
104    async fn ensure_stream(&mut self) -> Result<(), PlaneClientError> {
105        if self.stream.is_none() {
106            let mut request = self
107                .client
108                .get(self.url.clone())
109                .header(ACCEPT, HeaderValue::from_static("text/event-stream"))
110                .header(CONNECTION, HeaderValue::from_static("keep-alive"));
111
112            if let Some(id) = &self.last_id {
113                request = request.header("Last-Event-ID", id);
114            }
115
116            let response = request.send().await?;
117
118            if response.status() != 200 {
119                let status = response.status();
120                return Err(PlaneClientError::UnexpectedStatus(status));
121            }
122
123            self.stream = Some(RawSseStream::new(response));
124            return Ok(());
125        }
126
127        Ok(())
128    }
129
130    pub async fn next(&mut self) -> Option<T> {
131        loop {
132            if let Err(err) = self.ensure_stream().await {
133                tracing::error!(?err, "Error connecting to SSE stream.");
134                self.backoff.wait().await;
135                continue;
136            }
137
138            // We can unwrap here because we just ensured the stream exists.
139            let stream = self.stream.as_mut().expect("Stream is always Some.");
140            self.backoff.defer_reset();
141
142            let (id, data) = match stream.next().await {
143                Some(data) => data,
144                None => {
145                    self.stream = None;
146                    continue;
147                }
148            };
149
150            self.last_id = id;
151
152            match serde_json::from_slice(&data) {
153                Ok(value) => return Some(value),
154                Err(err) => {
155                    let typ = std::any::type_name::<T>();
156                    tracing::error!(?err, typ, "Failed to parse SSE data as type.");
157                    continue;
158                }
159            }
160        }
161    }
162}
163
164pub async fn sse_request<T: DeserializeOwned>(
165    url: Url,
166    client: Client,
167) -> Result<SseStream<T>, PlaneClientError> {
168    let mut stream = SseStream::new(url, client);
169    stream.ensure_stream().await?;
170    Ok(stream)
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use async_stream::stream;
177    use axum::{
178        extract::State,
179        http::HeaderMap,
180        response::sse::{Event, KeepAlive, Sse},
181        routing::get,
182        Router,
183    };
184    use futures_util::stream::Stream;
185    use serde::{Deserialize, Serialize};
186    use std::{convert::Infallible, time::Duration};
187    use tokio::{net::TcpListener, sync::broadcast, task::JoinHandle, time::timeout};
188
189    #[derive(Serialize, Deserialize, Debug)]
190    struct Count {
191        value: u32,
192    }
193
194    struct DemoSseServer {
195        port: u16,
196        handle: Option<JoinHandle<std::result::Result<(), anyhow::Error>>>,
197        disconnect_sender: broadcast::Sender<()>,
198    }
199
200    impl Drop for DemoSseServer {
201        fn drop(&mut self) {
202            self.handle.take().unwrap().abort();
203        }
204    }
205
206    async fn handle_sse(
207        State(disconnect_sender): State<broadcast::Sender<()>>,
208        headers: HeaderMap,
209    ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
210        let mut receiver = disconnect_sender.subscribe();
211
212        let mut value = headers
213            .get("Last-Event-ID")
214            .and_then(|id| {
215                id.to_str()
216                    .ok()
217                    .and_then(|id| id.parse::<u32>().ok())
218                    .map(|id| id + 1)
219            })
220            .unwrap_or(0);
221
222        let stream = stream! {
223            loop {
224                if (timeout(Duration::from_millis(100), receiver.recv()).await).is_ok() {
225                    break;
226                };
227
228                let event = Event::default().json_data(&Count { value }).unwrap().id(value.to_string());
229                yield Ok(event);
230                value += 1;
231                tokio::time::sleep(Duration::from_millis(100)).await;
232            }
233        };
234
235        Sse::new(stream).keep_alive(KeepAlive::default())
236    }
237
238    impl DemoSseServer {
239        async fn new() -> Self {
240            let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 0));
241            let listener = TcpListener::bind(addr).await.unwrap();
242            let port = listener.local_addr().unwrap().port();
243            let (disconnect_sender, _) = broadcast::channel::<()>(1);
244
245            let app = Router::new()
246                .route("/counter", get(handle_sse))
247                .with_state(disconnect_sender.clone());
248
249            let server = axum::serve(listener, app.into_make_service());
250            let handle = tokio::spawn(async move { server.await.map_err(anyhow::Error::new) });
251
252            Self {
253                port,
254                handle: Some(handle),
255                disconnect_sender,
256            }
257        }
258
259        async fn disconnect(&self) {
260            self.disconnect_sender.send(()).unwrap();
261        }
262
263        fn url(&self) -> Url {
264            let url = format!("http://localhost:{}/counter", self.port);
265            url::Url::parse(&url).unwrap()
266        }
267    }
268
269    #[tokio::test]
270    async fn test_simple_sse() {
271        let server = DemoSseServer::new().await;
272
273        let client = reqwest::Client::new();
274        let mut stream = super::sse_request::<Count>(server.url(), client)
275            .await
276            .unwrap();
277
278        for i in 0..10 {
279            let value = stream.next().await.unwrap();
280            assert_eq!(value.value, i);
281        }
282    }
283
284    #[tokio::test]
285    async fn test_sse_reconnect() {
286        let server = DemoSseServer::new().await;
287
288        let client = reqwest::Client::new();
289        let mut stream = super::sse_request::<Count>(server.url(), client)
290            .await
291            .unwrap();
292
293        for i in 0..10 {
294            let value = stream.next().await.unwrap();
295            assert_eq!(value.value, i);
296        }
297
298        server.disconnect().await;
299
300        for i in 10..20 {
301            let value = stream.next().await.unwrap();
302            assert_eq!(value.value, i);
303        }
304    }
305}