use std::time::Instant;
use futures::{SinkExt, StreamExt};
use simulator_api::{
BacktestError, BacktestRequest, BacktestResponse, BacktestStatus, ContinueParams,
ContinueToParams, CreateBacktestSessionRequest, DiscoveryBatchEvent, PausedEvent,
SequencedResponse,
};
use tokio::{
net::TcpStream,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Message, client::IntoClientRequest, http::HeaderValue},
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::{
CONNECT_TIMEOUT, ConnectionStatus, GRACEFUL_CLOSE_TIMEOUT, HANDSHAKE_RESPONSE_TIMEOUT,
KEEPALIVE_INTERVAL, KEEPALIVE_MISS_DEADLINE, RECONNECT_UPTIME_RESET, ReconnectBudget,
SessionInfo, cancellable_sleep,
};
use crate::{error::err_chain, urls::http_base_from_ws_url};
#[derive(Debug)]
pub enum ControlEvent {
ReadyForContinue,
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Slot(u64),
Status(BacktestStatus),
Completed,
Error(BacktestError),
}
pub struct ControlHandle {
continues: mpsc::Sender<ContinueParams>,
continue_tos: mpsc::Sender<ContinueToParams>,
pub events: mpsc::Receiver<ControlEvent>,
pub status: watch::Receiver<ConnectionStatus>,
session_info: Option<oneshot::Receiver<Result<SessionInfo, String>>>,
join: JoinHandle<()>,
}
impl ControlHandle {
pub async fn wait_for_session(&mut self) -> Result<SessionInfo, String> {
let rx = self
.session_info
.take()
.ok_or_else(|| "session_info already consumed".to_string())?;
rx.await
.map_err(|_| "control manager exited before creating session".to_string())?
}
pub async fn send_continue(
&self,
params: ContinueParams,
) -> Result<(), mpsc::error::SendError<ContinueParams>> {
self.continues.send(params).await
}
pub async fn send_continue_to(
&self,
params: ContinueToParams,
) -> Result<(), mpsc::error::SendError<ContinueToParams>> {
self.continue_tos.send(params).await
}
pub async fn join(self) {
drop(self.continues);
drop(self.continue_tos);
let _ = self.join.await;
}
}
pub fn spawn_control_manager(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
cancel: CancellationToken,
) -> ControlHandle {
let (continues_tx, continues_rx) = mpsc::channel::<ContinueParams>(1);
let (continue_tos_tx, continue_tos_rx) = mpsc::channel::<ContinueToParams>(1);
let (events_tx, events_rx) = mpsc::channel::<ControlEvent>(256);
let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
let (session_tx, session_rx) = oneshot::channel::<Result<SessionInfo, String>>();
let manager = ControlTask {
url,
api_key,
create: Some(create),
session_info: None,
session_tx: Some(session_tx),
last_sequence: None,
continues_rx,
continue_tos_rx,
events_tx,
status_tx,
cancel,
};
let join = tokio::spawn(manager.run());
ControlHandle {
continues: continues_tx,
continue_tos: continue_tos_tx,
events: events_rx,
status: status_rx,
session_info: Some(session_rx),
join,
}
}
type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
struct ControlTask {
url: String,
api_key: String,
create: Option<CreateBacktestSessionRequest>,
session_info: Option<SessionInfo>,
session_tx: Option<oneshot::Sender<Result<SessionInfo, String>>>,
last_sequence: Option<u64>,
continues_rx: mpsc::Receiver<ContinueParams>,
continue_tos_rx: mpsc::Receiver<ContinueToParams>,
events_tx: mpsc::Sender<ControlEvent>,
status_tx: watch::Sender<ConnectionStatus>,
cancel: CancellationToken,
}
enum MessageLoopExit {
SessionEnded,
Cancelled,
ConnectionLost(String),
Terminal(String),
}
impl ControlTask {
async fn run(mut self) {
let mut budget = ReconnectBudget::new();
loop {
if self.cancel.is_cancelled() {
self.fail_session_info_if_pending("cancelled before session created");
return;
}
self.publish(ConnectionStatus::Down);
let ws = match self.connect().await {
Ok(ws) => ws,
Err(why) => {
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), error = %why, ?delay, "control connect failed, retrying");
if !cancellable_sleep(delay, &self.cancel).await {
return;
}
continue;
}
self.finish_failed(format!("connect: {why}"));
return;
}
};
let ws = match self.handshake(ws).await {
Ok(ws) => ws,
Err(HandshakeError::Fatal(why)) => {
self.finish_failed(format!("handshake: {why}"));
return;
}
Err(HandshakeError::Transient(why)) => {
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), error = %why, ?delay, "control handshake failed, retrying");
if !cancellable_sleep(delay, &self.cancel).await {
return;
}
continue;
}
self.finish_failed(format!("handshake: {why}"));
return;
}
};
self.publish(ConnectionStatus::Up);
let connected_at = Instant::now();
let exit = self.message_loop(ws).await;
match exit {
MessageLoopExit::SessionEnded => return,
MessageLoopExit::Cancelled => return,
MessageLoopExit::ConnectionLost(why) => {
if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
budget.reset();
}
if let Some(delay) = budget.next_backoff() {
warn!(attempt = budget.attempt(), reason = %why, ?delay, "control connection lost, reconnecting");
if !cancellable_sleep(delay, &self.cancel).await {
return;
}
continue;
}
self.finish_failed(format!("connection lost: {why}"));
return;
}
MessageLoopExit::Terminal(why) => {
self.finish_failed(why);
return;
}
}
}
}
fn publish(&self, status: ConnectionStatus) {
self.status_tx.send_if_modified(|current| {
if *current == status {
false
} else {
*current = status;
true
}
});
}
fn fail_session_info_if_pending(&mut self, reason: &str) {
if let Some(tx) = self.session_tx.take() {
let _ = tx.send(Err(reason.to_string()));
}
}
fn finish_failed(&mut self, reason: String) {
self.fail_session_info_if_pending(&reason);
self.publish(ConnectionStatus::Failed(reason));
}
async fn connect(&self) -> Result<Ws, String> {
let mut request = self
.url
.clone()
.into_client_request()
.map_err(|e| format!("build request: {}", err_chain(&e)))?;
request.headers_mut().insert(
"X-API-Key",
HeaderValue::from_str(&self.api_key)
.map_err(|e| format!("api key header: {}", err_chain(&e)))?,
);
let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
.await
.map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
.map_err(|e| format!("connect: {}", err_chain(&e)))?;
Ok(connect.0)
}
async fn handshake(&mut self, mut ws: Ws) -> Result<Ws, HandshakeError> {
if let Some(info) = &self.session_info {
let info = info.clone();
attach(
&mut ws,
&info.session_id,
self.last_sequence,
&mut self.events_tx,
&mut self.last_sequence,
)
.await?;
resume(&mut ws, &mut self.events_tx, &mut self.last_sequence).await?;
debug!(session_id = info.session_id, "control reattached");
} else if let Some(create) = self.create.take() {
let info = create_session(
&mut ws,
create,
&self.url,
&mut self.events_tx,
&mut self.last_sequence,
)
.await?;
info!(session_id = info.session_id, "control session created");
self.session_info = Some(info.clone());
if let Some(tx) = self.session_tx.take() {
let _ = tx.send(Ok(info));
}
} else {
return Err(HandshakeError::Fatal(
"no create request and no session_id".into(),
));
}
Ok(ws)
}
async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut last_inbound = Instant::now();
let exit = loop {
tokio::select! {
biased;
_ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
_ = ping_timer.tick() => {
if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
break MessageLoopExit::ConnectionLost(format!(
"no traffic for {:?}", last_inbound.elapsed()
));
}
if let Err(e) = ws.send(Message::Ping(vec![])).await {
break MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
}
}
msg = ws.next() => {
last_inbound = Instant::now();
match msg {
Some(Ok(Message::Text(t))) => {
if let Err(exit) = self.handle_text(&t).await {
break exit;
}
}
Some(Ok(Message::Binary(b))) => {
if let Ok(t) = std::str::from_utf8(&b)
&& let Err(exit) = self.handle_text(t).await {
break exit;
}
}
Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
Some(Ok(Message::Close(frame))) => {
break MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
}
Some(Ok(Message::Frame(_))) => {}
Some(Err(e)) => {
break MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e)));
}
None => break MessageLoopExit::ConnectionLost("ws stream ended".into()),
}
}
req = self.continues_rx.recv() => {
match req {
Some(params) => {
if let Err(e) = send_request(&mut ws, &BacktestRequest::Continue(params)).await {
break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
}
}
None => {
break MessageLoopExit::SessionEnded;
}
}
}
req = self.continue_tos_rx.recv() => {
match req {
Some(params) => {
if let Err(e) = send_request(&mut ws, &BacktestRequest::ContinueTo(params)).await {
break MessageLoopExit::ConnectionLost(format!("continue_to send: {e}"));
}
}
None => break MessageLoopExit::SessionEnded,
}
}
}
};
if matches!(exit, MessageLoopExit::SessionEnded) {
graceful_close(&mut ws).await;
}
exit
}
async fn handle_text(&mut self, text: &str) -> Result<(), MessageLoopExit> {
let (seq, response) = match serde_json::from_str::<SequencedResponse>(text) {
Ok(s) => (Some(s.seq_id), s.response),
Err(_) => match serde_json::from_str::<BacktestResponse>(text) {
Ok(r) => (None, r),
Err(e) => {
warn!(error = %err_chain(&e), "discarding undeserializable control message");
return Ok(());
}
},
};
if let Some(s) = seq {
self.last_sequence = Some(s);
}
match response {
BacktestResponse::ReadyForContinue => {
let _ = self.events_tx.send(ControlEvent::ReadyForContinue).await;
}
BacktestResponse::Paused(event) => {
let _ = self.events_tx.send(ControlEvent::Paused(event)).await;
}
BacktestResponse::DiscoveryBatch(event) => {
let _ = self
.events_tx
.send(ControlEvent::DiscoveryBatch(event))
.await;
}
BacktestResponse::SlotNotification(slot) => {
let _ = self.events_tx.send(ControlEvent::Slot(slot)).await;
}
BacktestResponse::Completed { .. } => {
let _ = self.events_tx.send(ControlEvent::Completed).await;
return Err(MessageLoopExit::SessionEnded);
}
BacktestResponse::Error(err) => {
if matches!(&err, BacktestError::SimulationError { .. }) {
warn!(error = %err_chain(&err), "simulation error");
return Ok(());
}
let terminal = matches!(
&err,
BacktestError::NoMoreBlocks
| BacktestError::AdvanceSlotFailed { .. }
| BacktestError::FinalizeSlotFailed { .. }
| BacktestError::Internal { .. }
);
let _ = self.events_tx.send(ControlEvent::Error(err)).await;
if terminal {
return Err(MessageLoopExit::Terminal(
"server reported terminal error".into(),
));
}
}
BacktestResponse::Status { status } => {
let _ = self.events_tx.send(ControlEvent::Status(status)).await;
}
BacktestResponse::Success => {
}
other => {
debug!(?other, "ignoring unexpected control response");
}
}
Ok(())
}
}
enum HandshakeError {
Transient(String),
Fatal(String),
}
async fn create_session(
ws: &mut Ws,
request: CreateBacktestSessionRequest,
url: &str,
events: &mut mpsc::Sender<ControlEvent>,
last_sequence: &mut Option<u64>,
) -> Result<SessionInfo, HandshakeError> {
send_request(ws, &BacktestRequest::CreateBacktestSession(request))
.await
.map_err(HandshakeError::Transient)?;
let rpc_base = http_base_from_ws_url(url);
loop {
let response = next_response_with_timeout(ws, events, last_sequence)
.await
.map_err(HandshakeError::Transient)?;
match response {
BacktestResponse::SessionCreated {
session_id,
rpc_endpoint,
task_id,
} => {
let rpc_endpoint = resolve_rpc_url(&rpc_base, &rpc_endpoint);
return Ok(SessionInfo {
session_id,
rpc_endpoint,
task_id,
});
}
BacktestResponse::Error(err) => {
return Err(HandshakeError::Fatal(format!(
"server error: {}",
err_chain(&err)
)));
}
_ => {
}
}
}
}
async fn attach(
ws: &mut Ws,
session_id: &str,
last_sequence: Option<u64>,
events: &mut mpsc::Sender<ControlEvent>,
last_seq_state: &mut Option<u64>,
) -> Result<(), HandshakeError> {
send_request(
ws,
&BacktestRequest::AttachBacktestSession {
session_id: session_id.to_string(),
last_sequence,
},
)
.await
.map_err(HandshakeError::Transient)?;
loop {
let response = next_response_with_timeout(ws, events, last_seq_state)
.await
.map_err(HandshakeError::Transient)?;
match response {
BacktestResponse::SessionAttached { .. } => return Ok(()),
BacktestResponse::Error(err) => {
return Err(handshake_error_for_response("attach", err));
}
_ => {}
}
}
}
async fn resume(
ws: &mut Ws,
events: &mut mpsc::Sender<ControlEvent>,
last_seq_state: &mut Option<u64>,
) -> Result<(), HandshakeError> {
send_request(ws, &BacktestRequest::ResumeAttachedSession)
.await
.map_err(HandshakeError::Transient)?;
loop {
let response = next_response_with_timeout(ws, events, last_seq_state)
.await
.map_err(HandshakeError::Transient)?;
match response {
BacktestResponse::Success => return Ok(()),
BacktestResponse::Error(err) => {
return Err(handshake_error_for_response("resume", err));
}
_ => {}
}
}
}
fn handshake_error_for_response(stage: &'static str, err: BacktestError) -> HandshakeError {
match err {
BacktestError::SessionOwnershipBusy { .. } => {
HandshakeError::Transient(format!("{stage} contended: {}", err_chain(&err)))
}
_ => HandshakeError::Fatal(format!("{stage} rejected: {}", err_chain(&err))),
}
}
async fn send_request(ws: &mut Ws, req: &BacktestRequest) -> Result<(), String> {
let text = serde_json::to_string(req).map_err(|e| format!("serialize: {}", err_chain(&e)))?;
ws.send(Message::Text(text))
.await
.map_err(|e| format!("send: {}", err_chain(&e)))
}
async fn next_response_with_timeout(
ws: &mut Ws,
events: &mut mpsc::Sender<ControlEvent>,
last_sequence: &mut Option<u64>,
) -> Result<BacktestResponse, String> {
let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
loop {
let msg = tokio::time::timeout_at(deadline, ws.next())
.await
.map_err(|_| format!("handshake timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
let Some(msg) = msg else {
return Err("ws ended during handshake".into());
};
let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match std::str::from_utf8(&b) {
Ok(t) => t.to_string(),
Err(_) => continue,
},
Message::Close(frame) => {
return Err(format!("remote close during handshake: {frame:?}"));
}
_ => continue,
};
let (seq, response) = match serde_json::from_str::<SequencedResponse>(&text) {
Ok(s) => (Some(s.seq_id), s.response),
Err(_) => (
None,
serde_json::from_str::<BacktestResponse>(&text)
.map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)))?,
),
};
if let Some(s) = seq {
*last_sequence = Some(s);
}
match response {
BacktestResponse::SlotNotification(slot) => {
let _ = events.send(ControlEvent::Slot(slot)).await;
}
BacktestResponse::ReadyForContinue => {
let _ = events.send(ControlEvent::ReadyForContinue).await;
}
BacktestResponse::Paused(event) => {
let _ = events.send(ControlEvent::Paused(event)).await;
}
BacktestResponse::DiscoveryBatch(event) => {
let _ = events.send(ControlEvent::DiscoveryBatch(event)).await;
}
BacktestResponse::Completed { .. } => {
let _ = events.send(ControlEvent::Completed).await;
}
other => return Ok(other),
}
}
}
async fn graceful_close(ws: &mut Ws) {
let _ = tokio::time::timeout(
GRACEFUL_CLOSE_TIMEOUT,
send_request(ws, &BacktestRequest::CloseBacktestSession),
)
.await;
let _ = tokio::time::timeout(GRACEFUL_CLOSE_TIMEOUT, ws.close(None)).await;
}
fn resolve_rpc_url(base: &str, endpoint: &str) -> String {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
endpoint.to_string()
} else {
format!("{}/{}", base, endpoint.trim_start_matches('/'))
}
}