use std::{collections::VecDeque, time::Duration};
use simulator_api::{
BacktestError, BacktestStatus, ContinueParams, ContinueToParams, CreateBacktestSessionRequest,
DiscoveryBatchEvent, PausedEvent,
};
use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
use thiserror::Error;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use super::{
ConnectionStatus, ControlEvent, ControlHandle, SessionInfo, SubscriptionHandle,
SubscriptionNotification, spawn_account_diff_subscription_manager, spawn_control_manager,
spawn_transaction_subscription_manager,
};
use crate::subscriptions::AccountDiffNotification;
#[derive(Debug, Error)]
pub enum ManagedSessionError {
#[error("session create failed: {0}")]
Create(String),
#[error("control channel closed")]
ControlClosed,
#[error("control failed: {0}")]
ControlFailed(String),
#[error("subscription failed: {0}")]
SubscriptionFailed(String),
#[error("cancelled")]
Cancelled,
#[error("control closed while sending continue: {0}")]
ContinueSend(String),
}
#[derive(Debug)]
pub enum ManagedEvent {
ReadyForContinue,
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Slot(u64),
Status(BacktestStatus),
Completed,
Error(BacktestError),
Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
AccountDiff(AccountDiffNotification),
}
const DEFAULT_COMPLETION_DRAIN_TIMEOUT: Duration = Duration::from_secs(60);
pub struct ManagedBacktestSession {
session_info: SessionInfo,
control: Option<ControlHandle>,
subscriptions: Vec<SubscriptionHandle>,
session_cancel: CancellationToken,
post_completion: Option<VecDeque<ManagedEvent>>,
completion_drain_timeout: Duration,
}
impl ManagedBacktestSession {
pub async fn start(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
) -> Result<Self, ManagedSessionError> {
Self::start_with_cancel(url, api_key, create, CancellationToken::new()).await
}
pub async fn start_with_cancel(
url: String,
api_key: String,
create: CreateBacktestSessionRequest,
parent_cancel: CancellationToken,
) -> Result<Self, ManagedSessionError> {
let session_cancel = parent_cancel.child_token();
let mut control = spawn_control_manager(url, api_key, create, session_cancel.clone());
let session_info = tokio::select! {
biased;
_ = parent_cancel.cancelled() => {
session_cancel.cancel();
control.join().await;
return Err(ManagedSessionError::Cancelled);
}
result = control.wait_for_session() => {
result.map_err(ManagedSessionError::Create)?
}
};
Ok(Self {
session_info,
control: Some(control),
subscriptions: Vec::new(),
session_cancel,
post_completion: None,
completion_drain_timeout: DEFAULT_COMPLETION_DRAIN_TIMEOUT,
})
}
pub fn session_info(&self) -> &SessionInfo {
&self.session_info
}
pub fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_transaction_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
));
}
pub fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
self.subscriptions
.push(spawn_account_diff_subscription_manager(
self.session_info.rpc_endpoint.clone(),
program_ids,
self.session_cancel.clone(),
));
}
async fn drain_until_subscriptions_complete(
&mut self,
timeout: std::time::Duration,
) -> Vec<ManagedEvent> {
let mut events = Vec::new();
if self.subscriptions.is_empty() {
return events;
}
let deadline = tokio::time::Instant::now() + timeout;
loop {
while let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
events.push(event);
}
if self
.subscriptions
.iter()
.all(|s| s.notifications.is_closed())
{
return events;
}
tokio::select! {
biased;
_ = self.session_cancel.cancelled() => return events,
_ = tokio::time::sleep_until(deadline) => return events,
received = recv_any_open_subscription(&mut self.subscriptions) => {
if let Some(event) = received {
events.push(event);
}
}
}
}
}
pub async fn next_event(&mut self) -> Result<ManagedEvent, ManagedSessionError> {
if let Some(buffered) = self.post_completion.as_mut() {
return buffered
.pop_front()
.ok_or(ManagedSessionError::ControlClosed);
}
if let Some(event) = try_next_subscription_event(&mut self.subscriptions) {
return Ok(event);
}
let event = {
let cancel = self.session_cancel.clone();
let control = self
.control
.as_mut()
.ok_or(ManagedSessionError::ControlClosed)?;
let subscriptions = &mut self.subscriptions;
tokio::select! {
biased;
_ = cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
event = control.events.recv() => {
event.map(ManagedEvent::from).ok_or(ManagedSessionError::ControlClosed)?
}
event = wait_any_subscription_event(subscriptions) => event,
}
};
if matches!(event, ManagedEvent::Completed) {
let mut buffered: VecDeque<ManagedEvent> = self
.drain_until_subscriptions_complete(self.completion_drain_timeout)
.await
.into();
buffered.push_back(ManagedEvent::Completed);
let first = buffered.pop_front().expect("buffer contains Completed");
self.post_completion = Some(buffered);
return Ok(first);
}
Ok(event)
}
pub async fn send_continue(
&mut self,
params: ContinueParams,
) -> Result<(), ManagedSessionError> {
self.wait_all_up().await?;
self.control_mut()?
.send_continue(params)
.await
.map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
}
pub async fn send_continue_to(
&mut self,
params: ContinueToParams,
) -> Result<(), ManagedSessionError> {
self.wait_all_up().await?;
self.control_mut()?
.send_continue_to(params)
.await
.map_err(|e| ManagedSessionError::ContinueSend(e.to_string()))
}
pub async fn shutdown(mut self) {
self.session_cancel.cancel();
if let Some(control) = self.control.take() {
control.join().await;
}
for sub in std::mem::take(&mut self.subscriptions) {
let _ = sub.join.await;
}
}
fn control_mut(&mut self) -> Result<&mut ControlHandle, ManagedSessionError> {
self.control
.as_mut()
.ok_or(ManagedSessionError::ControlClosed)
}
async fn wait_all_up(&self) -> Result<(), ManagedSessionError> {
let mut control = self
.control
.as_ref()
.ok_or(ManagedSessionError::ControlClosed)?
.status
.clone();
let mut subscriptions: Vec<watch::Receiver<ConnectionStatus>> = self
.subscriptions
.iter()
.map(|s| s.status.clone())
.collect();
loop {
let control_status = control.borrow().clone();
if let ConnectionStatus::Failed(why) = &control_status {
return Err(ManagedSessionError::ControlFailed(why.clone()));
}
let mut all_subscriptions_up = true;
for subscription in &subscriptions {
match &*subscription.borrow() {
ConnectionStatus::Failed(why) => {
return Err(ManagedSessionError::SubscriptionFailed(why.clone()));
}
ConnectionStatus::Up => {}
_ => all_subscriptions_up = false,
}
}
if control_status == ConnectionStatus::Up && all_subscriptions_up {
return Ok(());
}
tokio::select! {
_ = self.session_cancel.cancelled() => return Err(ManagedSessionError::Cancelled),
_ = control.changed() => {}
_ = wait_any_subscription_change(&mut subscriptions) => {}
}
}
}
}
impl Drop for ManagedBacktestSession {
fn drop(&mut self) {
self.session_cancel.cancel();
}
}
async fn wait_any_subscription_change(subscriptions: &mut [watch::Receiver<ConnectionStatus>]) {
if subscriptions.is_empty() {
std::future::pending::<()>().await;
return;
}
let _ =
futures::future::select_all(subscriptions.iter_mut().map(|s| Box::pin(s.changed()))).await;
}
async fn wait_any_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> ManagedEvent {
loop {
if let Some(event) = try_next_subscription_event(subscriptions) {
return event;
}
let futures: Vec<_> = subscriptions
.iter_mut()
.filter(|s| !s.notifications.is_closed())
.map(|s| Box::pin(s.notifications.recv()))
.collect();
if futures.is_empty() {
std::future::pending::<()>().await;
}
let (notification, _, _) = futures::future::select_all(futures).await;
if let Some(notification) = notification {
return notification.into();
}
}
}
async fn recv_any_open_subscription(
subscriptions: &mut [SubscriptionHandle],
) -> Option<ManagedEvent> {
let futures: Vec<_> = subscriptions
.iter_mut()
.filter(|s| !s.notifications.is_closed())
.map(|s| Box::pin(s.notifications.recv()))
.collect();
if futures.is_empty() {
return None;
}
let (notification, _, _) = futures::future::select_all(futures).await;
notification.map(Into::into)
}
fn try_next_subscription_event(subscriptions: &mut [SubscriptionHandle]) -> Option<ManagedEvent> {
for subscription in subscriptions {
if let Ok(notification) = subscription.notifications.try_recv() {
return Some(notification.into());
}
}
None
}
impl From<ControlEvent> for ManagedEvent {
fn from(event: ControlEvent) -> Self {
match event {
ControlEvent::ReadyForContinue => Self::ReadyForContinue,
ControlEvent::Paused(event) => Self::Paused(event),
ControlEvent::DiscoveryBatch(event) => Self::DiscoveryBatch(event),
ControlEvent::Slot(slot) => Self::Slot(slot),
ControlEvent::Status(status) => Self::Status(status),
ControlEvent::Completed => Self::Completed,
ControlEvent::Error(error) => Self::Error(error),
}
}
}
impl From<SubscriptionNotification> for ManagedEvent {
fn from(notification: SubscriptionNotification) -> Self {
match notification {
SubscriptionNotification::Transaction(transaction) => Self::Transaction(transaction),
SubscriptionNotification::AccountDiff(diff) => Self::AccountDiff(diff),
}
}
}