#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![warn(rustdoc::missing_crate_level_docs)]
use std::{
fmt,
future::Future,
num::NonZeroUsize,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
time::Duration,
};
use bytes::Bytes;
use futures_core::stream::Stream;
use reqwest::{RequestBuilder, StatusCode, header::HeaderValue};
use thiserror::Error;
use tokio::time::{Instant, Sleep, sleep};
pub use sse_core::SseRetryConfig;
use sse_core::{
MessageEvent, PayloadTooLargeError, SseDecoder, SseEvent as SseEventCore, SseStream,
SseStreamError,
};
type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync>>;
type ConnectFuture =
Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Send + Sync>>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("unexpected HTTP status code: {0}")]
Status(StatusCode),
#[error("request builder could not be cloned (e.g., non-restartable body stream)")]
UncloneableRequest,
#[error("invalid response HTTP Content-Type")]
InvalidContentType,
#[error("response HTTP Content-Type missing")]
MissingContentType,
#[error("couldn't reconnect to SSE server in {0} attempts: {1}")]
Timeout(u32, SseErrorEvent),
#[error("server sent an oversized payload exceeding the allotted buffer")]
PayloadTooLarge(#[from] PayloadTooLargeError),
#[error("Last-Event-ID cannot be converted to a valid HTTP header: {0}")]
InvalidLastEventId(reqwest::header::InvalidHeaderValue),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum ReadyState {
Connecting = 0,
Open = 1,
Closed = 2,
}
enum State {
Disconnected,
Connecting(ConnectFuture),
Open,
Sleeping(Pin<Box<Sleep>>),
Closed,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
State::Disconnected => f.write_str("Disconnected"),
State::Connecting(_) => f.write_str("Connecting(_)"),
State::Open => f.write_str("Open"),
State::Sleeping(fut) => f.debug_tuple("Sleeping").field(fut).finish(),
State::Closed => f.write_str("Closed"),
}
}
}
#[derive(Debug, Error)]
pub enum SseErrorEvent {
#[error("server cleanly closed the connection (EOF)")]
Eof,
#[error("transient HTTP error: {0}")]
Http(StatusCode),
#[error("network or transport error: {0}")]
Network(#[from] reqwest::Error),
}
#[derive(Debug)]
pub enum SseEvent {
Open,
Message(MessageEvent),
Error(SseErrorEvent),
}
impl SseEvent {
pub fn into_message(self) -> Option<MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
pub fn as_message(&self) -> Option<&MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
pub fn as_message_mut(&mut self) -> Option<&mut MessageEvent> {
match self {
Self::Message(msg) => Some(msg),
Self::Open | Self::Error(_) => None,
}
}
}
impl From<MessageEvent> for SseEvent {
fn from(event: MessageEvent) -> Self {
Self::Message(event)
}
}
impl From<SseErrorEvent> for SseEvent {
fn from(err: SseErrorEvent) -> Self {
Self::Error(err)
}
}
#[derive(Debug, Error)]
#[error("couldn't convert Event::{} into a MessageEvent", match .0 {
SseEvent::Open => "Open",
SseEvent::Message(_) => "Message",
SseEvent::Error(_) => "Error"
})]
pub struct FromMessageEventError(pub SseEvent);
impl TryFrom<SseEvent> for MessageEvent {
type Error = FromMessageEventError;
fn try_from(ev: SseEvent) -> Result<Self, Self::Error> {
match ev {
SseEvent::Message(msg) => Ok(msg),
ev => Err(FromMessageEventError(ev)),
}
}
}
#[derive(Debug)]
pub struct EventSourceBuilder {
req: RequestBuilder,
retry_config: SseRetryConfig,
reconnection_time_ms: u32,
max_payload_size: Option<NonZeroUsize>,
last_event_id: Option<Arc<str>>,
retry_transient_errors: bool,
successful_connection_threshold: Duration,
}
impl EventSourceBuilder {
#[must_use]
pub fn new(req: RequestBuilder) -> Self {
Self {
req,
reconnection_time_ms: 3000, retry_config: SseRetryConfig::new(),
max_payload_size: None, last_event_id: None,
retry_transient_errors: false,
successful_connection_threshold: Duration::from_secs(5),
}
}
#[inline]
#[must_use]
pub fn retry_config(mut self, retry_config: SseRetryConfig) -> Self {
self.retry_config = retry_config;
self
}
#[inline]
#[must_use]
pub fn initial_reconnection_time(mut self, reconnection_time: Duration) -> Self {
self.reconnection_time_ms = reconnection_time
.as_millis()
.try_into()
.expect("Read duration too long");
self
}
#[inline]
#[must_use]
pub fn max_payload_size(mut self, max_payload_size: NonZeroUsize) -> Self {
self.max_payload_size = Some(max_payload_size);
self
}
#[inline]
#[must_use]
pub fn last_event_id(mut self, id: impl Into<Arc<str>>) -> Self {
self.last_event_id = Some(id.into());
self
}
#[inline]
#[must_use]
pub fn retry_transient_errors(mut self, retry: bool) -> Self {
self.retry_transient_errors = retry;
self
}
#[inline]
#[must_use]
pub fn successful_connection_threshold(mut self, threshold: Duration) -> Self {
self.successful_connection_threshold = threshold;
self
}
#[must_use]
pub fn build(self) -> EventSource {
let mut decoder = match self.max_payload_size {
Some(max_payload_size) => SseDecoder::with_limit(max_payload_size),
None => SseDecoder::new(),
};
decoder.reconnect_with_id(self.last_event_id);
EventSource {
req: (self.req)
.header(reqwest::header::ACCEPT, "text/event-stream")
.header(reqwest::header::CACHE_CONTROL, "no-store"),
reconnection_time_ms: self.reconnection_time_ms,
connection_attempt: 0,
connected_since: None,
retry_config: self.retry_config,
retry_transient_errors: self.retry_transient_errors,
successful_connection_threshold: self.successful_connection_threshold,
stream: SseStream::with_decoder(decoder),
state: State::Disconnected,
}
}
}
pub struct EventSource {
req: RequestBuilder,
reconnection_time_ms: u32,
connection_attempt: u32,
connected_since: Option<Instant>,
retry_config: SseRetryConfig,
retry_transient_errors: bool,
successful_connection_threshold: Duration,
stream: SseStream<ByteStream>,
state: State,
}
impl fmt::Debug for EventSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EventSource")
.field("req", &self.req)
.field("reconnection_time_ms", &self.reconnection_time_ms)
.field("connection_attempt", &self.connection_attempt)
.field("retry_config", &self.retry_config)
.field("retry_transient_errors", &self.retry_transient_errors)
.field("state", &self.state)
.field(
"stream.last_event_id()",
&self.stream.last_event_id().map(|id| &**id),
)
.field("stream.is_closed()", &self.stream.is_closed())
.finish_non_exhaustive()
}
}
impl EventSource {
#[must_use]
pub fn new(req: RequestBuilder) -> Self {
Self::builder(req).build()
}
#[must_use]
pub fn builder(req: RequestBuilder) -> EventSourceBuilder {
EventSourceBuilder::new(req)
}
pub fn close(&mut self) {
self.stream.close();
self.state = State::Closed;
}
#[inline]
#[must_use]
pub fn ready_state(&self) -> ReadyState {
match &self.state {
State::Disconnected | State::Connecting(_) | State::Sleeping(_) => {
ReadyState::Connecting
}
State::Open => ReadyState::Open,
State::Closed => ReadyState::Closed,
}
}
#[inline]
#[must_use]
pub fn last_event_id(&self) -> Option<&Arc<str>> {
self.stream.last_event_id()
}
#[inline]
pub fn force_reconnect(&mut self) {
self.stream.close();
self.connection_attempt = 0;
self.state = State::Disconnected;
}
#[inline]
pub fn force_reconnect_with_id(&mut self, id: Option<Arc<str>>) {
self.stream.close_with_id(id);
self.connection_attempt = 0;
self.state = State::Disconnected;
}
fn go_to_sleep(&mut self, cause: SseErrorEvent) -> Result<SseEvent> {
if let Some(connected_since) = self.connected_since.take() {
if self.successful_connection_threshold <= connected_since.elapsed() {
self.connection_attempt = 0;
}
}
let wait_dur = (self.retry_config)
.calculate_backoff(self.reconnection_time_ms, self.connection_attempt);
self.connection_attempt += 1;
if let Some(dur) = wait_dur {
self.state = State::Sleeping(Box::pin(sleep(dur)));
Ok(SseEvent::Error(cause))
} else {
self.close();
Err(Error::Timeout(self.connection_attempt, cause))
}
}
}
impl Stream for EventSource {
type Item = Result<SseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let slf = &mut *self;
loop {
match &mut slf.state {
State::Disconnected => {
let Some(mut req) = slf.req.try_clone() else {
slf.close();
return Poll::Ready(Some(Err(Error::UncloneableRequest)));
};
if let Some(last_event_id) = slf.stream.last_event_id() {
match HeaderValue::from_str(last_event_id) {
Ok(val) => req = req.header("Last-Event-ID", val),
Err(err) => {
slf.close();
return Poll::Ready(Some(Err(Error::InvalidLastEventId(err))));
}
}
}
let fut = Box::pin(req.send());
slf.state = State::Connecting(fut);
}
State::Connecting(fut) => match ready!(fut.as_mut().poll(cx)) {
Ok(res) => {
let status = res.status();
if matches!(status, StatusCode::NO_CONTENT) {
slf.close();
return Poll::Ready(None);
}
let is_transient_error = matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
);
if slf.retry_transient_errors && is_transient_error {
return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Http(status))));
} else if status != StatusCode::OK {
slf.close();
return Poll::Ready(Some(Err(Error::Status(status))));
}
let Some(content_type) = res
.headers()
.get(reqwest::header::CONTENT_TYPE)
.map(|v| v.as_bytes())
else {
slf.close();
return Poll::Ready(Some(Err(Error::MissingContentType)));
};
const MIME_EVENT_STREAM: &str = "text/event-stream";
if !(content_type.starts_with(MIME_EVENT_STREAM.as_bytes())
&& matches!(
content_type.get(MIME_EVENT_STREAM.len()),
None | Some(b';' | b' ' | b'\t')
))
{
slf.close();
return Poll::Ready(Some(Err(Error::InvalidContentType)));
}
slf.state = State::Open;
slf.connected_since = Some(Instant::now());
slf.stream.attach(Box::pin(res.bytes_stream()));
return Poll::Ready(Some(Ok(SseEvent::Open)));
}
Err(err) => {
slf.close();
return Poll::Ready(Some(slf.go_to_sleep(err.into())));
}
},
State::Open => match ready!(Pin::new(&mut slf.stream).poll_next(cx)) {
Some(Ok(raw_event)) => match raw_event {
SseEventCore::Retry(ms) => slf.reconnection_time_ms = ms,
SseEventCore::Message(event) => return Poll::Ready(Some(Ok(event.into()))),
},
Some(Err(SseStreamError::PayloadTooLarge(err))) => {
slf.close();
return Poll::Ready(Some(Err(Error::PayloadTooLarge(err))));
}
Some(Err(SseStreamError::Inner(err))) => {
return Poll::Ready(Some(slf.go_to_sleep(err.into())));
}
None => return Poll::Ready(Some(slf.go_to_sleep(SseErrorEvent::Eof))),
},
State::Sleeping(sleep_fut) => {
ready!(sleep_fut.as_mut().poll(cx));
slf.state = State::Disconnected;
}
State::Closed => return Poll::Ready(None),
}
}
}
}
mod sealed {
pub trait Sealed {}
}
pub trait RequestBuilderExt: sealed::Sealed {
fn into_event_source(self) -> EventSource;
fn into_event_source_builder(self) -> EventSourceBuilder;
}
impl sealed::Sealed for RequestBuilder {}
impl RequestBuilderExt for RequestBuilder {
fn into_event_source(self) -> EventSource {
EventSource::new(self)
}
fn into_event_source_builder(self) -> EventSourceBuilder {
EventSourceBuilder::new(self)
}
}