use std::error::Error as StdError;
use std::io::ErrorKind;
use std::time::Duration;
use std::time::Instant;
use async_stream::stream;
use futures_util::StreamExt;
use http::header::{
HeaderName,
HeaderValue,
CONTENT_TYPE,
};
use tokio_util::sync::CancellationToken;
use super::{
SseEventStream,
SseReconnectOptions,
DEFAULT_SSE_MAX_RECONNECT_DELAY,
};
use crate::{
HttpClient,
HttpError,
HttpErrorKind,
HttpRequest,
HttpResponse,
HttpResult,
RetryDelay,
RetryHint,
RetryOptions,
};
const LAST_EVENT_ID_HEADER: &str = "last-event-id";
enum ReconnectDecision {
Allowed,
MaxReconnectsReached,
MaxElapsedExceeded {
elapsed: Duration,
max_elapsed: Duration,
},
}
pub(crate) struct SseReconnectRunner {
client: HttpClient,
request_template: HttpRequest,
options: SseReconnectOptions,
}
impl SseReconnectRunner {
pub(crate) fn new(
client: HttpClient,
request: HttpRequest,
options: SseReconnectOptions,
) -> Self {
Self {
client,
request_template: request,
options,
}
}
pub(crate) fn run(self) -> SseEventStream {
let client = self.client;
let request_template = self.request_template;
let options = self.options;
let output = stream! {
let retry_options = options.retry.clone();
let max_reconnects = retry_options.max_attempts().saturating_sub(1);
let request_url = request_template.resolved_url_with_query().ok();
let request_method = request_template.method().clone();
let cancellation_token = request_template.cancellation_token().cloned();
let started_at = Instant::now();
let mut count: u32 = 0;
let mut backoff_delay = initial_reconnect_delay(&retry_options);
let mut pending_server_retry_delay: Option<Duration> = None;
let mut last_event_id: Option<String> = None;
loop {
let mut request = request_template.clone();
let retry_override = request.retry_override().clone().force_disable();
request.set_retry_override(retry_override);
if let Some(last_event_id) = last_event_id.as_deref() {
if let Err(error) = apply_last_event_id_header(&mut request, last_event_id) {
yield Err(error);
return;
}
}
let response = match client.execute_once(request).await {
Ok(response) => response,
Err(error) => {
if should_reconnect_sse_error(&error) {
let sleep_delay = reconnect_sleep_delay(
backoff_delay,
pending_server_retry_delay,
&retry_options,
&options,
);
match reconnect_decision(
count,
max_reconnects,
started_at,
&retry_options,
sleep_delay,
) {
ReconnectDecision::Allowed => {
count += 1;
if let Err(cancelled) = sleep_reconnect_delay(
sleep_delay,
cancellation_token.as_ref(),
&request_method,
request_url.as_ref(),
)
.await
{
yield Err(cancelled);
return;
}
backoff_delay = next_reconnect_delay(&retry_options, backoff_delay);
pending_server_retry_delay = None;
continue;
}
ReconnectDecision::MaxElapsedExceeded {
elapsed,
max_elapsed,
} => {
let error = max_elapsed_exceeded_error_with_last_error(
error,
elapsed,
max_elapsed,
&request_method,
request_url.as_ref(),
);
yield Err(error);
return;
}
ReconnectDecision::MaxReconnectsReached => {}
}
}
yield Err(error);
return;
}
};
if let Err(error) = validate_sse_response_content_type(&response) {
yield Err(error);
return;
}
let mut events = response.sse_events();
let mut stream_error: Option<HttpError> = None;
while let Some(item) = events.next().await {
match item {
Ok(event) => {
if let Some(id) = event.id.clone() {
last_event_id = Some(id);
}
if options.honor_server_retry {
if let Some(retry_ms) = event.retry {
pending_server_retry_delay =
Some(server_retry_delay(retry_ms, &retry_options, &options));
}
}
yield Ok(event);
}
Err(error) => {
stream_error = Some(error);
break;
}
}
}
if let Some(error) = stream_error {
if should_reconnect_sse_error(&error) {
let sleep_delay = reconnect_sleep_delay(
backoff_delay,
pending_server_retry_delay,
&retry_options,
&options,
);
match reconnect_decision(
count,
max_reconnects,
started_at,
&retry_options,
sleep_delay,
) {
ReconnectDecision::Allowed => {
count += 1;
if let Err(cancelled) = sleep_reconnect_delay(
sleep_delay,
cancellation_token.as_ref(),
&request_method,
request_url.as_ref(),
)
.await
{
yield Err(cancelled);
return;
}
backoff_delay = next_reconnect_delay(&retry_options, backoff_delay);
pending_server_retry_delay = None;
continue;
}
ReconnectDecision::MaxElapsedExceeded {
elapsed,
max_elapsed,
} => {
let error = max_elapsed_exceeded_error_with_last_error(
error,
elapsed,
max_elapsed,
&request_method,
request_url.as_ref(),
);
yield Err(error);
return;
}
ReconnectDecision::MaxReconnectsReached => {}
}
}
yield Err(error);
return;
}
if options.reconnect_on_eof {
let sleep_delay =
reconnect_sleep_delay(
backoff_delay,
pending_server_retry_delay,
&retry_options,
&options,
);
match reconnect_decision(
count,
max_reconnects,
started_at,
&retry_options,
sleep_delay,
) {
ReconnectDecision::Allowed => {
count += 1;
if let Err(cancelled) = sleep_reconnect_delay(
sleep_delay,
cancellation_token.as_ref(),
&request_method,
request_url.as_ref(),
)
.await
{
yield Err(cancelled);
return;
}
backoff_delay = next_reconnect_delay(&retry_options, backoff_delay);
pending_server_retry_delay = None;
continue;
}
ReconnectDecision::MaxElapsedExceeded {
elapsed,
max_elapsed,
} => {
let error = max_elapsed_exceeded_error(
elapsed,
max_elapsed,
request_template.method(),
request_url.as_ref(),
);
yield Err(error);
return;
}
ReconnectDecision::MaxReconnectsReached => return,
}
}
return;
}
};
Box::pin(output)
}
}
fn apply_last_event_id_header(request: &mut HttpRequest, last_event_id: &str) -> HttpResult<()> {
let header_value = HeaderValue::from_str(last_event_id).map_err(|error| {
HttpError::other(format!(
"Invalid Last-Event-ID header value '{last_event_id}': {error}"
))
})?;
request.set_typed_header(HeaderName::from_static(LAST_EVENT_ID_HEADER), header_value);
Ok(())
}
fn should_reconnect_sse_error(error: &HttpError) -> bool {
if error.kind == HttpErrorKind::Cancelled {
return false;
}
matches!(error.retry_hint(), RetryHint::Retryable) || is_unexpected_eof_error(error)
}
fn next_reconnect_delay(retry_options: &RetryOptions, current: Duration) -> Duration {
retry_options
.next_base_delay_from_current(current)
.max(Duration::from_millis(1))
}
fn reconnect_decision(
count: u32,
max_reconnects: u32,
started_at: Instant,
retry_options: &RetryOptions,
sleep_delay: Duration,
) -> ReconnectDecision {
if count >= max_reconnects {
return ReconnectDecision::MaxReconnectsReached;
}
if let Some(max_elapsed) = retry_options.max_total_elapsed() {
let elapsed = started_at.elapsed();
if (elapsed >= max_elapsed) || will_exceed_elapsed(elapsed, sleep_delay, max_elapsed) {
return ReconnectDecision::MaxElapsedExceeded {
elapsed,
max_elapsed,
};
}
}
ReconnectDecision::Allowed
}
fn will_exceed_elapsed(elapsed: Duration, sleep_delay: Duration, max_elapsed: Duration) -> bool {
elapsed
.checked_add(sleep_delay)
.is_none_or(|next_elapsed| next_elapsed >= max_elapsed)
}
fn initial_reconnect_delay(retry_options: &RetryOptions) -> Duration {
retry_options
.base_delay_for_attempt(1)
.max(Duration::from_millis(1))
}
fn reconnect_sleep_delay(
backoff_delay: Duration,
pending_server_retry_delay: Option<Duration>,
retry_options: &RetryOptions,
options: &SseReconnectOptions,
) -> Duration {
let delay = if let Some(server_delay) = pending_server_retry_delay {
if options.apply_jitter_to_server_retry {
retry_options.jittered_delay(server_delay)
} else {
server_delay
}
} else {
retry_options.jittered_delay(backoff_delay)
};
delay.max(Duration::from_millis(1))
}
async fn sleep_reconnect_delay(
delay: Duration,
cancellation_token: Option<&CancellationToken>,
request_method: &http::Method,
request_url: Option<&url::Url>,
) -> HttpResult<()> {
if let Some(token) = cancellation_token {
tokio::select! {
_ = token.cancelled() => {
let mut error = HttpError::cancelled(
"SSE reconnect cancelled while waiting before next attempt",
)
.with_method(request_method);
if let Some(url) = request_url {
error = error.with_url(url);
}
Err(error)
}
_ = tokio::time::sleep(delay) => Ok(()),
}
} else {
tokio::time::sleep(delay).await;
Ok(())
}
}
fn server_retry_delay(
retry_ms: u64,
retry_options: &RetryOptions,
options: &SseReconnectOptions,
) -> Duration {
let raw = Duration::from_millis(retry_ms.max(1));
let cap = server_retry_max_delay(retry_options, options);
raw.min(cap).max(Duration::from_millis(1))
}
fn server_retry_max_delay(retry_options: &RetryOptions, options: &SseReconnectOptions) -> Duration {
options
.server_retry_max_delay
.unwrap_or_else(|| default_server_retry_max_delay(retry_options))
.max(Duration::from_millis(1))
}
fn default_server_retry_max_delay(retry_options: &RetryOptions) -> Duration {
match retry_options.delay() {
RetryDelay::None | RetryDelay::Fixed(_) => DEFAULT_SSE_MAX_RECONNECT_DELAY,
RetryDelay::Random { max, .. } | RetryDelay::Exponential { max, .. } => *max,
}
.max(Duration::from_millis(1))
}
fn max_elapsed_exceeded_error(
elapsed: Duration,
max_elapsed: Duration,
request_method: &http::Method,
request_url: Option<&url::Url>,
) -> HttpError {
let mut error = HttpError::retry_max_elapsed_exceeded(format!(
"SSE reconnect max duration exceeded: {elapsed:?}/{max_elapsed:?}"
))
.with_method(request_method);
if let Some(url) = request_url {
error = error.with_url(url);
}
error
}
fn max_elapsed_exceeded_error_with_last_error(
last_error: HttpError,
elapsed: Duration,
max_elapsed: Duration,
request_method: &http::Method,
request_url: Option<&url::Url>,
) -> HttpError {
let mut error = max_elapsed_exceeded_error(elapsed, max_elapsed, request_method, request_url);
if let Some(method) = last_error.method.as_ref() {
error = error.with_method(method);
}
if let Some(url) = last_error.url.as_ref() {
error = error.with_url(url);
}
if let Some(status) = last_error.status {
error = error.with_status(status);
}
let mut message = format!(
"{}; last retryable error: {}",
error.message, last_error.message
);
if let Some(status) = last_error.status {
message = format!("{message} (status: {status})");
}
error.message = message;
error.source = Some(Box::new(last_error));
error
}
fn validate_sse_response_content_type(response: &HttpResponse) -> HttpResult<()> {
let method = response.meta.method.clone();
let url = response.request_url().clone();
let Some(value) = response.headers().get(CONTENT_TYPE) else {
return Err(
HttpError::sse_protocol("Missing Content-Type header for SSE response")
.with_status(response.status())
.with_method(&method)
.with_url(&url),
);
};
let content_type = value.to_str().map_err(|_| {
HttpError::sse_protocol("Invalid non-UTF8 Content-Type header for SSE response")
.with_status(response.status())
.with_method(&method)
.with_url(&url)
})?;
let media_type = content_type
.split(';')
.next()
.map(str::trim)
.unwrap_or_default();
if media_type.eq_ignore_ascii_case("text/event-stream") {
return Ok(());
}
Err(HttpError::sse_protocol(format!(
"Expected Content-Type 'text/event-stream' for SSE response, got '{content_type}'"
))
.with_status(response.status())
.with_method(&method)
.with_url(&url))
}
fn is_unexpected_eof_error(error: &HttpError) -> bool {
let contains_unexpected_eof = |text: &str| text.to_ascii_lowercase().contains("unexpected eof");
if contains_unexpected_eof(&error.message) {
return true;
}
error.source.as_ref().is_some_and(|source| {
has_unexpected_eof_in_error_chain(source.as_ref())
|| contains_unexpected_eof(&source.to_string())
|| contains_unexpected_eof(&format!("{source:?}"))
})
}
fn has_unexpected_eof_in_error_chain(error: &(dyn StdError + 'static)) -> bool {
let mut current: Option<&(dyn StdError + 'static)> = Some(error);
while let Some(item) = current {
if item
.downcast_ref::<std::io::Error>()
.is_some_and(|io_error| io_error.kind() == ErrorKind::UnexpectedEof)
{
return true;
}
current = item.source();
}
false
}