use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use tracing::trace;
use crate::error::Result;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
#[cfg(target_arch = "wasm32")]
pub(crate) type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static>>;
pub(crate) struct SseFrameStream {
upstream: ByteStream,
buffer: BytesMut,
done: bool,
sentinel: Option<&'static [u8]>,
label: &'static str,
}
impl SseFrameStream {
pub(crate) fn new(
upstream: ByteStream,
sentinel: Option<&'static [u8]>,
label: &'static str,
) -> Self {
Self {
upstream,
buffer: BytesMut::with_capacity(8 * 1024),
done: false,
sentinel,
label,
}
}
fn take_frame(&mut self) -> Option<Vec<u8>> {
let bytes = &self.buffer[..];
let mut i = 0;
while i < bytes.len() {
if i + 3 < bytes.len()
&& bytes[i] == b'\r'
&& bytes[i + 1] == b'\n'
&& bytes[i + 2] == b'\r'
&& bytes[i + 3] == b'\n'
{
let frame = self.buffer.split_to(i + 4);
return Some(extract_data_payload(&frame));
}
if i + 1 < bytes.len() && bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
let frame = self.buffer.split_to(i + 2);
return Some(extract_data_payload(&frame));
}
i += 1;
}
None
}
fn take_remaining(&mut self) -> Option<Vec<u8>> {
if self.buffer.is_empty() {
return None;
}
let frame = self.buffer.split_to(self.buffer.len());
Some(extract_data_payload(&frame))
}
}
impl Stream for SseFrameStream {
type Item = Result<Vec<u8>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.done {
let payload = match self.take_frame() {
Some(p) => p,
None => match self.take_remaining() {
Some(p) => p,
None => return Poll::Ready(None),
},
};
if payload.is_empty() {
continue;
}
if Some(payload.as_slice()) == self.sentinel {
continue;
}
return Poll::Ready(Some(Ok(payload)));
}
if let Some(payload) = self.take_frame() {
if payload.is_empty() {
continue;
}
if Some(payload.as_slice()) == self.sentinel {
self.done = true;
self.buffer.clear();
continue;
}
return Poll::Ready(Some(Ok(payload)));
}
match self.upstream.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => {
trace!(len = bytes.len(), "{} sse bytes", self.label);
self.buffer.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => self.done = true,
}
}
}
}
fn extract_data_payload(frame: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(frame.len());
let text = std::str::from_utf8(frame).unwrap_or("");
for line in text.split('\n') {
let line = line.trim_end_matches('\r');
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
if !out.is_empty() {
out.push(b'\n');
}
out.extend_from_slice(rest.as_bytes());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{stream, StreamExt};
fn byte_stream(parts: &[&[u8]]) -> ByteStream {
let owned: Vec<Bytes> = parts.iter().map(|b| Bytes::copy_from_slice(b)).collect();
Box::pin(stream::iter(owned.into_iter().map(Ok)))
}
async fn frames(parts: &[&[u8]], sentinel: Option<&'static [u8]>) -> Vec<String> {
let mut s = SseFrameStream::new(byte_stream(parts), sentinel, "test");
let mut out = Vec::new();
while let Some(item) = s.next().await {
out.push(String::from_utf8(item.unwrap()).unwrap());
}
out
}
#[tokio::test]
async fn lf_and_crlf_boundaries_both_split() {
assert_eq!(frames(&[b"data: a\n\n", b"data: b\r\n\r\n"], None).await, vec!["a", "b"]);
}
#[tokio::test]
async fn multiple_data_lines_in_one_frame_concat_with_newline() {
assert_eq!(frames(&[b"data: line1\ndata: line2\n\n"], None).await, vec!["line1\nline2"]);
}
#[tokio::test]
async fn non_data_fields_and_comments_ignored() {
let parts: &[&[u8]] = &[b"event: msg\ndata: x\nid: 7\n: a comment\nretry: 100\n\n"];
assert_eq!(frames(parts, None).await, vec!["x"]);
}
#[tokio::test]
async fn only_one_leading_space_stripped() {
assert_eq!(frames(&[b"data: x\n\n", b"data:y\n\n"], None).await, vec![" x", "y"]);
}
#[tokio::test]
async fn heartbeat_and_empty_frames_are_skipped() {
assert_eq!(frames(&[b": ping\n\n", b"\n\n", b"data: real\n\n"], None).await, vec!["real"]);
}
#[tokio::test]
async fn eof_flushes_an_unterminated_final_frame() {
assert_eq!(frames(&[b"data: a\n\n", b"data: last"], None).await, vec!["a", "last"]);
}
#[tokio::test]
async fn sentinel_terminates_and_drops_anything_after_it() {
let parts: &[&[u8]] = &[b"data: a\n\n", b"data: [DONE]\n\ndata: leaked\n\n"];
assert_eq!(frames(parts, Some(b"[DONE]")).await, vec!["a"]);
}
#[tokio::test]
async fn frame_split_across_chunks_is_buffered_until_complete() {
assert_eq!(frames(&[b"data: hel", b"lo\n\n"], None).await, vec!["hello"]);
assert_eq!(frames(&[b"data: z\r\n", b"\r\n"], None).await, vec!["z"]);
}
}