#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![warn(rustdoc::missing_crate_level_docs)]
use std::{
future::Future,
num::NonZeroUsize,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
time::Duration,
};
use bytes::Bytes;
use futures_core::stream::Stream;
use reqwest::{RequestBuilder, StatusCode};
use thiserror::Error;
use tokio::time::{Sleep, sleep};
pub use sse_core::SseRetryConfig;
use sse_core::{
MessageEvent, PayloadTooLargeError, SseDecoder, SseEvent as SseEventCore, SseStream,
SseStreamError,
};
type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync>>;
type ConnectFuture =
Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Send + Sync>>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("Unexpected HTTP status code: {0}")]
Status(StatusCode),
#[error("Request builder could not be cloned (e.g., non-restartable body stream)")]
UncloneableRequest,
#[error("Invalid response HTTP Content-Type")]
InvalidContentType,
#[error("Response HTTP Content-Type missing")]
MissingContentType,
#[error("Couldn't reconnect to SSE server in {0} attempts: {1}")]
Timeout(u32, SseErrorEvent),
#[error("Server sent an oversized payload exceeding the allotted buffer")]
PayloadTooLarge(#[from] PayloadTooLargeError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum ReadyState {
Connecting = 0,
Open = 1,
Closed = 2,
}
enum State {
Disconnected,
Connecting(ConnectFuture),
Open,
Sleeping(Pin<Box<Sleep>>),
Closed,
}
#[derive(Debug, Error)]
pub enum SseErrorEvent {
#[error("Server cleanly closed the connection (EOF)")]
Eof,
#[error("Transient HTTP error: {0}")]
Http(StatusCode),
#[error("Network or transport error: {0}")]
Network(#[from] reqwest::Error),
}
#[derive(Debug)]
pub enum SseEvent {
Open,
Message(MessageEvent),
Error(SseErrorEvent),
}
impl SseEvent {
pub fn into_message(self) -> Option<MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
pub fn as_message(&self) -> Option<&MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
pub fn as_message_mut(&mut self) -> Option<&mut MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
}
impl From<MessageEvent> for SseEvent {
fn from(event: MessageEvent) -> Self {
Self::Message(event)
}
}
impl From<SseErrorEvent> for SseEvent {
fn from(err: SseErrorEvent) -> Self {
Self::Error(err)
}
}
#[derive(Debug, Error)]
#[error("Couldn't convert Event::{} into a MessageEvent", match .0 {
SseEvent::Open => "Open",
SseEvent::Message(_) => "Message",
SseEvent::Error(_) => "Error"
})]
pub struct FromMessageEventError(pub SseEvent);
impl TryFrom<SseEvent> for MessageEvent {
type Error = FromMessageEventError;
fn try_from(ev: SseEvent) -> Result<Self, Self::Error> {
match ev {
SseEvent::Message(msg) => Ok(msg),
ev => Err(FromMessageEventError(ev)),
}
}
}
pub struct EventSourceBuilder {
req: RequestBuilder,
retry_config: SseRetryConfig,
reconnection_time_ms: u32,
max_payload_size: Option<NonZeroUsize>,
last_event_id: Option<Arc<str>>,
retry_transient_errors: bool,
}
impl EventSourceBuilder {
#[must_use]
pub fn new(req: RequestBuilder) -> Self {
Self {
req,
reconnection_time_ms: 3000, retry_config: SseRetryConfig::new(),
max_payload_size: None, last_event_id: None,
retry_transient_errors: false,
}
}
#[inline]
#[must_use]
pub fn retry_config(mut self, retry_config: SseRetryConfig) -> Self {
self.retry_config = retry_config;
self
}
#[inline]
#[must_use]
pub fn initial_reconnection_time(mut self, reconnection_time: Duration) -> Self {
self.reconnection_time_ms = reconnection_time
.as_millis()
.try_into()
.expect("Read duration too long");
self
}
#[inline]
#[must_use]
pub fn max_payload_size(mut self, max_payload_size: NonZeroUsize) -> Self {
self.max_payload_size = Some(max_payload_size);
self
}
#[inline]
#[must_use]
pub fn last_event_id(mut self, id: impl Into<Arc<str>>) -> Self {
self.last_event_id = Some(id.into());
self
}
#[inline]
#[must_use]
pub fn retry_transient_errors(mut self, retry: bool) -> Self {
self.retry_transient_errors = retry;
self
}
#[must_use]
pub fn build(self) -> EventSource {
let mut decoder = match self.max_payload_size {
Some(max_payload_size) => SseDecoder::with_limit(max_payload_size),
None => SseDecoder::new(),
};
decoder.reconnect_with_id(self.last_event_id);
EventSource {
req: self.req,
reconnection_time_ms: self.reconnection_time_ms,
connection_attempt: 0,
retry_config: self.retry_config,
retry_transient_errors: self.retry_transient_errors,
stream: SseStream::with_decoder(decoder),
state: State::Disconnected,
}
}
}
pub struct EventSource {
req: RequestBuilder,
reconnection_time_ms: u32,
connection_attempt: u32,
retry_config: SseRetryConfig,
retry_transient_errors: bool,
stream: SseStream<ByteStream>,
state: State,
}
impl EventSource {
#[must_use]
pub fn new(req: RequestBuilder) -> Self {
Self::builder(req).build()
}
#[must_use]
pub fn builder(req: RequestBuilder) -> EventSourceBuilder {
EventSourceBuilder::new(req)
}
pub fn close(&mut self) {
self.stream.close();
self.state = State::Closed;
}
#[inline]
#[must_use]
pub fn ready_state(&self) -> ReadyState {
match &self.state {
State::Disconnected | State::Connecting(_) | State::Sleeping(_) => {
ReadyState::Connecting
}
State::Open => ReadyState::Open,
State::Closed => ReadyState::Closed,
}
}
#[inline]
#[must_use]
pub fn last_event_id(&self) -> Option<&Arc<str>> {
self.stream.last_event_id()
}
#[inline]
pub fn force_reconnect(&mut self) {
self.stream.close();
self.connection_attempt = 0;
self.state = State::Disconnected;
}
#[inline]
pub fn force_reconnect_with_id(&mut self, id: Option<Arc<str>>) {
self.stream.close();
self.connection_attempt = 0;
self.state = State::Disconnected;
self.stream.close_with_id(id)
}
fn go_to_sleep(&mut self, cause: SseErrorEvent) -> Result<SseEvent> {
let wait_dur = (self.retry_config)
.calculate_backoff(self.reconnection_time_ms, self.connection_attempt);
self.connection_attempt += 1;
if let Some(dur) = wait_dur {
self.state = State::Sleeping(Box::pin(sleep(dur)));
Ok(SseEvent::Error(cause))
} else {
self.close();
Err(Error::Timeout(self.connection_attempt, cause))
}
}
}
impl Stream for EventSource {
type Item = Result<SseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let slf = &mut *self;
loop {
match &mut slf.state {
State::Disconnected => {
let Some(mut req) = slf.req.try_clone() else {
slf.close();
return Poll::Ready(Some(Err(Error::UncloneableRequest)));
};
req = req
.header(reqwest::header::ACCEPT, "text/event-stream")
.header(reqwest::header::CACHE_CONTROL, "no-store");
if let Some(last_event_id) = slf.stream.last_event_id() {
req = req.header("Last-Event-ID", &**last_event_id);
}
let fut = Box::pin(req.send());
slf.state = State::Connecting(fut);
}
State::Connecting(fut) => match ready!(fut.as_mut().poll(cx)) {
Ok(res) => {
let status = res.status();
if matches!(status, StatusCode::NO_CONTENT) {
slf.close();
return Poll::Ready(None);
}
let is_transient_error = matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
);
if slf.retry_transient_errors && is_transient_error {
return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Http(status))));
} else if status != StatusCode::OK {
slf.close();
return Poll::Ready(Some(Err(Error::Status(status))));
}
let Some(content_type) = res
.headers()
.get(reqwest::header::CONTENT_TYPE)
.map(|v| v.as_bytes())
else {
slf.close();
return Poll::Ready(Some(Err(Error::MissingContentType)));
};
const MIME_EVENT_STREAM: &str = "text/event-stream";
if !(content_type.starts_with(MIME_EVENT_STREAM.as_bytes())
&& matches!(
content_type.get(MIME_EVENT_STREAM.len()),
None | Some(b';' | b' ' | b'\t')
))
{
slf.close();
return Poll::Ready(Some(Err(Error::InvalidContentType)));
}
slf.state = State::Open;
slf.connection_attempt = 0;
slf.stream.attach(Box::pin(res.bytes_stream()));
return Poll::Ready(Some(Ok(SseEvent::Open)));
}
Err(err) => {
slf.close();
return Poll::Ready(Some(slf.go_to_sleep(err.into())));
}
},
State::Open => match ready!(Pin::new(&mut slf.stream).poll_next(cx)) {
Some(Ok(raw_event)) => match raw_event {
SseEventCore::Retry(ms) => slf.reconnection_time_ms = ms,
SseEventCore::Message(event) => return Poll::Ready(Some(Ok(event.into()))),
},
Some(Err(SseStreamError::PayloadTooLarge(err))) => {
slf.close();
return Poll::Ready(Some(Err(Error::PayloadTooLarge(err))));
}
Some(Err(SseStreamError::Inner(err))) => {
return Poll::Ready(Some(slf.go_to_sleep(err.into())));
}
None => return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Eof))),
},
State::Sleeping(sleep_fut) => {
ready!(sleep_fut.as_mut().poll(cx));
slf.state = State::Disconnected;
}
State::Closed => return Poll::Ready(None),
}
}
}
}
mod sealed {
pub trait Sealed {}
}
pub trait RequestBuilderExt: sealed::Sealed {
fn into_event_source(self) -> EventSource;
fn into_event_source_builder(self) -> EventSourceBuilder;
}
impl sealed::Sealed for RequestBuilder {}
impl RequestBuilderExt for RequestBuilder {
fn into_event_source(self) -> EventSource {
EventSource::new(self)
}
fn into_event_source_builder(self) -> EventSourceBuilder {
EventSourceBuilder::new(self)
}
}