use std::{
collections::{HashMap, VecDeque},
sync::Arc,
};
use futures::StreamExt;
use simulator_api::{
BacktestRequest, BacktestResponse, ContinueParams, ContinueSessionRequestV1,
CreateBacktestSessionRequest, SessionEventKind,
};
use tokio::{
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::{
ConnectionStatus, ControlConnection, ControlEvent, HANDSHAKE_RESPONSE_TIMEOUT, HandshakeError,
InboundFrame, KEEPALIVE_INTERVAL, ManagedEvent, ManagedSessionError, MessageLoopExit,
ReconnectCoordinator, SessionInfo, SubscriptionHandle, Ws, classify_inbound, graceful_close,
handshake_error_for_response, is_terminal_backtest_error, resolve_rpc_url, run_control_loop,
send_keepalive_ping, send_request,
session::{
DrainOutcome, drain_subscriptions_until_complete, try_next_subscription_event,
wait_any_subscription_event, wait_connections_up,
},
spawn_account_diff_subscription_manager, spawn_transaction_subscription_manager,
};
use crate::{error::err_chain, urls::http_base_from_ws_url};
const CREATE_RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(900);
const CONTINUE_CHANNEL_CAPACITY: usize = 256;
const COMPLETION_DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
type TaggedContinue = (String, ContinueParams);
struct ParallelCreated {
control_session_id: String,
sessions: Vec<CreatedSubSession>,
}
struct CreatedSubSession {
info: SessionInfo,
events: mpsc::UnboundedReceiver<ControlEvent>,
start_slot: u64,
end_slot: u64,
}
struct ParallelControlHandle {
continues: mpsc::Sender<TaggedContinue>,
status: watch::Receiver<ConnectionStatus>,
created: Option<oneshot::Receiver<Result<ParallelCreated, String>>>,
join: JoinHandle<()>,
}
impl ParallelControlHandle {
async fn wait_created(&mut self) -> Result<ParallelCreated, String> {
let rx = self
.created
.take()
.ok_or_else(|| "parallel create already consumed".to_string())?;
rx.await
.map_err(|_| "control manager exited before creating sessions".to_string())?
}
async fn join(self) {
drop(self.continues);
let _ = self.join.await;
}
}
fn spawn_parallel_control_manager(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
cancel: CancellationToken,
) -> ParallelControlHandle {
let (continues_tx, continues_rx) = mpsc::channel::<TaggedContinue>(CONTINUE_CHANNEL_CAPACITY);
let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
let (created_tx, created_rx) = oneshot::channel::<Result<ParallelCreated, String>>();
let task = ParallelControlTask {
url,
api_key,
create: Some(create),
control_session_id: None,
event_txs: HashMap::new(),
last_sequences: HashMap::new(),
completed: std::collections::HashSet::new(),
continues_rx,
status_tx,
created_tx: Some(created_tx),
cancel,
};
let join = tokio::spawn(run_control_loop(task));
ParallelControlHandle {
continues: continues_tx,
status: status_rx,
created: Some(created_rx),
join,
}
}
struct ParallelControlTask {
url: String,
api_key: String,
create: Option<CreateBacktestSessionRequest>,
control_session_id: Option<String>,
event_txs: HashMap<String, mpsc::UnboundedSender<ControlEvent>>,
last_sequences: HashMap<String, u64>,
completed: std::collections::HashSet<String>,
continues_rx: mpsc::Receiver<TaggedContinue>,
status_tx: watch::Sender<ConnectionStatus>,
created_tx: Option<oneshot::Sender<Result<ParallelCreated, String>>>,
cancel: CancellationToken,
}
impl ControlConnection for ParallelControlTask {
fn url(&self) -> &str {
&self.url
}
fn api_key(&self) -> &str {
&self.api_key
}
fn cancel(&self) -> &CancellationToken {
&self.cancel
}
fn label(&self) -> &'static str {
"parallel control"
}
fn status_tx(&self) -> &watch::Sender<ConnectionStatus> {
&self.status_tx
}
fn fail_pending(&mut self, reason: String) {
if let Some(tx) = self.created_tx.take() {
let _ = tx.send(Err(reason));
}
}
async fn handshake(&mut self, ws: Ws) -> Result<Ws, HandshakeError> {
if let Some(control_session_id) = self.control_session_id.clone() {
self.attach(ws, &control_session_id).await
} else if let Some(create) = self.create.clone() {
self.create_sessions(ws, create).await
} else {
Err(HandshakeError::Fatal(
"no create request and no control_session_id".into(),
))
}
}
async fn message_loop(&mut self, mut ws: Ws) -> MessageLoopExit {
let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut last_inbound = std::time::Instant::now();
let exit = loop {
tokio::select! {
biased;
_ = self.cancel.cancelled() => break MessageLoopExit::Cancelled,
_ = ping_timer.tick() => {
if let Some(why) = send_keepalive_ping(&mut ws, last_inbound).await {
break MessageLoopExit::ConnectionLost(why);
}
}
msg = ws.next() => {
last_inbound = std::time::Instant::now();
match classify_inbound(msg) {
InboundFrame::Text(t) => {
if let Some(exit) = self.handle_text(&t) {
break exit;
}
}
InboundFrame::Ignore => {}
InboundFrame::Lost(why) => break MessageLoopExit::ConnectionLost(why),
}
}
req = self.continues_rx.recv() => {
match req {
Some((session_id, request)) => {
let msg = BacktestRequest::ContinueSessionV1(ContinueSessionRequestV1 { session_id, request });
if let Err(e) = send_request(&mut ws, &msg).await {
break MessageLoopExit::ConnectionLost(format!("continue send: {e}"));
}
}
None => break MessageLoopExit::SessionEnded,
}
}
}
};
if matches!(
exit,
MessageLoopExit::SessionEnded | MessageLoopExit::Cancelled
) {
graceful_close(&mut ws).await;
}
exit
}
}
impl ParallelControlTask {
async fn create_sessions(
&mut self,
mut ws: Ws,
create: CreateBacktestSessionRequest,
) -> Result<Ws, HandshakeError> {
send_request(&mut ws, &BacktestRequest::CreateBacktestSession(create))
.await
.map_err(HandshakeError::Transient)?;
let rpc_base = http_base_from_ws_url(&self.url);
let mut sessions: Vec<CreatedSubSession> = Vec::new();
loop {
let response = next_response(&mut ws, CREATE_RESPONSE_TIMEOUT)
.await
.map_err(HandshakeError::Transient)?;
match response {
BacktestResponse::SessionCreated {
session_id,
rpc_endpoint,
task_id,
} => {
let (event_tx, event_rx) = mpsc::unbounded_channel::<ControlEvent>();
self.event_txs.insert(session_id.clone(), event_tx);
sessions.push(CreatedSubSession {
info: SessionInfo {
rpc_endpoint: resolve_rpc_url(&rpc_base, &rpc_endpoint),
session_id,
task_id,
},
events: event_rx,
start_slot: 0,
end_slot: 0,
});
}
BacktestResponse::SessionEventV2 {
session_id,
seq_id,
event,
} => {
self.route_event(&session_id, seq_id, event);
}
BacktestResponse::SessionsCreatedV2 {
control_session_id,
session_ids,
start_slots,
end_slots,
..
} => {
info!(
%control_session_id,
sessions = sessions.len(),
"parallel sessions created"
);
if session_ids.len() != start_slots.len()
|| session_ids.len() != end_slots.len()
{
return Err(HandshakeError::Fatal(format!(
"server did not report per-sub-session ranges \
(session_ids={}, start_slots={}, end_slots={}); \
server is too old for the multiplexed parallel client",
session_ids.len(),
start_slots.len(),
end_slots.len(),
)));
}
for ((id, start), end) in session_ids.iter().zip(&start_slots).zip(&end_slots) {
if let Some(s) = sessions.iter_mut().find(|s| s.info.session_id == *id) {
s.start_slot = *start;
s.end_slot = *end;
}
}
self.control_session_id = Some(control_session_id.clone());
self.create = None;
if let Some(tx) = self.created_tx.take() {
let _ = tx.send(Ok(ParallelCreated {
control_session_id,
sessions,
}));
}
return Ok(ws);
}
BacktestResponse::Error(err) => {
return Err(HandshakeError::Fatal(format!(
"server error: {}",
err_chain(&err)
)));
}
_ => {}
}
}
}
async fn attach(&mut self, mut ws: Ws, control_session_id: &str) -> Result<Ws, HandshakeError> {
send_request(
&mut ws,
&BacktestRequest::AttachParallelControlSessionV2 {
control_session_id: control_session_id.to_string(),
last_sequences: self.last_sequences.clone().into_iter().collect(),
},
)
.await
.map_err(HandshakeError::Transient)?;
loop {
let response = next_response(&mut ws, HANDSHAKE_RESPONSE_TIMEOUT)
.await
.map_err(HandshakeError::Transient)?;
match response {
BacktestResponse::ParallelSessionAttachedV2 { .. } => {
debug!(%control_session_id, "parallel control reattached");
return Ok(ws);
}
BacktestResponse::SessionEventV2 {
session_id,
seq_id,
event,
} => {
self.route_event(&session_id, seq_id, event);
}
BacktestResponse::Error(err) => {
return Err(handshake_error_for_response("attach", err));
}
_ => {}
}
}
}
fn handle_text(&mut self, text: &str) -> Option<MessageLoopExit> {
let response = match serde_json::from_str::<BacktestResponse>(text) {
Ok(r) => r,
Err(e) => {
warn!(error = %err_chain(&e), "discarding undeserializable parallel control message");
return None;
}
};
match response {
BacktestResponse::SessionEventV2 {
session_id,
seq_id,
event,
} => {
self.route_event(&session_id, seq_id, event);
if self.completed.len() == self.event_txs.len() && !self.event_txs.is_empty() {
return Some(MessageLoopExit::SessionEnded);
}
}
BacktestResponse::Error(err) => {
if is_terminal_backtest_error(&err) {
return Some(MessageLoopExit::Terminal(format!(
"control session error: {}",
err_chain(&err)
)));
}
warn!(error = %err_chain(&err), "non-terminal parallel control error");
}
other => {
debug!(?other, "ignoring unexpected parallel control response");
}
}
None
}
fn route_event(&mut self, session_id: &str, seq_id: u64, event: SessionEventKind) {
match self.last_sequences.get_mut(session_id) {
Some(last) if seq_id <= *last => return,
Some(last) => *last = seq_id,
None => {
self.last_sequences.insert(session_id.to_string(), seq_id);
}
}
let is_completed = matches!(event, SessionEventKind::Completed { .. });
let Some(control_event) = session_event_to_control(event) else {
return;
};
if let Some(tx) = self.event_txs.get(session_id) {
let _ = tx.send(control_event);
}
if is_completed {
self.completed.insert(session_id.to_string());
}
}
}
fn session_event_to_control(event: SessionEventKind) -> Option<ControlEvent> {
Some(match event {
SessionEventKind::ReadyForContinue => ControlEvent::ReadyForContinue,
SessionEventKind::SlotNotification(slot) => ControlEvent::Slot(slot),
SessionEventKind::Paused(event) => ControlEvent::Paused(event),
SessionEventKind::DiscoveryBatch(event) => ControlEvent::DiscoveryBatch(event),
SessionEventKind::Error(error) => ControlEvent::Error(error),
SessionEventKind::Completed { summary } => ControlEvent::Completed {
summary,
agent_stats: None,
},
SessionEventKind::Status { status } => ControlEvent::Status(status),
SessionEventKind::Success => return None,
})
}
async fn next_response(
ws: &mut Ws,
timeout: std::time::Duration,
) -> Result<BacktestResponse, String> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
let msg = tokio::time::timeout_at(deadline, ws.next())
.await
.map_err(|_| format!("handshake timeout after {timeout:?}"))?;
let Some(msg) = msg else {
return Err("ws ended during handshake".into());
};
let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match String::from_utf8(b) {
Ok(t) => t,
Err(_) => continue,
},
Message::Close(frame) => {
return Err(format!("remote close during handshake: {frame:?}"));
}
_ => continue,
};
return serde_json::from_str::<BacktestResponse>(&text)
.map_err(|e| format!("deserialize: {}; raw={text}", err_chain(&e)));
}
}
pub struct ManagedParallelSession {
control_session_id: String,
control: Option<ParallelControlHandle>,
sub_sessions: Vec<ParallelSubSession>,
session_cancel: CancellationToken,
}
impl ManagedParallelSession {
pub async fn start_with_cancel(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
parent_cancel: CancellationToken,
) -> Result<Self, ManagedSessionError> {
let session_cancel = parent_cancel.child_token();
let mut control =
spawn_parallel_control_manager(url, api_key, create, session_cancel.clone());
let created = tokio::select! {
biased;
_ = parent_cancel.cancelled() => {
session_cancel.cancel();
control.join().await;
return Err(ManagedSessionError::Cancelled);
}
result = control.wait_created() => {
result.map_err(ManagedSessionError::Create)?
}
};
let reconnect_coordinator = Arc::new(ReconnectCoordinator::new());
let sub_sessions = created
.sessions
.into_iter()
.map(|s| ParallelSubSession {
session_info: s.info,
events: s.events,
continues: control.continues.clone(),
status: control.status.clone(),
subscriptions: Vec::new(),
session_cancel: session_cancel.child_token(),
post_completion: None,
post_completion_error: None,
reconnect_coordinator: Some(reconnect_coordinator.clone()),
start_slot: s.start_slot,
end_slot: s.end_slot,
})
.collect();
Ok(Self {
control_session_id: created.control_session_id,
control: Some(control),
sub_sessions,
session_cancel,
})
}
pub fn control_session_id(&self) -> &str {
&self.control_session_id
}
pub fn take_sub_sessions(&mut self) -> Vec<ParallelSubSession> {
std::mem::take(&mut self.sub_sessions)
}
pub async fn shutdown(mut self) {
self.session_cancel.cancel();
if let Some(control) = self.control.take() {
control.join().await;
}
}
}
impl Drop for ManagedParallelSession {
fn drop(&mut self) {
self.session_cancel.cancel();
}
}
pub struct ParallelSubSession {
session_info: SessionInfo,
events: mpsc::UnboundedReceiver<ControlEvent>,
continues: mpsc::Sender<TaggedContinue>,
status: watch::Receiver<ConnectionStatus>,
subscriptions: Vec<SubscriptionHandle>,
session_cancel: CancellationToken,
post_completion: Option<VecDeque<ManagedEvent>>,
post_completion_error: Option<ManagedSessionError>,
reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
start_slot: u64,
end_slot: u64,
}
impl ParallelSubSession {
pub fn session_info(&self) -> &SessionInfo {
&self.session_info
}
pub fn range(&self) -> (u64, u64) {
(self.start_slot, self.end_slot)
}
pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_transaction_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
self.reconnect_coordinator.clone(),
));
}
pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_account_diff_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
self.reconnect_coordinator.clone(),
));
}
pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
if let Some(buffered) = self.post_completion.as_mut() {
if let Some(event) = buffered.pop_front() {
return Ok(event);
}
return Err(self
.post_completion_error
.take()
.unwrap_or(ManagedSessionError::ControlClosed));
}
if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
return Ok(event);
}
let event = {
let cancel = &self.session_cancel;
let subscriptions = &mut self.subscriptions;
tokio::select! {
biased;
_ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
event = self.events.recv() => {
event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
}
event = wait_any_subscription_event(subscriptions) => event,
}
};
let ManagedEvent::Completed {
summary,
agent_stats,
} = event
else {
return Ok(event);
};
let (mut buffered, terminal): (VecDeque<ManagedEvent>, _) = match self
.drain_until_subscriptions_complete(COMPLETION_DRAIN_TIMEOUT)
.await
{
DrainOutcome::Complete(events) => (
events.into(),
Ok(ManagedEvent::Completed {
summary,
agent_stats,
}),
),
DrainOutcome::Stalled(events) => (
events.into(),
Err(ManagedSessionError::SubscriptionFailed(
"completion drain stalled: subscriptions did not deliver their \
end-of-stream terminals; the captured stream is incomplete"
.to_string(),
)),
),
};
match terminal {
Ok(completed) => buffered.push_back(completed),
Err(err) => self.post_completion_error = Some(err),
}
let first = buffered.pop_front();
self.post_completion = Some(buffered);
match first {
Some(event) => Ok(event),
None => Err(self
.post_completion_error
.take()
.unwrap_or(ManagedSessionError::ControlClosed)),
}
}
pub async fn send_continue(
&mut self,
params: ContinueParams,
) -> Result<(), ManagedSessionError> {
self.wait_all_up().await?;
self.continues
.send((self.session_info.session_id.clone(), params))
.await
.map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
}
async fn drain_until_subscriptions_complete(
&mut self,
idle_timeout: std::time::Duration,
) -> DrainOutcome {
drain_subscriptions_until_complete(
&mut self.subscriptions,
&self.session_cancel,
idle_timeout,
)
.await
}
pub async fn shutdown(mut self) {
self.session_cancel.cancel();
for sub in std::mem::take(&mut self.subscriptions) {
let _ = sub.join.await;
}
}
async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
let subscriptions = self
.subscriptions
.iter()
.map(|s| s.status.clone())
.collect();
wait_connections_up(self.status.clone(), subscriptions, &self.session_cancel).await
}
}
impl Drop for ParallelSubSession {
fn drop(&mut self) {
self.session_cancel.cancel();
}
}