Skip to main content

aster/streaming/
sse.rs

1//! Server-Sent Events (SSE) Parser
2//!
3//! Implements standard SSE protocol parsing with support for:
4//! - event: and data: field parsing
5//! - Multi-line data fields
6//! - CRLF and LF line endings
7//! - Stream reconnection
8//!
9
10use serde::{Deserialize, Serialize};
11use std::collections::VecDeque;
12
13/// SSE Event structure
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SSEEvent {
16    /// Event type (defaults to "message")
17    pub event: String,
18    /// Event data
19    pub data: String,
20    /// Raw lines that made up this event
21    pub raw: Vec<String>,
22    /// Event ID (optional)
23    pub id: Option<String>,
24    /// Retry time in milliseconds (optional)
25    pub retry: Option<u64>,
26}
27
28impl SSEEvent {
29    /// Create a new SSE event with default values
30    pub fn new(data: String) -> Self {
31        Self {
32            event: "message".to_string(),
33            data,
34            raw: Vec::new(),
35            id: None,
36            retry: None,
37        }
38    }
39
40    /// Create with specific event type
41    pub fn with_event(mut self, event: impl Into<String>) -> Self {
42        self.event = event.into();
43        self
44    }
45
46    /// Parse the data as JSON
47    pub fn parse_json<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
48        serde_json::from_str(&self.data)
49    }
50}
51
52/// SSE Event Decoder
53/// Parses SSE protocol line by line
54pub struct SSEDecoder {
55    event_type: Option<String>,
56    data_lines: Vec<String>,
57    chunks: Vec<String>,
58    event_id: Option<String>,
59    retry_time: Option<u64>,
60}
61
62impl Default for SSEDecoder {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl SSEDecoder {
69    /// Create a new SSE decoder
70    pub fn new() -> Self {
71        Self {
72            event_type: None,
73            data_lines: Vec::new(),
74            chunks: Vec::new(),
75            event_id: None,
76            retry_time: None,
77        }
78    }
79
80    /// Decode a single line of SSE data
81    /// Returns a complete SSE event if the line is empty (event boundary)
82    pub fn decode(&mut self, line: &str) -> Option<SSEEvent> {
83        self.chunks.push(line.to_string());
84
85        // Empty line indicates event end
86        if line.trim().is_empty() {
87            if self.data_lines.is_empty() {
88                self.reset();
89                return None;
90            }
91
92            let event = SSEEvent {
93                event: self
94                    .event_type
95                    .take()
96                    .unwrap_or_else(|| "message".to_string()),
97                data: self.data_lines.join("\n"),
98                raw: std::mem::take(&mut self.chunks),
99                id: self.event_id.clone(),
100                retry: self.retry_time,
101            };
102
103            self.reset();
104            return Some(event);
105        }
106
107        // Comment line (starts with :)
108        if line.starts_with(':') {
109            return None;
110        }
111
112        // Parse field
113        if let Some((field, value)) = split_first(line, ':') {
114            let value = value.strip_prefix(' ').unwrap_or(value);
115
116            match field {
117                "event" => self.event_type = Some(value.to_string()),
118                "data" => self.data_lines.push(value.to_string()),
119                "id" => self.event_id = Some(value.to_string()),
120                "retry" => {
121                    if let Ok(retry) = value.parse::<u64>() {
122                        self.retry_time = Some(retry);
123                    }
124                }
125                _ => {}
126            }
127        }
128
129        None
130    }
131
132    /// Flush the buffer (force complete current event)
133    pub fn flush(&mut self) -> Option<SSEEvent> {
134        if self.data_lines.is_empty() {
135            return None;
136        }
137
138        let event = SSEEvent {
139            event: self
140                .event_type
141                .take()
142                .unwrap_or_else(|| "message".to_string()),
143            data: self.data_lines.join("\n"),
144            raw: std::mem::take(&mut self.chunks),
145            id: self.event_id.clone(),
146            retry: self.retry_time,
147        };
148
149        self.reset();
150        Some(event)
151    }
152
153    fn reset(&mut self) {
154        self.event_type = None;
155        self.data_lines.clear();
156        self.chunks.clear();
157        // id and retry are not reset per SSE spec
158    }
159}
160
161/// Split string at first occurrence of separator
162fn split_first(s: &str, sep: char) -> Option<(&str, &str)> {
163    let idx = s.find(sep)?;
164    Some((s.get(..idx)?, s.get(idx + 1..)?))
165}
166
167/// Newline Decoder for byte-level buffering
168/// Handles both CRLF and LF line endings
169pub struct NewlineDecoder {
170    buffer: Vec<u8>,
171    carriage_index: Option<usize>,
172}
173
174impl Default for NewlineDecoder {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl NewlineDecoder {
181    /// Create a new newline decoder
182    pub fn new() -> Self {
183        Self {
184            buffer: Vec::new(),
185            carriage_index: None,
186        }
187    }
188
189    /// Decode a chunk of bytes, extracting complete lines
190    pub fn decode(&mut self, chunk: &[u8]) -> Vec<String> {
191        self.buffer.extend_from_slice(chunk);
192
193        let mut lines = Vec::new();
194
195        while let Some(line_end) = self.find_newline() {
196            let line_bytes = &self.buffer[..line_end.preceding];
197            if let Ok(line) = String::from_utf8(line_bytes.to_vec()) {
198                lines.push(line);
199            }
200
201            self.buffer = self.buffer[line_end.index..].to_vec();
202            self.carriage_index = None;
203        }
204
205        lines
206    }
207
208    /// Flush remaining buffer as final line
209    pub fn flush(&mut self) -> Vec<String> {
210        if self.buffer.is_empty() {
211            return Vec::new();
212        }
213
214        let line = String::from_utf8_lossy(&self.buffer).to_string();
215        self.buffer.clear();
216        self.carriage_index = None;
217        vec![line]
218    }
219
220    fn find_newline(&self) -> Option<LineEnd> {
221        let start = self.carriage_index.unwrap_or(0);
222
223        for i in start..self.buffer.len() {
224            let byte = self.buffer[i];
225
226            if byte == 0x0a {
227                // LF
228                let preceding = if i > 0 && self.buffer[i - 1] == 0x0d {
229                    i - 1 // CRLF
230                } else {
231                    i // LF only
232                };
233
234                return Some(LineEnd {
235                    index: i + 1,
236                    preceding,
237                });
238            }
239        }
240
241        None
242    }
243}
244
245struct LineEnd {
246    index: usize,
247    preceding: usize,
248}
249
250/// SSE Stream wrapper for high-level SSE processing
251pub struct SSEStream<T> {
252    decoder: SSEDecoder,
253    newline_decoder: NewlineDecoder,
254    event_queue: VecDeque<SSEEvent>,
255    aborted: bool,
256    _phantom: std::marker::PhantomData<T>,
257}
258
259impl<T> Default for SSEStream<T> {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265impl<T> SSEStream<T> {
266    /// Create a new SSE stream
267    pub fn new() -> Self {
268        Self {
269            decoder: SSEDecoder::new(),
270            newline_decoder: NewlineDecoder::new(),
271            event_queue: VecDeque::new(),
272            aborted: false,
273            _phantom: std::marker::PhantomData,
274        }
275    }
276
277    /// Process incoming bytes
278    pub fn process_bytes(&mut self, bytes: &[u8]) {
279        if self.aborted {
280            return;
281        }
282
283        let lines = self.newline_decoder.decode(bytes);
284        for line in lines {
285            if let Some(event) = self.decoder.decode(&line) {
286                self.event_queue.push_back(event);
287            }
288        }
289    }
290
291    /// Get next event from queue
292    pub fn next_event(&mut self) -> Option<SSEEvent> {
293        self.event_queue.pop_front()
294    }
295
296    /// Flush and get remaining events
297    pub fn flush(&mut self) -> Vec<SSEEvent> {
298        let mut events = Vec::new();
299
300        // Flush newline decoder
301        for line in self.newline_decoder.flush() {
302            if let Some(event) = self.decoder.decode(&line) {
303                events.push(event);
304            }
305        }
306
307        // Flush SSE decoder
308        if let Some(event) = self.decoder.flush() {
309            events.push(event);
310        }
311
312        // Drain queue
313        events.extend(self.event_queue.drain(..));
314        events
315    }
316
317    /// Abort the stream
318    pub fn abort(&mut self) {
319        self.aborted = true;
320    }
321
322    /// Check if stream is aborted
323    pub fn is_aborted(&self) -> bool {
324        self.aborted
325    }
326
327    /// Check if there are pending events
328    pub fn has_events(&self) -> bool {
329        !self.event_queue.is_empty()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_sse_event_new() {
339        let event = SSEEvent::new("test data".to_string());
340        assert_eq!(event.event, "message");
341        assert_eq!(event.data, "test data");
342    }
343
344    #[test]
345    fn test_sse_event_with_event() {
346        let event = SSEEvent::new("data".to_string()).with_event("custom");
347        assert_eq!(event.event, "custom");
348    }
349
350    #[test]
351    fn test_sse_decoder_simple() {
352        let mut decoder = SSEDecoder::new();
353
354        assert!(decoder.decode("data: hello").is_none());
355        let event = decoder.decode("").unwrap();
356
357        assert_eq!(event.event, "message");
358        assert_eq!(event.data, "hello");
359    }
360
361    #[test]
362    fn test_sse_decoder_with_event_type() {
363        let mut decoder = SSEDecoder::new();
364
365        decoder.decode("event: custom");
366        decoder.decode("data: test");
367        let event = decoder.decode("").unwrap();
368
369        assert_eq!(event.event, "custom");
370        assert_eq!(event.data, "test");
371    }
372
373    #[test]
374    fn test_sse_decoder_multiline_data() {
375        let mut decoder = SSEDecoder::new();
376
377        decoder.decode("data: line1");
378        decoder.decode("data: line2");
379        let event = decoder.decode("").unwrap();
380
381        assert_eq!(event.data, "line1\nline2");
382    }
383
384    #[test]
385    fn test_sse_decoder_comment() {
386        let mut decoder = SSEDecoder::new();
387
388        decoder.decode(": this is a comment");
389        decoder.decode("data: actual data");
390        let event = decoder.decode("").unwrap();
391
392        assert_eq!(event.data, "actual data");
393    }
394
395    #[test]
396    fn test_sse_decoder_id_and_retry() {
397        let mut decoder = SSEDecoder::new();
398
399        decoder.decode("id: 123");
400        decoder.decode("retry: 5000");
401        decoder.decode("data: test");
402        let event = decoder.decode("").unwrap();
403
404        assert_eq!(event.id, Some("123".to_string()));
405        assert_eq!(event.retry, Some(5000));
406    }
407
408    #[test]
409    fn test_sse_decoder_flush() {
410        let mut decoder = SSEDecoder::new();
411
412        decoder.decode("data: incomplete");
413        let event = decoder.flush().unwrap();
414
415        assert_eq!(event.data, "incomplete");
416    }
417
418    #[test]
419    fn test_newline_decoder_lf() {
420        let mut decoder = NewlineDecoder::new();
421        let lines = decoder.decode(b"line1\nline2\n");
422        assert_eq!(lines, vec!["line1", "line2"]);
423    }
424
425    #[test]
426    fn test_newline_decoder_crlf() {
427        let mut decoder = NewlineDecoder::new();
428        let lines = decoder.decode(b"line1\r\nline2\r\n");
429        assert_eq!(lines, vec!["line1", "line2"]);
430    }
431
432    #[test]
433    fn test_newline_decoder_partial() {
434        let mut decoder = NewlineDecoder::new();
435
436        let lines1 = decoder.decode(b"par");
437        assert!(lines1.is_empty());
438
439        let lines2 = decoder.decode(b"tial\n");
440        assert_eq!(lines2, vec!["partial"]);
441    }
442
443    #[test]
444    fn test_newline_decoder_flush() {
445        let mut decoder = NewlineDecoder::new();
446        decoder.decode(b"incomplete");
447        let lines = decoder.flush();
448        assert_eq!(lines, vec!["incomplete"]);
449    }
450
451    #[test]
452    fn test_sse_stream_process() {
453        let mut stream: SSEStream<()> = SSEStream::new();
454
455        stream.process_bytes(b"data: hello\n\n");
456
457        let event = stream.next_event().unwrap();
458        assert_eq!(event.data, "hello");
459    }
460
461    #[test]
462    fn test_sse_stream_abort() {
463        let mut stream: SSEStream<()> = SSEStream::new();
464
465        stream.abort();
466        assert!(stream.is_aborted());
467
468        stream.process_bytes(b"data: ignored\n\n");
469        assert!(stream.next_event().is_none());
470    }
471}