use std::{collections::VecDeque, sync::Arc, time::Duration};
use simulator_api::{
AgentStatsReport, BacktestError, BacktestStatus, ContinueParams, ContinueToParams,
CreateBacktestSessionRequest, DiscoveryBatchEvent, PausedEvent, SessionSummary,
};
use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
use thiserror::Error;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use super::{
ConnectionStatus, ControlEvent, ControlHandle, ReconnectCoordinator, SessionInfo,
SubscriptionHandle, SubscriptionNotification, spawn_account_diff_subscription_manager,
spawn_control_manager, spawn_transaction_subscription_manager,
};
use crate::subscriptions::AccountDiffNotification;
#[derive(Debug, Error)]
pub enum ManagedSessionError {
#[error("session create failed: {0}")]
Create(String),
#[error("control channel closed")]
ControlClosed,
#[error("control failed: {0}")]
ControlFailed(String),
#[error("subscription failed: {0}")]
SubscriptionFailed(String),
#[error("cancelled")]
Cancelled,
#[error("control closed while sending continue: {0}")]
ContinueSend(String),
}
#[derive(Debug)]
pub enum ManagedEvent {
ReadyForContinue,
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Slot(u64),
Status(BacktestStatus),
Completed {
summary: Option<SessionSummary>,
agent_stats: Option<Vec<AgentStatsReport>>,
},
Error(BacktestError),
Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
AccountDiff(AccountDiffNotification),
}
const DEFAULT_COMPLETION_DRAIN_TIMEOUT: Duration = Duration::from_secs(60);
pub(super) enum DrainOutcome {
Complete(Vec<ManagedEvent>),
Stalled(Vec<ManagedEvent>),
}
pub struct ManagedBacktestSession {
session_info: SessionInfo,
control: Option<ControlHandle>,
subscriptions: Vec<SubscriptionHandle>,
session_cancel: CancellationToken,
post_completion: Option<VecDeque<ManagedEvent>>,
post_completion_error: Option<ManagedSessionError>,
completion_drain_timeout: Duration,
reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
}
impl ManagedBacktestSession {
pub async fn start(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
) -> Result<Self, ManagedSessionError> {
Self::start_with_cancel(url, api_key, create, CancellationToken::new(), None).await
}
pub async fn start_with_cancel(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
parent_cancel: CancellationToken,
reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
) -> Result<Self, ManagedSessionError> {
let session_cancel = parent_cancel.child_token();
let mut control = spawn_control_manager(url, api_key, create, session_cancel.clone());
let session_info = tokio::select! {
biased;
_ = parent_cancel.cancelled() => {
session_cancel.cancel();
control.join().await;
return Err(ManagedSessionError::Cancelled);
}
result = control.wait_for_session() => {
result.map_err(ManagedSessionError::Create)?
}
};
Ok(Self {
session_info,
control: Some(control),
subscriptions: Vec::new(),
session_cancel,
post_completion: None,
post_completion_error: None,
completion_drain_timeout: DEFAULT_COMPLETION_DRAIN_TIMEOUT,
reconnect_coordinator,
})
}
pub fn session_info(&self) -> &SessionInfo {
&self.session_info
}
pub fn set_completion_drain_timeout(&mut self, idle_timeout: std::time::Duration) {
self.completion_drain_timeout = idle_timeout;
}
pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_transaction_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
self.reconnect_coordinator.clone(),
));
}
pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_account_diff_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
self.reconnect_coordinator.clone(),
));
}
async fn drain_until_subscriptions_complete(
&mut self,
idle_timeout: std::time::Duration,
) -> DrainOutcome {
drain_subscriptions_until_complete(
&mut self.subscriptions,
&self.session_cancel,
idle_timeout,
)
.await
}
pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
if let Some(buffered) = self.post_completion.as_mut() {
if let Some(event) = buffered.pop_front() {
return Ok(event);
}
return Err(self
.post_completion_error
.take()
.unwrap_or(ManagedSessionError::ControlClosed));
}
if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
return Ok(event);
}
let event = {
let cancel = self.session_cancel.clone();
let control = self
.control
.as_mut()
.ok_or(ManagedSessionError::ControlClosed)?;
let subscriptions = &mut self.subscriptions;
tokio::select! {
biased;
_ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
event = control.events.recv() => {
event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
}
event = wait_any_subscription_event(subscriptions) => event,
}
};
let ManagedEvent::Completed {
summary,
agent_stats,
} = event
else {
return Ok(event);
};
let (mut buffered, terminal): (VecDeque<ManagedEvent>, _) = match self
.drain_until_subscriptions_complete(self.completion_drain_timeout)
.await
{
DrainOutcome::Complete(events) => (
events.into(),
Ok(ManagedEvent::Completed {
summary,
agent_stats,
}),
),
DrainOutcome::Stalled(events) => (
events.into(),
Err(ManagedSessionError::SubscriptionFailed(
"completion drain stalled: subscriptions did not deliver their \
end-of-stream terminals; the captured stream is incomplete"
.to_string(),
)),
),
};
match terminal {
Ok(completed) => buffered.push_back(completed),
Err(err) => self.post_completion_error = Some(err),
}
let first = buffered.pop_front();
self.post_completion = Some(buffered);
match first {
Some(event) => Ok(event),
None => Err(self
.post_completion_error
.take()
.unwrap_or(ManagedSessionError::ControlClosed)),
}
}
pub async fn send_continue(
&mut self,
params: ContinueParams,
) -> Result<(), ManagedSessionError> {
self.wait_all_up().await?;
self.control_mut()?
.send_continue(params)
.await
.map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
}
pub async fn send_continue_to(
&mut self,
params: ContinueToParams,
) -> Result<(), ManagedSessionError> {
self.wait_all_up().await?;
self.control_mut()?
.send_continue_to(params)
.await
.map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
}
pub async fn shutdown(mut self) {
self.session_cancel.cancel();
if let Some(control) = self.control.take() {
control.join().await;
}
for sub in std::mem::take(&mut self.subscriptions) {
let _ = sub.join.await;
}
}
fn control_mut(&mut self) -> Result<&mut ControlHandle, ManagedSessionError> {
self.control
.as_mut()
.ok_or(ManagedSessionError::ControlClosed)
}
async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
let control = self
.control
.as_ref()
.ok_or(ManagedSessionError::ControlClosed)?
.status
.clone();
let subscriptions = self
.subscriptions
.iter()
.map(|s| s.status.clone())
.collect();
wait_connections_up(control, subscriptions, &self.session_cancel).await
}
}
pub(super) async fn wait_connections_up(
mut control: watch::Receiver<ConnectionStatus>,
mut subscriptions: Vec<watch::Receiver<ConnectionStatus>>,
cancel: &CancellationToken,
) -> Result<(), ManagedSessionError> {
loop {
let control_status = control.borrow().clone();
if let ConnectionStatus::Failed(why) = &control_status {
return Err(ManagedSessionError::ControlFailed(why.clone()));
}
let mut all_subscriptions_up = true;
for subscription in &subscriptions {
match &*subscription.borrow() {
ConnectionStatus::Failed(why) => {
return Err(ManagedSessionError::SubscriptionFailed(why.clone()));
}
ConnectionStatus::Up => {}
_ => all_subscriptions_up = false,
}
}
if control_status == ConnectionStatus::Up && all_subscriptions_up {
return Ok(());
}
tokio::select! {
_ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
_ = control.changed() => {}
_ = wait_any_subscription_change(&mut subscriptions) => {}
}
}
}
pub(super) async fn drain_subscriptions_until_complete(
subscriptions: &mut [SubscriptionHandle],
cancel: &CancellationToken,
idle_timeout: std::time::Duration,
) -> DrainOutcome {
let mut events = Vec::new();
if subscriptions.is_empty() {
return DrainOutcome::Complete(events);
}
loop {
while let Some(event) = try_next_subscription_event(subscriptions) {
events.push(event);
}
if subscriptions.iter().all(|s| s.notifications.is_closed()) {
let any_failed = subscriptions
.iter()
.any(|s| matches!(*s.status.borrow(), ConnectionStatus::Failed(_)));
return if any_failed {
DrainOutcome::Stalled(events)
} else {
DrainOutcome::Complete(events)
};
}
tokio::select! {
biased;
_ = cancel.cancelled() => return DrainOutcome::Complete(events),
_ = tokio::time::sleep(idle_timeout) => {
let any_up = subscriptions.iter().any(|s| {
!s.notifications.is_closed()
&& matches!(*s.status.borrow(), ConnectionStatus::Up)
});
if any_up {
return DrainOutcome::Stalled(events);
}
}
received = recv_any_open_subscription(subscriptions) => {
if let Some(event) = received {
events.push(event);
}
}
}
}
}
impl Drop for ManagedBacktestSession {
fn drop(&mut self) {
self.session_cancel.cancel();
}
}
pub(super) async fn wait_any_subscription_change(
subscriptions: &mut [watch::Receiver<ConnectionStatus>],
) {
if subscriptions.is_empty() {
std::future::pending::<()>().await;
return;
}
let _ =
futures::future::select_all(subscriptions.iter_mut().map(|s| Box::pin(s.changed()))).await;
}
pub(super) async fn wait_any_subscription_event(
subscriptions: &mut [SubscriptionHandle],
) -> ManagedEvent {
loop {
if let Some(event) = try_next_subscription_event(subscriptions) {
return event;
}
let futures: Vec<_> = subscriptions
.iter_mut()
.filter(|s| !s.notifications.is_closed())
.map(|s| Box::pin(s.notifications.recv()))
.collect();
if futures.is_empty() {
std::future::pending::<()>().await;
}
let (notification, _, _) = futures::future::select_all(futures).await;
if let Some(notification) = notification {
return notification.into();
}
}
}
pub(super) async fn recv_any_open_subscription(
subscriptions: &mut [SubscriptionHandle],
) -> Option<ManagedEvent> {
let futures: Vec<_> = subscriptions
.iter_mut()
.filter(|s| !s.notifications.is_closed())
.map(|s| Box::pin(s.notifications.recv()))
.collect();
if futures.is_empty() {
return None;
}
let (notification, _, _) = futures::future::select_all(futures).await;
notification.map(Into::into)
}
pub(super) fn try_next_subscription_event(
subscriptions: &mut [SubscriptionHandle],
) -> Option<ManagedEvent> {
for subscription in subscriptions {
if let Ok(notification) = subscription.notifications.try_recv() {
return Some(notification.into());
}
}
None
}
impl From<ControlEvent> for ManagedEvent {
fn from(event: ControlEvent) -> Self {
match event {
ControlEvent::ReadyForContinue => Self::ReadyForContinue,
ControlEvent::Paused(event) => Self::Paused(event),
ControlEvent::DiscoveryBatch(event) => Self::DiscoveryBatch(event),
ControlEvent::Slot(slot) => Self::Slot(slot),
ControlEvent::Status(status) => Self::Status(status),
ControlEvent::Completed {
summary,
agent_stats,
} => Self::Completed {
summary,
agent_stats,
},
ControlEvent::Error(error) => Self::Error(error),
}
}
}
impl From<SubscriptionNotification> for ManagedEvent {
fn from(notification: SubscriptionNotification) -> Self {
match notification {
SubscriptionNotification::Transaction(transaction) => Self::Transaction(transaction),
SubscriptionNotification::AccountDiff(diff) => Self::AccountDiff(diff),
}
}
}