use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, watch};
use tokio::time::interval;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{debug, error, info, warn};
use crate::actors::{DataMessage, ExchangeConnector};
use crate::error::{ExchangeError, Result};
use crate::ws::types::WsMessage;
#[derive(Debug, Clone)]
pub struct WsRunnerConfig {
pub ping_interval_secs: u64,
pub reconnect_delay_secs: u64,
pub max_reconnect_delay_secs: u64,
pub max_reconnect_attempts: u32,
}
impl Default for WsRunnerConfig {
fn default() -> Self {
Self {
ping_interval_secs: 20,
reconnect_delay_secs: 5,
max_reconnect_delay_secs: 80,
max_reconnect_attempts: 10,
}
}
}
impl WsRunnerConfig {
pub fn from_ping_interval(ping_interval_secs: u64) -> Self {
Self {
ping_interval_secs,
..Default::default()
}
}
}
struct WsMsgGuard {
window: VecDeque<Instant>,
max_msgs: usize,
window_dur: Duration,
}
impl WsMsgGuard {
fn new() -> Self {
Self {
window: VecDeque::with_capacity(100),
max_msgs: 100,
window_dur: Duration::from_secs(10),
}
}
async fn check(&mut self) {
let now = Instant::now();
while self
.window
.front()
.is_some_and(|t| now - *t > self.window_dur)
{
self.window.pop_front();
}
if self.window.len() >= self.max_msgs {
if let Some(oldest) = self.window.front() {
let wait = self.window_dur.saturating_sub(now - *oldest);
if !wait.is_zero() {
warn!(
wait_ms = wait.as_millis(),
"WS outbound rate limit reached (100/10s) — throttling"
);
tokio::time::sleep(wait).await;
}
}
}
self.window.push_back(Instant::now());
}
}
pub async fn run_feed(
ws_url: impl Into<String>,
subscriptions: Vec<String>,
connector: Arc<dyn ExchangeConnector>,
tx: mpsc::Sender<DataMessage>,
config: WsRunnerConfig,
mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
const STABLE_SESSION_SECS: u64 = 60;
let url = ws_url.into();
let mut attempts: u32 = 0;
loop {
if attempts > 0 {
let exp = (attempts - 1).min(63); let delay = config
.reconnect_delay_secs
.saturating_mul(1u64 << exp.min(4)) .min(config.max_reconnect_delay_secs);
warn!(
attempt = attempts,
max = config.max_reconnect_attempts,
delay_secs = delay,
exchange = connector.exchange_name(),
"WS reconnecting"
);
tokio::time::sleep(Duration::from_secs(delay)).await;
}
let session_start = Instant::now();
let outcome = single_session(
&url,
&subscriptions,
connector.clone(),
tx.clone(),
&config,
&mut shutdown,
attempts,
)
.await;
match outcome {
SessionOutcome::ShutdownRequested => {
info!(
exchange = connector.exchange_name(),
"WS feed shut down cleanly"
);
return Ok(());
}
SessionOutcome::ReceiverDropped => {
info!("DataMessage receiver dropped; stopping WS feed");
return Ok(());
}
SessionOutcome::Disconnected => {
if session_start.elapsed().as_secs() >= STABLE_SESSION_SECS {
info!(
exchange = connector.exchange_name(),
uptime_secs = session_start.elapsed().as_secs(),
"WS stable session ended — resetting reconnect counter",
);
attempts = 0;
} else {
attempts += 1;
if attempts > config.max_reconnect_attempts {
error!(
max = config.max_reconnect_attempts,
exchange = connector.exchange_name(),
"WS max reconnect attempts exhausted"
);
return Err(ExchangeError::WsDisconnected {
url: url.to_string(),
attempts,
});
}
}
}
}
}
}
enum SessionOutcome {
ShutdownRequested,
ReceiverDropped,
Disconnected,
}
async fn single_session(
url: &str,
subscriptions: &[String],
connector: Arc<dyn ExchangeConnector>,
tx: mpsc::Sender<DataMessage>,
config: &WsRunnerConfig,
shutdown: &mut watch::Receiver<bool>,
attempt: u32,
) -> SessionOutcome {
info!(url, exchange = connector.exchange_name(), "WS connecting");
let ws_stream = match connect_async(url).await {
Ok((stream, _resp)) => stream,
Err(e) => {
warn!(error = %e, "WS connect failed");
return SessionOutcome::Disconnected;
}
};
let (mut write, mut read) = ws_stream.split();
let mut guard = WsMsgGuard::new();
for sub in subscriptions {
guard.check().await;
if let Err(e) = write.send(Message::Text(sub.clone().into())).await {
warn!(error = %e, "failed to send subscription");
return SessionOutcome::Disconnected;
}
debug!(topic = ?sub, "subscribed");
}
info!(
exchange = connector.exchange_name(),
"WS connected and subscribed"
);
let mut ping_tick = interval(Duration::from_secs(config.ping_interval_secs));
ping_tick.tick().await;
loop {
tokio::select! {
biased;
Ok(()) = shutdown.changed() => {
if *shutdown.borrow() {
guard.check().await;
let _ = write.send(Message::Close(None)).await;
return SessionOutcome::ShutdownRequested;
}
}
frame = read.next() => {
match frame {
Some(Ok(Message::Text(text))) => {
match connector.parse_message(&text) {
Ok(msgs) => {
for msg in msgs {
if tx.send(msg).await.is_err() {
return SessionOutcome::ReceiverDropped;
}
}
}
Err(e) => {
warn!(error = %e, raw = %text, "parse_message error — skipping frame");
}
}
}
Some(Ok(Message::Ping(data))) => {
if let Err(e) = write.send(Message::Pong(data)).await {
warn!(error = %e, "pong send failed");
return SessionOutcome::Disconnected;
}
}
Some(Ok(Message::Close(frame))) => {
info!(frame = ?frame, "server closed WS connection");
return SessionOutcome::Disconnected;
}
Some(Ok(Message::Binary(_))) => {
debug!("unexpected binary frame — ignored");
}
Some(Ok(_)) => {} Some(Err(e)) => {
if attempt == 0 {
debug!(error = %e, exchange = connector.exchange_name(), "WS read error");
} else {
warn!(error = %e, attempt, exchange = connector.exchange_name(), "WS read error");
}
return SessionOutcome::Disconnected;
}
None => {
debug!("WS stream closed");
return SessionOutcome::Disconnected;
}
}
}
_ = ping_tick.tick() => {
guard.check().await;
if let Err(e) = write
.send(Message::Text(WsMessage::ping_json().into()))
.await
{
warn!(error = %e, "ping send failed");
return SessionOutcome::Disconnected;
}
debug!(exchange = connector.exchange_name(), "sent ping");
}
}
}
}