use crate::{ByteStream, error::Error, utils::str::memchr_split};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::stream::{Stream, TryStream};
use pin_project_lite::pin_project;
use serde::Serialize;
use std::fmt::Debug;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
const ID: &str = "id";
const EVENT: &str = "event";
const DATA: &str = "data";
const RETRY: &[u8] = b"retry:";
const ERROR: &str = "error";
const NEW_LINE: u8 = b'\n';
const EMPTY: &[u8] = b":\n";
pin_project! {
pub struct SseStream<S> {
#[pin]
inner: S,
}
}
impl<S> Debug for SseStream<S> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SseStream(...)").finish()
}
}
impl SseStream<()> {
#[inline]
pub fn from_try_messages<T>(
stream: T,
) -> SseStream<impl Stream<Item = Result<Message, Error>> + Send + 'static>
where
T: TryStream<Ok = Message, Error = Error> + Send + 'static,
{
use futures_util::TryStreamExt;
SseStream::new(stream.into_stream())
}
}
impl<S> SseStream<S>
where
S: Stream<Item = Result<Message, Error>> + Send + 'static,
{
#[inline]
pub fn new(inner: S) -> Self {
Self { inner }
}
#[inline]
pub fn into_bytes(self) -> impl Stream<Item = Result<Bytes, Error>> + Send + 'static {
use futures_util::StreamExt;
self.map(|m| m.map(Bytes::from))
}
#[inline]
pub fn into_byte_stream(
self,
) -> ByteStream<impl Stream<Item = Result<Bytes, Error>> + Send + 'static> {
ByteStream::new(self.into_bytes())
}
#[inline]
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S> Stream for SseStream<S>
where
S: Stream<Item = Result<Message, Error>> + Send + 'static,
{
type Item = Result<Message, Error>;
#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project().inner.poll_next(cx) {
Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[macro_export]
macro_rules! sse_stream {
{ $($tt:tt)* } => {{
$crate::http::sse::SseStream::from_try_messages(
$crate::__async_stream::try_stream! { $($tt)* }
)
}};
}
#[derive(Debug, Default, Clone)]
pub struct Message {
fields: Vec<SseField>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FieldKind {
Comment,
Data,
Event,
Id,
Retry,
}
#[derive(Debug, Clone)]
struct SseField {
kind: FieldKind,
bytes: Bytes,
}
impl Message {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn once(self) -> SseStream<impl Stream<Item = Result<Message, Error>> + Send> {
SseStream::new(futures_util::stream::iter([Ok(self)]))
}
#[inline]
pub fn repeat(self) -> SseStream<impl Stream<Item = Result<Message, Error>> + Send> {
SseStream::new(futures_util::stream::repeat_with(move || Ok(self.clone())))
}
#[inline]
pub fn empty() -> Self {
let mut msg = Self::default();
msg.fields.push(SseField {
kind: FieldKind::Comment,
bytes: Bytes::from_static(EMPTY),
});
msg
}
#[inline]
pub fn data(mut self, value: impl AsRef<[u8]>) -> Self {
let mut buffer = BytesMut::new();
for line in memchr_split(NEW_LINE, value.as_ref()) {
buffer.extend(Self::field(DATA, line));
}
self.remove_fields(FieldKind::Data);
self.fields.push(SseField {
kind: FieldKind::Data,
bytes: buffer.freeze(),
});
self
}
#[inline]
pub fn append(mut self, value: impl AsRef<[u8]>) -> Self {
let mut buffer = BytesMut::new();
for line in memchr_split(NEW_LINE, value.as_ref()) {
buffer.extend(Self::field(DATA, line));
}
self.fields.push(SseField {
kind: FieldKind::Data,
bytes: buffer.freeze(),
});
self
}
#[inline]
pub fn json<T: Serialize>(self, value: T) -> Self {
match serde_json::to_vec(&value) {
Ok(v) => self.data(v),
Err(err) => self.event(ERROR).data(err.to_string()),
}
}
#[inline]
pub fn event(mut self, name: &str) -> Self {
self.remove_fields(FieldKind::Event);
self.fields.push(SseField {
kind: FieldKind::Event,
bytes: Self::field(EVENT, name),
});
self
}
#[inline]
pub fn id(mut self, value: impl AsRef<[u8]>) -> Self {
self.remove_fields(FieldKind::Id);
self.fields.push(SseField {
kind: FieldKind::Id,
bytes: Self::field(ID, value),
});
self
}
#[inline]
pub fn retry(mut self, duration: Duration) -> Self {
let mut buffer = BytesMut::new();
buffer.extend_from_slice(RETRY);
buffer.extend_from_slice(itoa::Buffer::new().format(duration.as_millis()).as_ref());
buffer.put_u8(NEW_LINE);
self.remove_fields(FieldKind::Retry);
self.fields.push(SseField {
kind: FieldKind::Retry,
bytes: buffer.freeze(),
});
self
}
#[inline]
pub fn comment(mut self, value: impl AsRef<[u8]>) -> Self {
self.fields.push(SseField {
kind: FieldKind::Comment,
bytes: Self::field("", value),
});
self
}
#[inline]
fn remove_fields(&mut self, kind: FieldKind) {
self.fields.retain(|field| field.kind != kind);
}
#[inline]
fn field(name: &str, value: impl AsRef<[u8]>) -> Bytes {
let mut buffer = BytesMut::new();
buffer.extend_from_slice(name.as_bytes());
buffer.put_u8(b':');
buffer.put_u8(b' ');
buffer.extend_from_slice(value.as_ref());
buffer.put_u8(NEW_LINE);
buffer.freeze()
}
}
impl<T: Serialize> From<T> for Message {
#[inline]
fn from(value: T) -> Self {
Self::default().json(value)
}
}
impl From<Message> for Bytes {
#[inline]
fn from(message: Message) -> Self {
let mut buffer = BytesMut::new();
for field in message.fields {
buffer.extend(field.bytes);
}
buffer.put_u8(NEW_LINE);
buffer.freeze()
}
}
#[cfg(test)]
mod tests {
use super::Message;
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt, pin_mut};
use serde::Serialize;
use std::time::Duration;
#[tokio::test]
async fn it_creates_message_repeat_stream() {
let stream = Message::new().data("hi!").repeat();
pin_mut!(stream);
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
}
#[tokio::test]
async fn it_creates_message_once_stream() {
let stream = Message::new().data("hi!").once();
pin_mut!(stream);
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn it_creates_sse_stream() {
let stream = sse_stream! {
yield Message::new().data("hi!");
yield Message::new().data("hi!");
yield Message::new().data("hi!");
};
pin_mut!(stream);
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn it_creates_sse_stream_with_loop() {
let stream = sse_stream! {
loop {
yield Message::new().data("hi!");
}
};
pin_mut!(stream);
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
}
#[tokio::test]
async fn it_modifies_sse_stream() {
let stream = sse_stream! {
yield Message::new().data("hi!");
};
let stream = stream.map_ok(|msg| msg.comment("some comment"));
pin_mut!(stream);
let bytes = Bytes::from(stream.next().await.unwrap().unwrap());
assert_eq!(
String::from_utf8_lossy(&bytes),
"data: hi!\n: some comment\n\n"
);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn it_converts_into_bytes() {
let stream = sse_stream! {
yield Message::new().data("hi!");
};
let stream = stream.into_bytes();
pin_mut!(stream);
let bytes = stream.next().await.unwrap().unwrap();
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn it_converts_into_byte_stream() {
let stream = sse_stream! {
yield Message::new().data("hi!");
};
let stream = stream.into_byte_stream();
pin_mut!(stream);
let bytes = stream.next().await.unwrap().unwrap();
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
assert!(stream.next().await.is_none());
}
#[test]
fn it_creates_default_message() {
let event = Message::default();
let bytes: Bytes = event.into();
assert_eq!(String::from_utf8_lossy(&bytes), "\n");
}
#[test]
fn it_creates_empty_message() {
let event = Message::empty();
let bytes: Bytes = event.into();
assert_eq!(String::from_utf8_lossy(&bytes), ":\n\n");
}
#[test]
fn it_creates_data_message_with_comment() {
let event = Message::new().comment("some comment").data("hi!");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
": some comment\ndata: hi!\n\n"
);
}
#[test]
fn it_creates_data_message_with_multiple_comment() {
let event = Message::new()
.comment("some comment")
.data("hi!")
.comment("another comment")
.comment("one more comment");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
": some comment\ndata: hi!\n: another comment\n: one more comment\n\n"
);
}
#[test]
fn it_creates_string_message() {
let event = Message::new().data("hi!");
let bytes: Bytes = event.into();
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\n\n");
}
#[test]
fn it_appends_string_data() {
let event = Message::new().data("Hello").append("World");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"data: Hello\ndata: World\n\n"
);
}
#[test]
fn it_creates_multiline_string_data() {
let event = Message::new().data("Hello \nbeautiful \nworld!");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"data: Hello \ndata: beautiful \ndata: world!\n\n"
);
}
#[test]
fn it_creates_string_event() {
let event = Message::new().event("greet").data("hi!");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"event: greet\ndata: hi!\n\n"
);
}
#[test]
fn it_creates_string_event_with_id() {
let event = Message::new().id("some id").event("greet").data("hi!");
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"id: some id\nevent: greet\ndata: hi!\n\n"
);
}
#[test]
fn it_creates_message_with_retry() {
let event = Message::new().data("hi!").retry(Duration::from_secs(5));
let bytes: Bytes = event.into();
assert_eq!(String::from_utf8_lossy(&bytes), "data: hi!\nretry:5000\n\n");
}
#[test]
fn it_creates_json_event() {
let event = Message::new().json(Test {
value: "test".into(),
});
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"data: {\"value\":\"test\"}\n\n"
);
}
#[test]
fn it_converts_json_into_event() {
let data = Test {
value: "test".into(),
};
let event: Message = data.into();
let bytes: Bytes = event.into();
assert_eq!(
String::from_utf8_lossy(&bytes),
"data: {\"value\":\"test\"}\n\n"
);
}
#[derive(Serialize)]
struct Test {
value: String,
}
}