use eventsource_stream::{Event, Eventsource};
use futures_util::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct SseStream<S> {
inner: eventsource_stream::EventStream<S>,
}
impl<S> SseStream<S> {
pub fn new(stream: S) -> Self
where
S: Stream + Eventsource,
{
Self {
inner: stream.eventsource(),
}
}
pub fn set_last_event_id(&mut self, id: impl Into<String>) {
self.inner.set_last_event_id(id);
}
pub fn last_event_id(&self) -> &str {
self.inner.last_event_id()
}
}
impl<S, B, E> Stream for SseStream<S>
where
S: Stream<Item = Result<B, E>> + Unpin,
B: AsRef<[u8]>,
{
type Item = Result<Event, eventsource_stream::EventStreamError<E>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
pub trait SseStreamExt: Sized {
fn into_sse_stream(self) -> SseStream<Self>;
}
impl<S, B, E> SseStreamExt for S
where
S: Stream<Item = Result<B, E>>,
B: AsRef<[u8]>,
{
fn into_sse_stream(self) -> SseStream<Self> {
SseStream::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
#[tokio::test]
async fn test_sse_stream_basic() {
let data: Vec<Result<&[u8], std::io::Error>> = vec![
Ok(b"data: hello\n\n".as_slice()),
Ok(b"data: world\n\n".as_slice()),
];
let stream = futures_util::stream::iter(data);
let mut sse_stream = stream.into_sse_stream();
let event1 = sse_stream.next().await.unwrap().unwrap();
assert_eq!(event1.data, "hello");
let event2 = sse_stream.next().await.unwrap().unwrap();
assert_eq!(event2.data, "world");
}
#[tokio::test]
async fn test_sse_stream_utf8_handling() {
let chinese_text = "你好世界";
let sse_data = format!("data: {chinese_text}\n\n");
let full_data: Vec<Result<&[u8], std::io::Error>> = vec![Ok(sse_data.as_bytes())];
let stream = futures_util::stream::iter(full_data);
let mut sse_stream = stream.into_sse_stream();
let event = sse_stream.next().await.unwrap().unwrap();
assert_eq!(event.data, chinese_text);
}
}