use bytes::{Bytes, BytesMut};
use futures_util::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
const MAX_SSE_BUFFER_SIZE: usize = 64 * 1024;
pub struct SseBufferedStream<S> {
inner: S,
buffer: BytesMut,
}
impl<S> SseBufferedStream<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
buffer: BytesMut::new(),
}
}
}
impl<S, E> Stream for SseBufferedStream<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
{
type Item = Result<Bytes, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
loop {
if let Some(pos) = find_event_boundary(&this.buffer) {
let complete = this.buffer.split_to(pos + 2);
return Poll::Ready(Some(Ok(complete.freeze())));
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.buffer.extend_from_slice(&chunk);
if this.buffer.len() > MAX_SSE_BUFFER_SIZE {
tracing::error!(
"SSE buffer exceeded maximum size of {} bytes, terminating stream",
MAX_SSE_BUFFER_SIZE
);
this.buffer.clear();
return Poll::Ready(None);
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
if this.buffer.is_empty() {
return Poll::Ready(None);
}
let remaining = this.buffer.split().freeze();
return Poll::Ready(Some(Ok(remaining)));
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
fn find_event_boundary(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|window| window == b"\n\n")
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
use std::convert::Infallible;
fn chunks_to_stream(
chunks: Vec<&'static [u8]>,
) -> impl Stream<Item = Result<Bytes, Infallible>> + Unpin {
futures_util::stream::iter(chunks.into_iter().map(|c| Ok(Bytes::from_static(c))))
}
#[tokio::test]
async fn test_complete_event_passes_through() {
let chunks = vec![b"data: {\"hello\": \"world\"}\n\n".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(
results[0].as_ref().unwrap().as_ref(),
b"data: {\"hello\": \"world\"}\n\n"
);
}
#[tokio::test]
async fn test_split_event_is_buffered() {
let chunks = vec![
b"data: {\"hel".as_slice(),
b"lo\": \"world\"}\n\n".as_slice(),
];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(
results[0].as_ref().unwrap().as_ref(),
b"data: {\"hello\": \"world\"}\n\n"
);
}
#[tokio::test]
async fn test_multiple_events_in_one_chunk() {
let chunks = vec![b"data: first\n\ndata: second\n\n".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].as_ref().unwrap().as_ref(), b"data: first\n\n");
assert_eq!(results[1].as_ref().unwrap().as_ref(), b"data: second\n\n");
}
#[tokio::test]
async fn test_event_split_at_newline() {
let chunks = vec![b"data: test\n".as_slice(), b"\n".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap().as_ref(), b"data: test\n\n");
}
#[tokio::test]
async fn test_multiple_events_across_chunks() {
let chunks = vec![
b"data: first\n\ndata: sec".as_slice(),
b"ond\n\ndata: third\n\n".as_slice(),
];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 3);
assert_eq!(results[0].as_ref().unwrap().as_ref(), b"data: first\n\n");
assert_eq!(results[1].as_ref().unwrap().as_ref(), b"data: second\n\n");
assert_eq!(results[2].as_ref().unwrap().as_ref(), b"data: third\n\n");
}
#[tokio::test]
async fn test_incomplete_event_at_stream_end() {
let chunks = vec![b"data: incomplete".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap().as_ref(), b"data: incomplete");
}
#[tokio::test]
async fn test_empty_stream() {
let chunks: Vec<&[u8]> = vec![];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 0);
}
#[tokio::test]
async fn test_json_split_across_many_chunks() {
let chunks = vec![
b"da".as_slice(),
b"ta: ".as_slice(),
b"{\"delta\"".as_slice(),
b": {\"".as_slice(),
b"content\": \"Hello".as_slice(),
b"\"}}\n".as_slice(),
b"\n".as_slice(),
];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(
results[0].as_ref().unwrap().as_ref(),
b"data: {\"delta\": {\"content\": \"Hello\"}}\n\n"
);
}
#[tokio::test]
async fn test_handles_crlf_events() {
let chunks = vec![b"data: test\r\n\r\n".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap().as_ref(), b"data: test\r\n\r\n");
}
#[tokio::test]
async fn test_preserves_multiline_data() {
let chunks = vec![b"data: line1\ndata: line2\n\n".as_slice()];
let stream = SseBufferedStream::new(chunks_to_stream(chunks));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(
results[0].as_ref().unwrap().as_ref(),
b"data: line1\ndata: line2\n\n"
);
}
#[tokio::test]
async fn test_buffer_overflow_terminates_stream() {
let large_chunk = vec![b'x'; MAX_SSE_BUFFER_SIZE + 1];
let chunks: Vec<&[u8]> = vec![&large_chunk];
let stream = SseBufferedStream::new(futures_util::stream::iter(
chunks
.into_iter()
.map(|c| Ok::<_, Infallible>(Bytes::from(c.to_vec()))),
));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 0);
}
#[tokio::test]
async fn test_buffer_at_limit_still_works() {
let mut chunk = vec![b'x'; MAX_SSE_BUFFER_SIZE - 2];
chunk.extend_from_slice(b"\n\n");
let chunks: Vec<&[u8]> = vec![&chunk];
let stream = SseBufferedStream::new(futures_util::stream::iter(
chunks
.into_iter()
.map(|c| Ok::<_, Infallible>(Bytes::from(c.to_vec()))),
));
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert!(results[0].is_ok());
assert_eq!(results[0].as_ref().unwrap().len(), MAX_SSE_BUFFER_SIZE);
}
}