use crate::error::LingerError;
use crate::transport::BodyStream;
use futures_core::Stream;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct SseEvent {
pub event_type: Option<String>,
pub data: String,
}
pub struct SseStream {
body: BodyStream,
buffer: Vec<u8>,
pending: VecDeque<SseEvent>,
finished: bool,
}
impl SseStream {
pub fn new(body: BodyStream) -> Self {
Self {
body,
buffer: Vec::new(),
pending: VecDeque::new(),
finished: false,
}
}
fn decode_ready_frames(&mut self) -> Result<(), LingerError> {
while let Some(frame_end) = find_frame_end(&self.buffer) {
let frame = self.buffer.drain(..frame_end.consumed).collect::<Vec<_>>();
let frame = &frame[..frame_end.payload_len];
if let Some(event) = parse_frame(frame)? {
self.pending.push_back(event);
}
}
Ok(())
}
fn decode_trailing_frame(&mut self) -> Result<(), LingerError> {
if self.buffer.is_empty() {
return Ok(());
}
let frame = std::mem::take(&mut self.buffer);
if let Some(event) = parse_frame(&frame)? {
self.pending.push_back(event);
}
Ok(())
}
}
impl Stream for SseStream {
type Item = Result<SseEvent, LingerError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(event) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(event)));
}
if this.finished {
return Poll::Ready(None);
}
match Pin::new(&mut this.body).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.buffer.extend_from_slice(&chunk);
if let Err(error) = this.decode_ready_frames() {
this.finished = true;
return Poll::Ready(Some(Err(error)));
}
}
Poll::Ready(Some(Err(error))) => {
this.finished = true;
return Poll::Ready(Some(Err(error)));
}
Poll::Ready(None) => {
this.finished = true;
if let Err(error) = this.decode_trailing_frame() {
return Poll::Ready(Some(Err(error)));
}
}
Poll::Pending => return Poll::Pending,
}
}
}
}
struct FrameEnd {
payload_len: usize,
consumed: usize,
}
fn find_frame_end(buffer: &[u8]) -> Option<FrameEnd> {
for index in 0..buffer.len().saturating_sub(1) {
if buffer[index] == b'\n' && buffer[index + 1] == b'\n' {
return Some(FrameEnd {
payload_len: index,
consumed: index + 2,
});
}
if index + 3 < buffer.len()
&& buffer[index] == b'\r'
&& buffer[index + 1] == b'\n'
&& buffer[index + 2] == b'\r'
&& buffer[index + 3] == b'\n'
{
return Some(FrameEnd {
payload_len: index,
consumed: index + 4,
});
}
}
None
}
fn parse_frame(frame: &[u8]) -> Result<Option<SseEvent>, LingerError> {
let text = std::str::from_utf8(frame)
.map_err(|error| LingerError::streaming(format!("SSE frame is not UTF-8: {error}")))?;
let mut event_type = None;
let mut data = Vec::new();
for raw_line in text.lines() {
let line = raw_line.strip_suffix('\r').unwrap_or(raw_line);
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(value) = line.strip_prefix("event:") {
event_type = Some(value.trim_start().to_string());
} else if let Some(value) = line.strip_prefix("data:") {
data.push(value.trim_start().to_string());
}
}
if event_type.is_none() && data.is_empty() {
return Ok(None);
}
Ok(Some(SseEvent {
event_type,
data: data.join("\n"),
}))
}