use bytes::Bytes;
use futures_core::Stream;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SseEvent {
pub event: Option<String>,
pub data: String,
pub id: Option<String>,
pub retry: Option<u64>,
}
pub fn parse_sse_chunk(text: &str) -> Vec<SseEvent> {
let normalized = text.replace("\r\n", "\n");
let mut events = Vec::new();
let mut event_type: Option<String> = None;
let mut data_lines: Vec<&str> = Vec::new();
let mut id: Option<String> = None;
let mut retry: Option<u64> = None;
for line in normalized.lines() {
if line.is_empty() {
if !data_lines.is_empty() {
events.push(SseEvent {
event: event_type.take(),
data: data_lines.join("\n"),
id: id.take(),
retry: retry.take(),
});
data_lines.clear();
} else {
event_type = None;
id = None;
retry = None;
}
continue;
}
if let Some(value) = line.strip_prefix("data:") {
data_lines.push(value.strip_prefix(' ').unwrap_or(value));
} else if let Some(value) = line.strip_prefix("event:") {
event_type = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
} else if let Some(value) = line.strip_prefix("id:") {
id = Some(value.strip_prefix(' ').unwrap_or(value).to_string());
} else if let Some(value) = line.strip_prefix("retry:") {
let v = value.strip_prefix(' ').unwrap_or(value);
retry = v.parse().ok();
}
}
if !data_lines.is_empty() {
events.push(SseEvent {
event: event_type,
data: data_lines.join("\n"),
id,
retry,
});
}
events
}
pin_project! {
pub struct SseStream {
#[pin]
inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
buffer: String,
pending: std::collections::VecDeque<SseEvent>,
}
}
impl SseStream {
pub fn new(
byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
) -> Self {
Self {
inner: byte_stream,
buffer: String::new(),
pending: std::collections::VecDeque::new(),
}
}
}
impl Stream for SseStream {
type Item = Result<SseEvent, crate::ApiError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
if let Some(first_blank) = this.buffer.find("\n\n") {
let chunk = this.buffer[..first_blank].to_string();
*this.buffer = this.buffer[first_blank + 2..].to_string();
let mut parsed = parse_sse_chunk(&chunk);
if !parsed.is_empty() {
let first = parsed.remove(0);
this.pending.extend(parsed);
return Poll::Ready(Some(Ok(first)));
}
continue;
}
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
let text = String::from_utf8_lossy(&bytes);
let appended = text.replace("\r\n", "\n");
this.buffer.push_str(&appended);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(crate::ApiError::Http(e))));
}
Poll::Ready(None) => {
if this.buffer.is_empty() {
return Poll::Ready(None);
}
let remaining = std::mem::take(this.buffer);
let mut parsed = parse_sse_chunk(&remaining);
if parsed.is_empty() {
return Poll::Ready(None);
}
let first = parsed.remove(0);
this.pending.extend(parsed);
return Poll::Ready(Some(Ok(first)));
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
#[test]
fn parse_all_fields() {
let events = parse_sse_chunk("data: hello\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "hello");
assert_eq!(events[0].event, None);
let events = parse_sse_chunk("event: update\ndata: {\"foo\":1}\n\n");
assert_eq!(events[0].event.as_deref(), Some("update"));
let events = parse_sse_chunk("data: line1\ndata: line2\n\n");
assert_eq!(events[0].data, "line1\nline2");
let events = parse_sse_chunk("id: 42\nretry: 3000\ndata: ping\n\n");
assert_eq!(events[0].id.as_deref(), Some("42"));
assert_eq!(events[0].retry, Some(3000));
}
#[test]
fn parse_multiple_events_and_blanks() {
let events = parse_sse_chunk("data: first\n\ndata: second\n\n");
assert_eq!(events.len(), 2);
assert_eq!(events[0].data, "first");
assert_eq!(events[1].data, "second");
let events = parse_sse_chunk("data: a\n\n\n\ndata: b\n\n");
assert_eq!(events.len(), 2);
let events = parse_sse_chunk("data: hello\r\n\r\ndata: world\r\n\r\n");
assert_eq!(events.len(), 2);
assert_eq!(events[0].data, "hello");
}
#[test]
fn parse_edge_cases() {
assert!(parse_sse_chunk("").is_empty());
let events = parse_sse_chunk("data: unterminated");
assert_eq!(events[0].data, "unterminated");
let events = parse_sse_chunk("data:nospace\n\n");
assert_eq!(events[0].data, "nospace");
let events = parse_sse_chunk(": comment\ndata: real\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "real");
let events = parse_sse_chunk("retry: notanumber\ndata: x\n\n");
assert_eq!(events[0].retry, None);
}
#[test]
fn parse_state_reset_on_empty_block() {
let events = parse_sse_chunk("event: stale\n\ndata: clean\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].event, None);
let events = parse_sse_chunk("id: old\nretry: 5000\n\ndata: fresh\n\n");
assert_eq!(events[0].id, None);
assert_eq!(events[0].retry, None);
}
fn mock_byte_stream(
chunks: Vec<Result<Bytes, reqwest::Error>>,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>> {
Box::pin(futures_util::stream::iter(chunks))
}
#[tokio::test]
async fn stream_chunked_delivery() {
let chunks = vec![
Ok(Bytes::from("data: first\r\n\r\ndata: sec")),
Ok(Bytes::from("ond\n\n")),
];
let mut stream = SseStream::new(mock_byte_stream(chunks));
assert_eq!(stream.next().await.unwrap().unwrap().data, "first");
assert_eq!(stream.next().await.unwrap().unwrap().data, "second");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn stream_multiple_events_single_chunk() {
let chunks = vec![Ok(Bytes::from("data: a\n\ndata: b\n\ndata: c\n\n"))];
let mut stream = SseStream::new(mock_byte_stream(chunks));
assert_eq!(stream.next().await.unwrap().unwrap().data, "a");
assert_eq!(stream.next().await.unwrap().unwrap().data, "b");
assert_eq!(stream.next().await.unwrap().unwrap().data, "c");
assert!(stream.next().await.is_none());
let chunks = vec![Ok(Bytes::from("data: x\n\n\n\ndata: y\n\n"))];
let mut stream = SseStream::new(mock_byte_stream(chunks));
assert_eq!(stream.next().await.unwrap().unwrap().data, "x");
assert_eq!(stream.next().await.unwrap().unwrap().data, "y");
}
#[tokio::test]
async fn stream_end_of_inner() {
let mut stream = SseStream::new(mock_byte_stream(vec![]));
assert!(stream.next().await.is_none());
let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from("data: trailing"))]));
assert_eq!(stream.next().await.unwrap().unwrap().data, "trailing");
assert!(stream.next().await.is_none());
let mut stream = SseStream::new(mock_byte_stream(vec![Ok(Bytes::from(": comment only"))]));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn stream_error_from_inner() {
let err = reqwest::Client::new()
.get("http://localhost:1/x")
.header("bad\0header", "v")
.build()
.unwrap_err();
let mut stream = SseStream::new(mock_byte_stream(vec![Err(err)]));
assert!(stream.next().await.unwrap().is_err());
}
#[tokio::test]
async fn stream_pending_then_data() {
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, reqwest::Error>>(2);
let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let mut stream = SseStream::new(Box::pin(rx_stream));
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
tx.send(Ok(Bytes::from("data: delayed\n\n"))).await.unwrap();
drop(tx);
});
assert_eq!(stream.next().await.unwrap().unwrap().data, "delayed");
assert!(stream.next().await.is_none());
}
}