use crate::error::{Error, Result};
use crate::shared::{Transport, TransportMessage};
use async_trait::async_trait;
use bytes::BytesMut;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{interval, timeout};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct OptimizedSseConfig {
pub url: String,
pub connection_timeout: Duration,
pub keepalive_interval: Duration,
pub max_reconnects: usize,
pub reconnect_delay: Duration,
pub buffer_size: usize,
pub flush_interval: Duration,
pub enable_pooling: bool,
pub max_connections: usize,
pub enable_compression: bool,
}
impl Default for OptimizedSseConfig {
fn default() -> Self {
Self {
url: "http://localhost:8080/sse".to_string(),
connection_timeout: Duration::from_secs(30),
keepalive_interval: Duration::from_secs(15),
max_reconnects: 5,
reconnect_delay: Duration::from_secs(1),
buffer_size: 100,
flush_interval: Duration::from_millis(100),
enable_pooling: true,
max_connections: 10,
enable_compression: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConnectionState {
Disconnected,
Connecting,
Connected,
Reconnecting,
}
pub struct OptimizedSseTransport {
config: OptimizedSseConfig,
client: reqwest::Client,
state: Arc<RwLock<ConnectionState>>,
event_buffer: Arc<RwLock<VecDeque<TransportMessage>>>,
send_tx: mpsc::Sender<TransportMessage>,
recv_rx: Arc<RwLock<mpsc::Receiver<TransportMessage>>>,
reconnect_count: Arc<RwLock<usize>>,
last_event_id: Arc<RwLock<Option<String>>>,
}
impl std::fmt::Debug for OptimizedSseTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OptimizedSseTransport")
.field("config", &self.config)
.field("state", &self.state)
.field("reconnect_count", &self.reconnect_count)
.finish()
}
}
impl OptimizedSseTransport {
#[allow(clippy::needless_pass_by_value)]
pub fn new(config: OptimizedSseConfig) -> Self {
let (send_tx, send_rx) = mpsc::channel(config.buffer_size);
let (recv_tx, recv_rx) = mpsc::channel(config.buffer_size);
let client = reqwest::Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(90)))
.pool_max_idle_per_host(config.max_connections)
.tcp_keepalive(Some(Duration::from_secs(60)))
.timeout(config.connection_timeout)
.build()
.expect("Failed to build HTTP client");
let transport = Self {
config: config.clone(),
client,
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
event_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(config.buffer_size))),
send_tx,
recv_rx: Arc::new(RwLock::new(recv_rx)),
reconnect_count: Arc::new(RwLock::new(0)),
last_event_id: Arc::new(RwLock::new(None)),
};
transport.start_background_tasks(send_rx, recv_tx);
transport
}
fn start_background_tasks(
&self,
mut send_rx: mpsc::Receiver<TransportMessage>,
recv_tx: mpsc::Sender<TransportMessage>,
) {
let config = self.config.clone();
let config2 = self.config.clone();
let config3 = self.config.clone();
let client = self.client.clone();
let client2 = self.client.clone();
let client3 = self.client.clone();
let state = self.state.clone();
let state2 = self.state.clone();
let state3 = self.state.clone();
let _event_buffer = self.event_buffer.clone();
let event_buffer2 = self.event_buffer.clone();
let reconnect_count = self.reconnect_count.clone();
let last_event_id = self.last_event_id.clone();
tokio::spawn(async move {
loop {
match Self::connect_sse(&config, &client, &state, &recv_tx, &last_event_id).await {
Ok(()) => {
info!("SSE connection closed normally");
*reconnect_count.write().await = 0;
},
Err(e) => {
error!("SSE connection error: {}", e);
let mut count = reconnect_count.write().await;
*count += 1;
if *count >= config.max_reconnects {
error!("Max reconnect attempts reached");
break;
}
*state.write().await = ConnectionState::Reconnecting;
tokio::time::sleep(config.reconnect_delay).await;
},
}
}
});
tokio::spawn(async move {
let mut flush_ticker = interval(config2.flush_interval);
loop {
tokio::select! {
Some(msg) = send_rx.recv() => {
event_buffer2.write().await.push_back(msg);
if event_buffer2.read().await.len() >= config2.buffer_size {
Self::flush_events(
&event_buffer2,
&client2,
&config2,
&state2,
).await;
}
}
_ = flush_ticker.tick() => {
if !event_buffer2.read().await.is_empty() {
Self::flush_events(
&event_buffer2,
&client2,
&config2,
&state2,
).await;
}
}
}
}
});
tokio::spawn(async move {
let mut ticker = interval(config3.keepalive_interval);
loop {
ticker.tick().await;
if *state3.read().await == ConnectionState::Connected {
if let Err(e) = Self::send_keepalive(&client3, &config3).await {
warn!("Keepalive failed: {}", e);
}
}
}
});
}
#[allow(clippy::cognitive_complexity)]
async fn connect_sse(
config: &OptimizedSseConfig,
client: &reqwest::Client,
state: &Arc<RwLock<ConnectionState>>,
recv_tx: &mpsc::Sender<TransportMessage>,
last_event_id: &Arc<RwLock<Option<String>>>,
) -> Result<()> {
*state.write().await = ConnectionState::Connecting;
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
headers.insert("Cache-Control", HeaderValue::from_static("no-cache"));
if let Some(ref id) = *last_event_id.read().await {
headers.insert(
"Last-Event-ID",
HeaderValue::from_str(id).unwrap_or_else(|_| HeaderValue::from_static("0")),
);
}
let response = timeout(
config.connection_timeout,
client.get(&config.url).headers(headers).send(),
)
.await
.map_err(|_| Error::internal("SSE connection timeout"))?
.map_err(|e| Error::internal(format!("SSE connection failed: {}", e)))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"SSE connection failed with status: {}",
response.status()
)));
}
*state.write().await = ConnectionState::Connected;
info!("SSE connection established");
match response.text().await {
Ok(text) => {
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if let Ok(msg) = serde_json::from_str::<TransportMessage>(data) {
if let Err(e) = recv_tx.send(msg).await {
error!("Failed to queue received message: {}", e);
return Err(Error::internal("Receiver channel closed"));
}
}
}
}
},
Err(e) => {
error!("Response error: {}", e);
return Err(Error::internal("Response error"));
},
}
*state.write().await = ConnectionState::Disconnected;
Ok(())
}
#[allow(dead_code, clippy::unnecessary_wraps)]
fn parse_sse_event(buffer: &mut BytesMut) -> Result<Option<SseEvent>> {
if let Some(pos) = buffer.windows(2).position(|w| w == b"\n\n") {
let event_data = buffer.split_to(pos + 2);
let event_str = String::from_utf8_lossy(&event_data);
let mut event = SseEvent::default();
for line in event_str.lines() {
if let Some(data) = line.strip_prefix("data: ") {
event.data.push_str(data);
event.data.push('\n');
} else if let Some(event_type) = line.strip_prefix("event: ") {
event.event = Some(event_type.to_string());
} else if let Some(id) = line.strip_prefix("id: ") {
event.id = Some(id.to_string());
} else if let Some(retry) = line.strip_prefix("retry: ") {
if let Ok(ms) = retry.parse::<u64>() {
event.retry = Some(Duration::from_millis(ms));
}
}
}
if event.data.ends_with('\n') {
event.data.pop();
}
if !event.data.is_empty() {
return Ok(Some(event));
}
}
Ok(None)
}
#[allow(dead_code, clippy::unnecessary_wraps)]
fn parse_message(event: &SseEvent) -> Result<Option<TransportMessage>> {
if event.data.is_empty() {
return Ok(None);
}
match serde_json::from_str::<TransportMessage>(&event.data) {
Ok(msg) => Ok(Some(msg)),
Err(e) => {
warn!("Failed to parse SSE message: {}", e);
Ok(None)
},
}
}
async fn flush_events(
buffer: &Arc<RwLock<VecDeque<TransportMessage>>>,
client: &reqwest::Client,
config: &OptimizedSseConfig,
state: &Arc<RwLock<ConnectionState>>,
) {
if *state.read().await != ConnectionState::Connected {
return;
}
let mut events = buffer.write().await;
if events.is_empty() {
return;
}
let batch: Vec<TransportMessage> = events.drain(..).collect();
for msg in batch {
if let Err(e) = Self::send_event(client, config, &msg).await {
error!("Failed to send event: {}", e);
events.push_back(msg);
}
}
}
async fn send_event(
client: &reqwest::Client,
config: &OptimizedSseConfig,
msg: &TransportMessage,
) -> Result<()> {
let json = serde_json::to_string(msg)
.map_err(|e| Error::internal(format!("Failed to serialize message: {}", e)))?;
let response = client
.post(&config.url)
.header(CONTENT_TYPE, "application/json")
.body(json)
.send()
.await
.map_err(|e| Error::internal(format!("Failed to send event: {}", e)))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"Event send failed with status: {}",
response.status()
)));
}
Ok(())
}
async fn send_keepalive(client: &reqwest::Client, config: &OptimizedSseConfig) -> Result<()> {
let response = client
.get(format!("{}/ping", config.url))
.send()
.await
.map_err(|e| Error::internal(format!("Keepalive failed: {}", e)))?;
if !response.status().is_success() {
return Err(Error::internal("Keepalive ping failed"));
}
debug!("Keepalive ping successful");
Ok(())
}
}
#[derive(Debug, Default)]
#[allow(dead_code)]
struct SseEvent {
data: String,
event: Option<String>,
id: Option<String>,
retry: Option<Duration>,
}
#[async_trait]
impl Transport for OptimizedSseTransport {
async fn send(&mut self, message: TransportMessage) -> Result<()> {
self.send_tx
.send(message)
.await
.map_err(|_| Error::internal("Send channel closed"))
}
async fn receive(&mut self) -> Result<TransportMessage> {
let mut rx = self.recv_rx.write().await;
rx.recv()
.await
.ok_or_else(|| Error::internal("Receive channel closed"))
}
async fn close(&mut self) -> Result<()> {
*self.state.write().await = ConnectionState::Disconnected;
info!("SSE transport closed");
Ok(())
}
fn is_connected(&self) -> bool {
futures::executor::block_on(async {
*self.state.read().await == ConnectionState::Connected
})
}
fn transport_type(&self) -> &'static str {
"sse-optimized"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = OptimizedSseConfig::default();
assert_eq!(config.buffer_size, 100);
assert_eq!(config.max_connections, 10);
assert!(config.enable_pooling);
}
#[test]
fn test_sse_event_parsing() {
use bytes::BufMut;
let mut buffer = BytesMut::new();
buffer.put(&b"data: test message\nid: 123\n\n"[..]);
let event = OptimizedSseTransport::parse_sse_event(&mut buffer)
.unwrap()
.unwrap();
assert_eq!(event.data, "test message");
assert_eq!(event.id, Some("123".to_string()));
}
}