use std::{
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use crate::{
BoxError,
body::{Body, codec::BodyDecoder},
};
pub use sse_stream::{Error as SseError, Sse};
pub type SseEventStream = sse_stream::SseStream<SseBody>;
pub struct SseBody {
inner: Body,
}
impl SseBody {
fn new(inner: Body) -> Self {
Self { inner }
}
}
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct SseBodyError(#[source] BoxError);
impl http_body::Body for SseBody {
type Data = Bytes;
type Error = SseBodyError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.inner) };
match inner.poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(SseBodyError(e)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct SseDecoder;
impl BodyDecoder<SseEventStream> for SseDecoder {
type Error = std::convert::Infallible;
async fn decode<B>(&self, body: B) -> Result<SseEventStream, Self::Error>
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
Ok(sse_stream::SseStream::new(SseBody::new(Body::new(body))))
}
}