use std::{
future::Future,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::{Duration, Instant},
};
use futures::SinkExt;
use rand::Rng;
use simulator_api::{BacktestError, BacktestRequest};
use tokio::{
net::TcpStream,
sync::{Notify, OwnedSemaphorePermit, Semaphore, watch},
};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Error as WsError, Message, client::IntoClientRequest, http::HeaderValue},
};
use tokio_util::sync::CancellationToken;
use tracing::warn;
use crate::error::err_chain;
mod control;
mod parallel;
mod session;
mod subscription;
pub use control::{ControlEvent, ControlHandle, spawn_control_manager};
pub use parallel::{ManagedParallelSession, ParallelSubSession};
pub use session::{ManagedBacktestSession, ManagedEvent, ManagedSessionError};
pub use subscription::{
SubscriptionHandle, SubscriptionNotification, spawn_account_diff_subscription_manager,
spawn_transaction_subscription_manager,
};
pub const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const HANDSHAKE_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
pub const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15);
pub const KEEPALIVE_MISS_DEADLINE: Duration = Duration::from_secs(45);
pub const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
pub const RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
pub const RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
pub const RECONNECT_BACKOFF_MULTIPLIER: f64 = 2.0;
pub const RECONNECT_JITTER: f64 = 0.2;
pub const RECONNECT_MAX_TOTAL: Duration = Duration::from_secs(5 * 60);
pub const RECONNECT_MAX_ATTEMPTS: u32 = 20;
pub const RECONNECT_UNGATED_ATTEMPTS: u32 = 5;
pub const RECONNECT_UPTIME_RESET: Duration = Duration::from_secs(30);
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ConnectionStatus {
Up,
Down,
Failed(String),
}
#[derive(Clone, Debug)]
pub struct SessionInfo {
pub session_id: String,
pub rpc_endpoint: String,
pub task_id: Option<String>,
}
pub(crate) struct ReconnectBudget {
attempts: u32,
started_at: std::time::Instant,
current_backoff: Duration,
}
impl ReconnectBudget {
pub fn new() -> Self {
Self {
attempts: 0,
started_at: std::time::Instant::now(),
current_backoff: RECONNECT_INITIAL_BACKOFF,
}
}
pub fn reset(&mut self) {
self.attempts = 0;
self.started_at = std::time::Instant::now();
self.current_backoff = RECONNECT_INITIAL_BACKOFF;
}
pub fn attempt(&self) -> u32 {
self.attempts
}
pub fn discount_parked(&mut self, parked: Duration) {
self.started_at += parked;
}
pub fn next_backoff(&mut self) -> Option<Duration> {
if self.attempts >= RECONNECT_MAX_ATTEMPTS
|| self.started_at.elapsed() >= RECONNECT_MAX_TOTAL
{
return None;
}
self.attempts += 1;
let backoff = with_jitter(self.current_backoff);
self.current_backoff = std::cmp::min(
RECONNECT_MAX_BACKOFF,
Duration::from_secs_f64(
self.current_backoff.as_secs_f64() * RECONNECT_BACKOFF_MULTIPLIER,
),
);
Some(backoff)
}
}
pub struct ReconnectCoordinator {
streaming: AtomicUsize,
drained: Notify,
handshake: Arc<Semaphore>,
}
impl Default for ReconnectCoordinator {
fn default() -> Self {
Self::new()
}
}
impl ReconnectCoordinator {
pub fn new() -> Self {
Self {
streaming: AtomicUsize::new(0),
drained: Notify::new(),
handshake: Arc::new(Semaphore::new(1)),
}
}
pub fn enter(self: &Arc<Self>) -> StreamingGuard {
self.streaming.fetch_add(1, Ordering::SeqCst);
StreamingGuard(self.clone())
}
pub async fn reconnect_slot(&self, cancel: &CancellationToken) -> Option<OwnedSemaphorePermit> {
loop {
loop {
let drained = self.drained.notified();
if self.streaming.load(Ordering::SeqCst) == 0 {
break;
}
tokio::select! {
biased;
_ = cancel.cancelled() => return None,
_ = drained => {}
}
}
let permit = tokio::select! {
biased;
_ = cancel.cancelled() => return None,
p = self.handshake.clone().acquire_owned() => p.ok()?,
};
if self.streaming.load(Ordering::SeqCst) == 0 {
return Some(permit);
}
}
}
}
pub struct StreamingGuard(Arc<ReconnectCoordinator>);
impl Drop for StreamingGuard {
fn drop(&mut self) {
self.0.streaming.fetch_sub(1, Ordering::SeqCst);
self.0.drained.notify_waiters();
}
}
fn with_jitter(d: Duration) -> Duration {
let jitter = rand::rng().random_range(-RECONNECT_JITTER..RECONNECT_JITTER);
let secs = (d.as_secs_f64() * (1.0 + jitter)).max(0.0);
Duration::from_secs_f64(secs)
}
pub(crate) async fn cancellable_sleep(delay: Duration, cancel: &CancellationToken) -> bool {
tokio::select! {
_ = tokio::time::sleep(delay) => true,
_ = cancel.cancelled() => false,
}
}
pub(super) type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub(super) async fn connect_ws(url: &str, api_key: &str) -> Result<Ws, String> {
let mut request = url
.into_client_request()
.map_err(|e| format!("build request: {}", err_chain(&e)))?;
request.headers_mut().insert(
"X-API-Key",
HeaderValue::from_str(api_key).map_err(|e| format!("api key header: {}", err_chain(&e)))?,
);
let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
.await
.map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
.map_err(|e| format!("connect: {}", err_chain(&e)))?;
Ok(connect.0)
}
pub(super) async fn send_request(ws: &mut Ws, req: &BacktestRequest) -> Result<(), String> {
let text = serde_json::to_string(req).map_err(|e| format!("serialize: {}", err_chain(&e)))?;
ws.send(Message::Text(text))
.await
.map_err(|e| format!("send: {}", err_chain(&e)))
}
pub(super) fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
endpoint.to_string()
} else {
format!("{}/{}", base, endpoint.trim_start_matches('/'))
}
}
pub(super) enum HandshakeError {
Transient(String),
Fatal(String),
}
pub(super) fn handshake_error_for_response(
stage: &'static str,
err: BacktestError,
) -> HandshakeError {
match err {
BacktestError::SessionOwnershipBusy { .. } => {
HandshakeError::Transient(format!("{stage} contended: {}", err_chain(&err)))
}
_ => HandshakeError::Fatal(format!("{stage} rejected: {}", err_chain(&err))),
}
}
pub(super) enum MessageLoopExit {
SessionEnded,
Cancelled,
ConnectionLost(String),
Terminal(String),
}
pub(super) fn publish_status(
status_tx: &watch::Sender<ConnectionStatus>,
status: ConnectionStatus,
) {
status_tx.send_if_modified(|current| {
if *current == status {
false
} else {
*current = status;
true
}
});
}
pub(super) fn is_terminal_backtest_error(err: &BacktestError) -> bool {
matches!(
err,
BacktestError::NoMoreBlocks
| BacktestError::AdvanceSlotFailed { .. }
| BacktestError::FinalizeSlotFailed { .. }
| BacktestError::Internal { .. }
)
}
pub(super) async fn graceful_close(ws: &mut Ws) {
let _ = tokio::time::timeout(GRACEFUL_CLOSE_TIMEOUT, async {
let _ = send_request(ws, &BacktestRequest::CloseBacktestSession).await;
let _ = ws.close(None).await;
})
.await;
}
pub(super) enum InboundFrame {
Text(String),
Ignore,
Lost(String),
}
pub(super) fn classify_inbound(msg: Option<Result<Message, WsError>>) -> InboundFrame {
match msg {
Some(Ok(Message::Text(t))) => InboundFrame::Text(t),
Some(Ok(Message::Binary(b))) => match String::from_utf8(b) {
Ok(t) => InboundFrame::Text(t),
Err(_) => InboundFrame::Ignore,
},
Some(Ok(Message::Pong(_) | Message::Ping(_) | Message::Frame(_))) => InboundFrame::Ignore,
Some(Ok(Message::Close(frame))) => InboundFrame::Lost(format!("remote close: {frame:?}")),
Some(Err(e)) => InboundFrame::Lost(format!("ws read: {}", err_chain(&e))),
None => InboundFrame::Lost("ws stream ended".into()),
}
}
pub(super) async fn send_keepalive_ping(ws: &mut Ws, last_inbound: Instant) -> Option<String> {
if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
return Some(format!("no traffic for {:?}", last_inbound.elapsed()));
}
if let Err(e) = ws.send(Message::Ping(vec![])).await {
return Some(format!("ping send: {}", err_chain(&e)));
}
None
}
pub(super) trait ControlConnection: Send + 'static {
fn url(&self) -> &str;
fn api_key(&self) -> &str;
fn cancel(&self) -> &CancellationToken;
fn label(&self) -> &'static str;
fn status_tx(&self) -> &watch::Sender<ConnectionStatus>;
fn fail_pending(&mut self, reason: String);
fn handshake(&mut self, ws: Ws) -> impl Future<Output = Result<Ws, HandshakeError>> + Send;
fn message_loop(&mut self, ws: Ws) -> impl Future<Output = MessageLoopExit> + Send;
fn publish(&self, status: ConnectionStatus) {
publish_status(self.status_tx(), status);
}
fn finish_failed(&mut self, reason: String) {
self.fail_pending(reason.clone());
self.publish(ConnectionStatus::Failed(reason));
}
}
pub(super) async fn run_control_loop<T: ControlConnection>(mut task: T) {
let mut budget = ReconnectBudget::new();
loop {
if task.cancel().is_cancelled() {
task.fail_pending("cancelled before session created".to_string());
return;
}
task.publish(ConnectionStatus::Down);
let ws = match connect_ws(task.url(), task.api_key()).await {
Ok(ws) => ws,
Err(why) => {
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), error = %why, ?delay, "{} connect failed, retrying", task.label());
if !cancellable_sleep(delay, task.cancel()).await {
return;
}
continue;
}
task.finish_failed(format!("connect: {why}"));
return;
}
};
let ws = match task.handshake(ws).await {
Ok(ws) => ws,
Err(HandshakeError::Fatal(why)) => {
task.finish_failed(format!("handshake: {why}"));
return;
}
Err(HandshakeError::Transient(why)) => {
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), error = %why, ?delay, "{} handshake failed, retrying", task.label());
if !cancellable_sleep(delay, task.cancel()).await {
return;
}
continue;
}
task.finish_failed(format!("handshake: {why}"));
return;
}
};
task.publish(ConnectionStatus::Up);
let connected_at = Instant::now();
match task.message_loop(ws).await {
MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled => return,
MessageLoopExit::ConnectionLost(why) => {
if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
budget.reset();
}
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), reason = %why, ?delay, "{} connection lost, reconnecting", task.label());
if !cancellable_sleep(delay, task.cancel()).await {
return;
}
continue;
}
task.finish_failed(format!("connection lost: {why}"));
return;
}
MessageLoopExit::Terminal(why) => {
task.finish_failed(why);
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn budget_exhausts_after_max_attempts() {
let mut b = ReconnectBudget::new();
for _ in 0..RECONNECT_MAX_ATTEMPTS {
assert!(b.next_backoff().is_some());
}
assert!(b.next_backoff().is_none());
}
#[test]
fn budget_reset_restores_full_budget() {
let mut b = ReconnectBudget::new();
b.next_backoff();
b.next_backoff();
b.reset();
assert_eq!(b.attempt(), 0);
}
#[test]
fn streaming_guard_balances_the_count() {
let coord = Arc::new(ReconnectCoordinator::new());
assert_eq!(coord.streaming.load(Ordering::SeqCst), 0);
let g = coord.enter();
assert_eq!(coord.streaming.load(Ordering::SeqCst), 1);
drop(g);
assert_eq!(coord.streaming.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn reconnect_slot_available_when_link_is_quiet() {
let coord = Arc::new(ReconnectCoordinator::new());
let cancel = CancellationToken::new();
assert!(coord.reconnect_slot(&cancel).await.is_some());
}
#[tokio::test]
async fn reconnect_slot_unparks_when_last_sibling_leaves() {
let coord = Arc::new(ReconnectCoordinator::new());
let cancel = CancellationToken::new();
let guard = coord.enter();
let waiter = tokio::spawn({
let coord = coord.clone();
let cancel = cancel.clone();
async move { coord.reconnect_slot(&cancel).await.is_some() }
});
tokio::task::yield_now().await;
assert!(!waiter.is_finished());
drop(guard); assert!(waiter.await.unwrap());
}
#[tokio::test]
async fn reconnect_slot_returns_none_on_cancel_while_parked() {
let coord = Arc::new(ReconnectCoordinator::new());
let _guard = coord.enter(); let cancel = CancellationToken::new();
cancel.cancel();
assert!(coord.reconnect_slot(&cancel).await.is_none());
}
#[test]
fn discount_parked_does_not_consume_the_budget() {
let mut b = ReconnectBudget::new();
b.discount_parked(2 * RECONNECT_MAX_TOTAL);
for _ in 0..RECONNECT_MAX_ATTEMPTS {
assert!(b.next_backoff().is_some());
}
}
}