use super::types::{ConnectionState, ReconnectionConfig};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{debug, error, info, warn};
pub struct OtelMessage {
pub prefix: &'static [u8],
pub data: Vec<u8>,
}
pub struct SharedEngineConnection {
state: Arc<RwLock<ConnectionState>>,
tx: mpsc::Sender<OtelMessage>,
shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
flush_tx: mpsc::Sender<oneshot::Sender<()>>,
}
impl SharedEngineConnection {
pub fn new(ws_url: String, config: ReconnectionConfig) -> Self {
Self::with_channel_capacity(ws_url, config, 10_000)
}
pub fn with_channel_capacity(
ws_url: String,
config: ReconnectionConfig,
channel_capacity: usize,
) -> Self {
let (tx, rx) = mpsc::channel(channel_capacity);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let (flush_tx, flush_rx) = mpsc::channel(16);
let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
let connection = Self {
state: state.clone(),
tx,
shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
flush_tx,
};
tokio::spawn(connection_loop(
ws_url,
config,
state,
rx,
shutdown_rx,
flush_rx,
));
connection
}
pub fn send(&self, prefix: &'static [u8], data: Vec<u8>) -> Result<(), String> {
self.tx
.try_send(OtelMessage { prefix, data })
.map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => {
tracing::warn!("Telemetry channel full, dropping message");
"Channel full".to_string()
}
mpsc::error::TrySendError::Closed(_) => "Connection closed".to_string(),
})
}
pub async fn state(&self) -> ConnectionState {
*self.state.read().await
}
pub async fn flush(&self) {
let (done_tx, done_rx) = oneshot::channel();
if self.flush_tx.try_send(done_tx).is_ok() {
let _ = done_rx.await;
}
}
pub async fn shutdown(&self) {
if let Some(tx) = self.shutdown_tx.lock().await.take() {
let _ = tx.send(()).await;
}
}
}
fn collect_pending(
rx: &mut mpsc::Receiver<OtelMessage>,
pending: &mut Vec<OtelMessage>,
max_pending: usize,
) {
let mut dropped = 0u64;
while let Ok(msg) = rx.try_recv() {
if pending.len() < max_pending {
pending.push(msg);
} else {
dropped += 1;
}
}
if dropped > 0 {
warn!(dropped, "Pending message queue full, dropped messages");
}
}
async fn connection_loop(
ws_url: String,
config: ReconnectionConfig,
state: Arc<RwLock<ConnectionState>>,
mut rx: mpsc::Receiver<OtelMessage>,
mut shutdown_rx: mpsc::Receiver<()>,
mut flush_rx: mpsc::Receiver<oneshot::Sender<()>>,
) {
let mut retry_count: u64 = 0;
let max_pending = config.max_pending_messages;
let mut pending_messages: Vec<OtelMessage> = Vec::new();
let nanos_seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0x517cc1b727220a95);
let mut rng_state: u64 = nanos_seed ^ (std::process::id() as u64) ^ 0x6a09e667f3bcc908;
loop {
let is_reconnect = retry_count > 0;
*state.write().await = if is_reconnect {
ConnectionState::Reconnecting
} else {
ConnectionState::Connecting
};
debug!(
"Attempting to connect to engine WebSocket: {} (attempt {})",
ws_url,
retry_count + 1
);
let ws_result = tokio::select! {
result = connect_async(&ws_url) => result,
_ = shutdown_rx.recv() => {
info!("Shutting down OTEL connection during connect");
*state.write().await = ConnectionState::Disconnected;
return;
}
};
match ws_result {
Ok((ws_stream, _)) => {
info!("Connected to engine WebSocket");
*state.write().await = ConnectionState::Connected;
retry_count = 0;
let (mut write, mut read) = ws_stream.split();
debug!("Flushing {} pending messages", pending_messages.len());
for msg in pending_messages.drain(..) {
if let Err(e) = send_message(&mut write, msg).await {
error!("Failed to flush pending message: {}", e);
break;
}
}
loop {
tokio::select! {
Some(msg) = rx.recv() => {
if let Err(e) = send_message(&mut write, msg).await {
error!("Failed to send message: {}", e);
break;
}
}
Some(done_tx) = flush_rx.recv() => {
let mut flush_ok = true;
while let Ok(msg) = rx.try_recv() {
if let Err(e) = send_message(&mut write, msg).await {
error!("Failed to send message during flush: {}", e);
flush_ok = false;
break;
}
}
let _ = done_tx.send(());
if !flush_ok {
break;
}
}
result = read.next() => {
match result {
Some(Ok(Message::Ping(data))) => {
if let Err(e) = write.send(Message::Pong(data)).await {
error!("Failed to send pong: {}", e);
break;
}
}
Some(Ok(Message::Close(_))) | None => {
info!("WebSocket connection closed");
break;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
_ = shutdown_rx.recv() => {
info!("Shutdown signal received");
let _ = write.send(Message::Close(None)).await;
*state.write().await = ConnectionState::Disconnected;
return;
}
}
}
collect_pending(&mut rx, &mut pending_messages, max_pending);
}
Err(e) => {
error!("Failed to connect to engine WebSocket: {}", e);
retry_count += 1;
collect_pending(&mut rx, &mut pending_messages, max_pending);
}
}
if let Some(max) = config.max_retries
&& retry_count >= max
{
error!("Max retries exceeded, giving up");
*state.write().await = ConnectionState::Failed;
break;
}
let initial = config.effective_initial_delay_ms() as f64;
let max_delay = config.max_delay_ms as f64;
let exponent = retry_count.saturating_sub(1).min(63) as i32;
let base_delay = (initial * config.backoff_multiplier.powi(exponent))
.min(max_delay)
.max(0.0);
rng_state ^= rng_state << 13;
rng_state ^= rng_state >> 7;
rng_state ^= rng_state << 17;
let jitter = base_delay * config.jitter_factor * (rng_state as f64 / u64::MAX as f64);
let delay = (base_delay + jitter).min(max_delay).max(0.0) as u64;
debug!("Reconnecting in {}ms", delay);
tokio::select! {
_ = tokio::time::sleep(tokio::time::Duration::from_millis(delay)) => {}
_ = shutdown_rx.recv() => {
info!("Shutting down OTEL connection during backoff");
*state.write().await = ConnectionState::Disconnected;
return;
}
}
}
}
async fn send_message<S>(
write: &mut S,
msg: OtelMessage,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
S: SinkExt<Message> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
{
let mut frame = Vec::with_capacity(msg.prefix.len() + msg.data.len());
frame.extend_from_slice(msg.prefix);
frame.extend_from_slice(&msg.data);
write.send(Message::Binary(frame.into())).await?;
Ok(())
}