use alloc::sync::Arc;
use core::{
pin::Pin,
str,
task::{self, Poll, ready},
};
use thiserror::Error;
use bytes::Buf;
use futures_core::{
TryStream,
stream::{FusedStream, Stream},
};
use pin_project_lite::pin_project;
use crate::{PayloadTooLargeError, SseDecoder, SseEvent};
pin_project! {
#[derive(Debug, Clone, Default)]
pub struct SseStream<T: TryStream> {
#[pin]
inner: Option<T>,
buf: Option<T::Ok>,
decoder: SseDecoder,
}
}
impl<T: TryStream> SseStream<T> {
#[inline]
#[must_use]
pub fn disconnected() -> Self {
Self::with_decoder(SseDecoder::new())
}
#[inline]
#[must_use]
pub fn with_decoder(decoder: SseDecoder) -> Self {
Self {
inner: None,
buf: None,
decoder,
}
}
#[inline]
#[must_use]
pub fn new(inner: T) -> Self {
let mut slf = Self::disconnected();
slf.inner = Some(inner);
slf
}
#[inline]
pub fn take_decoder(self) -> SseDecoder {
let Self { mut decoder, .. } = self;
decoder.reconnect();
decoder
}
#[inline]
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.is_none()
}
#[inline]
#[must_use]
pub fn last_event_id(&self) -> Option<&Arc<str>> {
self.decoder.last_event_id()
}
#[inline]
pub fn close(&mut self) {
self.inner = None;
}
#[inline]
pub fn close_and_clear(&mut self) {
self.decoder.clear();
self.close();
}
#[inline]
pub fn close_with_id(&mut self, id: Option<Arc<str>>) {
self.decoder.reconnect_with_id(id);
self.close();
}
#[inline]
pub fn attach(&mut self, inner: T) {
self.decoder.reconnect();
self.buf = None;
self.inner = Some(inner);
}
#[inline]
pub fn attach_with_id(&mut self, inner: T, id: Option<Arc<str>>) {
self.decoder.reconnect_with_id(id);
self.buf = None;
self.inner = Some(inner);
}
}
pub type SseStreamResult<T, E> = Result<T, SseStreamError<E>>;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum SseStreamError<T> {
#[error("{0}")]
PayloadTooLarge(PayloadTooLargeError),
#[error("{0}")]
Inner(#[from] T),
}
impl<T: TryStream> Stream for SseStream<T>
where
T::Ok: Buf,
{
type Item = SseStreamResult<SseEvent, T::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut slf = self.project();
let Some(mut inner) = slf.inner.as_mut().as_pin_mut() else {
return Poll::Ready(None);
};
loop {
if let Some(event) = (slf.buf.as_mut())
.and_then(|buf| slf.decoder.next(buf).transpose())
.transpose()
.map_err(SseStreamError::PayloadTooLarge)?
{
return Poll::Ready(Some(Ok(event)));
};
*slf.buf = ready!(inner.as_mut().try_poll_next(cx)?);
if slf.buf.is_none() {
slf.inner.set(None);
return Poll::Ready(None);
}
}
}
}
impl<T: TryStream> FusedStream for SseStream<T>
where
T::Ok: Buf,
{
fn is_terminated(&self) -> bool {
self.is_closed()
}
}
#[test]
fn hard_parse() -> Result<(), PayloadTooLargeError> {
use crate::MessageEvent;
use std::slice;
use tokio_stream::StreamExt;
tokio_test::block_on(async {
let bytes = "
:
event: my-event\r
data:line1
data: line2
:
id: my-id
:should be ignored too\rretry:42
retry:
data:second
";
let mut inner = tokio_test::stream_mock::StreamMockBuilder::new();
for b in bytes.as_bytes() {
inner = inner.next(Ok(slice::from_ref(b)));
}
inner = inner
.next(Err(()))
.next(Ok(b"data: hello\n\ndata:ignored\n"));
let id = Some("my-id".into());
let mut stream = SseStream::new(inner.build());
let events: Vec<_> = (&mut stream).collect().await;
assert_eq!(
events,
&[
Ok(SseEvent::Retry(42)),
Ok(SseEvent::Message(MessageEvent {
event: "my-event".into(),
data: "line1\nline2".into(),
last_event_id: id.clone()
})),
Ok(SseEvent::Message(MessageEvent {
event: "message".into(),
data: "second".into(),
last_event_id: id.clone()
})),
Err(SseStreamError::Inner(())),
Ok(SseEvent::Message(MessageEvent {
event: "message".into(),
data: "hello".into(),
last_event_id: id.clone()
})),
]
);
assert!(stream.is_closed());
Ok(())
})
}