use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use tokio::{sync::mpsc, time::timeout};
use tracing::{debug, error, info, warn};
use super::{config::SseConfig, protocol::SseProtocolHandler, types::SseEvent};
use crate::{
auth::Authentication,
error::{TransportError, TransportResult},
reconnect::{BackoffConfig, calculate_backoff},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SseConnectionState {
Disconnected,
Connecting,
Connected,
Reconnecting {
attempt: u32,
},
Closed,
}
impl SseConnectionState {
pub fn is_connected(&self) -> bool {
matches!(self, Self::Connected)
}
pub fn is_closed(&self) -> bool {
matches!(self, Self::Closed)
}
}
#[derive(Debug)]
pub enum SseCommand {
Close,
Reconnect {
reason: String,
},
}
pub struct SseConnection {
handle: SseHandle,
stream: SseStream,
}
impl SseConnection {
pub async fn connect<H: SseProtocolHandler>(
config: SseConfig,
handler: H,
) -> TransportResult<Self> {
Self::connect_inner(config, handler, None).await
}
pub async fn connect_with_auth<H: SseProtocolHandler>(
config: SseConfig,
handler: H,
auth: Box<dyn Authentication>,
) -> TransportResult<Self> {
Self::connect_inner(config, handler, Some(auth)).await
}
async fn connect_inner<H: SseProtocolHandler>(
config: SseConfig,
handler: H,
auth: Option<Box<dyn Authentication>>,
) -> TransportResult<Self> {
config.validate().map_err(TransportError::config)?;
let config = Arc::new(config);
let (cmd_tx, cmd_rx) = mpsc::channel(config.command_channel_capacity);
let (event_tx, event_rx) = mpsc::channel(config.event_channel_capacity);
tokio::spawn(sse_connection_driver(
Arc::clone(&config),
handler,
auth,
cmd_rx,
event_tx,
));
let handle = SseHandle { cmd_tx };
let stream = SseStream { rx: event_rx };
Ok(Self { handle, stream })
}
pub fn split(self) -> (SseHandle, SseStream) {
(self.handle, self.stream)
}
pub fn handle(&self) -> &SseHandle {
&self.handle
}
}
impl Stream for SseConnection {
type Item = SseEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Pin::new(&mut this.stream).poll_next(cx)
}
}
#[derive(Clone)]
pub struct SseHandle {
cmd_tx: mpsc::Sender<SseCommand>,
}
impl SseHandle {
pub async fn close(&self) -> TransportResult<()> {
self.cmd_tx.send(SseCommand::Close).await.map_err(|_| {
TransportError::connection_closed(Some("SSE background task shut down".to_string()))
})
}
pub async fn reconnect(&self, reason: &str) -> TransportResult<()> {
self.cmd_tx
.send(SseCommand::Reconnect {
reason: reason.to_string(),
})
.await
.map_err(|_| {
TransportError::connection_closed(Some("SSE background task shut down".to_string()))
})
}
pub fn is_connected(&self) -> bool {
!self.cmd_tx.is_closed()
}
}
pub struct SseStream {
rx: mpsc::Receiver<SseEvent>,
}
impl SseStream {
pub async fn next_event(&mut self) -> Option<SseEvent> {
self.rx.recv().await
}
}
impl Stream for SseStream {
type Item = SseEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Pin::new(&mut this.rx).poll_recv(cx)
}
}
async fn establish_sse_connection(
config: &SseConfig,
auth: Option<&dyn Authentication>,
last_event_id: Option<&str>,
) -> TransportResult<super::parse::EventStream<impl Stream<Item = Result<Bytes, hpx::Error>> + use<>>>
{
let mut request_url = config.url.clone();
let client = hpx::Client::builder()
.connect_timeout(config.connect_timeout)
.build()
.map_err(|e| TransportError::config(format!("Failed to build HTTP client: {e}")))?;
let mut headers = config.headers.clone();
headers.insert(
http::header::ACCEPT,
http::HeaderValue::from_static("text/event-stream"),
);
headers.insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("no-cache"),
);
if let Some(id) = last_event_id
&& let Ok(value) = http::HeaderValue::from_str(id)
{
headers.insert(
http::header::HeaderName::from_static("last-event-id"),
value,
);
}
if config.auth_on_connect
&& let Some(a) = auth
{
let path = url_path(&config.url);
let body = config.body.as_deref();
if let Some(qs) = a.sign(&config.method, &path, &mut headers, body).await? {
append_query_string(&mut request_url, &qs);
debug!(query_string = %qs, "Applied auth query string to SSE URL");
}
}
let mut req = match config.method {
http::Method::POST => client.post(&request_url),
_ => client.get(&request_url),
};
req = req.headers(headers);
if let Some(body) = &config.body {
req = req.body(body.clone());
}
let resp = timeout(config.connect_timeout, req.send())
.await
.map_err(|_| TransportError::timeout(config.connect_timeout))?
.map_err(TransportError::Http)?;
let status = resp.status();
if !status.is_success() {
return Err(TransportError::sse_invalid_status(status));
}
if let Some(ct) = resp.headers().get(http::header::CONTENT_TYPE) {
let ct_str = ct.to_str().unwrap_or("");
if !ct_str.contains("text/event-stream") {
return Err(TransportError::sse_invalid_content_type(ct_str));
}
}
let body_stream = resp.bytes_stream();
Ok(super::parse::EventStream::new(body_stream))
}
fn url_path(url: &str) -> String {
url.find("://")
.and_then(|scheme_end| url[scheme_end + 3..].find('/'))
.map(|path_start| {
let offset = url.find("://").unwrap_or(0) + 3;
url[offset + path_start..].to_string()
})
.unwrap_or_else(|| "/".to_string())
}
fn append_query_string(url: &mut String, query: &str) {
let query = query.trim_start_matches('?');
if query.is_empty() {
return;
}
if url.contains('?') {
url.push('&');
} else {
url.push('?');
}
url.push_str(query);
}
async fn sse_connection_driver<H: SseProtocolHandler>(
config: Arc<SseConfig>,
handler: H,
auth: Option<Box<dyn Authentication>>,
mut cmd_rx: mpsc::Receiver<SseCommand>,
event_tx: mpsc::Sender<SseEvent>,
) {
let mut attempt: u32 = 0;
let mut last_event_id: Option<String> = None;
loop {
info!(url = %config.url, attempt, "SSE connecting");
let connection =
establish_sse_connection(&config, auth.as_deref(), last_event_id.as_deref()).await;
let mut event_stream = match connection {
Ok(stream) => {
info!(url = %config.url, "SSE connection established");
handler.on_connect();
attempt = 0;
stream
}
Err(err) => {
error!(url = %config.url, error = %err, "SSE connection failed");
handler.on_disconnect();
if !handler.should_retry(&err) {
warn!("Handler says no retry — closing");
return;
}
if let Some(max) = config.reconnect_max_attempts
&& attempt >= max
{
error!(attempts = max, "Max SSE reconnect attempts exceeded");
return;
}
let delay = calculate_backoff(
BackoffConfig {
initial_delay: config.reconnect_initial_delay,
max_delay: config.reconnect_max_delay,
factor: config.reconnect_backoff_factor,
jitter: config.reconnect_jitter,
},
attempt,
);
attempt = attempt.saturating_add(1);
warn!(
attempt,
delay_ms = delay.as_millis() as u64,
"SSE reconnecting after backoff"
);
tokio::time::sleep(delay).await;
continue;
}
};
let should_reconnect = loop {
tokio::select! {
biased;
cmd = cmd_rx.recv() => {
match cmd {
Some(SseCommand::Close) | None => {
info!("SSE connection closing (requested)");
handler.on_disconnect();
return;
}
Some(SseCommand::Reconnect { reason }) => {
warn!(reason = %reason, "SSE reconnect requested");
handler.on_disconnect();
break true;
}
}
}
item = event_stream.next() => {
match item {
Some(Ok(raw_event)) => {
if !raw_event.id.is_empty() {
last_event_id = Some(raw_event.id.to_string());
}
let kind = handler.classify_event(&raw_event);
debug!(
event_type = %raw_event.event,
id = %raw_event.id,
kind = %kind,
"SSE event received",
);
let sse_event = SseEvent::new(raw_event, kind);
if event_tx.send(sse_event).await.is_err() {
info!("SSE consumer dropped, shutting down");
handler.on_disconnect();
return;
}
}
Some(Err(err)) => {
error!(error = %err, "SSE stream error");
handler.on_disconnect();
let transport_err = TransportError::sse_parse(err.to_string());
if !handler.should_retry(&transport_err) {
warn!("Handler says no retry after stream error — closing");
return;
}
break true;
}
None => {
warn!("SSE stream ended");
handler.on_disconnect();
let transport_err = TransportError::sse_stream_ended();
if !handler.should_retry(&transport_err) {
return;
}
break true;
}
}
}
}
};
if !should_reconnect {
return;
}
if let Some(max) = config.reconnect_max_attempts
&& attempt >= max
{
error!(attempts = max, "Max SSE reconnect attempts exceeded");
return;
}
let delay = calculate_backoff(
BackoffConfig {
initial_delay: config.reconnect_initial_delay,
max_delay: config.reconnect_max_delay,
factor: config.reconnect_backoff_factor,
jitter: config.reconnect_jitter,
},
attempt,
);
attempt = attempt.saturating_add(1);
warn!(
attempt,
delay_ms = delay.as_millis() as u64,
"SSE reconnecting after backoff"
);
tokio::time::sleep(delay).await;
}
}