use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use tracing::trace;
use crate::error::Result;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
#[cfg(target_arch = "wasm32")]
pub(crate) type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static>>;
pub(crate) struct SseFrameStream {
upstream: ByteStream,
buffer: BytesMut,
done: bool,
sentinel: Option<&'static [u8]>,
label: &'static str,
}
impl SseFrameStream {
pub(crate) fn new(
upstream: ByteStream,
sentinel: Option<&'static [u8]>,
label: &'static str,
) -> Self {
Self {
upstream,
buffer: BytesMut::with_capacity(8 * 1024),
done: false,
sentinel,
label,
}
}
fn take_frame(&mut self) -> Option<Vec<u8>> {
let bytes = &self.buffer[..];
let mut i = 0;
while i < bytes.len() {
if i + 3 < bytes.len()
&& bytes[i] == b'\r'
&& bytes[i + 1] == b'\n'
&& bytes[i + 2] == b'\r'
&& bytes[i + 3] == b'\n'
{
let frame = self.buffer.split_to(i + 4);
return Some(extract_data_payload(&frame));
}
if i + 1 < bytes.len() && bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
let frame = self.buffer.split_to(i + 2);
return Some(extract_data_payload(&frame));
}
i += 1;
}
None
}
fn take_remaining(&mut self) -> Option<Vec<u8>> {
if self.buffer.is_empty() {
return None;
}
let frame = self.buffer.split_to(self.buffer.len());
Some(extract_data_payload(&frame))
}
}
impl Stream for SseFrameStream {
type Item = Result<Vec<u8>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.done {
let payload = match self.take_frame() {
Some(p) => p,
None => match self.take_remaining() {
Some(p) => p,
None => return Poll::Ready(None),
},
};
if payload.is_empty() {
continue;
}
if Some(payload.as_slice()) == self.sentinel {
continue;
}
return Poll::Ready(Some(Ok(payload)));
}
if let Some(payload) = self.take_frame() {
if payload.is_empty() {
continue;
}
if Some(payload.as_slice()) == self.sentinel {
self.done = true;
self.buffer.clear();
continue;
}
return Poll::Ready(Some(Ok(payload)));
}
match self.upstream.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => {
trace!(len = bytes.len(), "{} sse bytes", self.label);
self.buffer.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => self.done = true,
}
}
}
}
fn extract_data_payload(frame: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(frame.len());
let text = std::str::from_utf8(frame).unwrap_or("");
for line in text.split('\n') {
let line = line.trim_end_matches('\r');
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
if !out.is_empty() {
out.push(b'\n');
}
out.extend_from_slice(rest.as_bytes());
}
}
out
}