use std::{
collections::{BTreeMap, HashMap},
future::Future,
sync::Arc,
};
use console::style;
use eyre::Result;
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use simulator_api::{
AccountData, AccountModifications, AgentStatsReport, ContinueParams, SessionSummary,
};
use simulator_client::managed::{
ManagedBacktestSession, ManagedEvent, ManagedParallelSession, ManagedSessionError,
ParallelSubSession, SessionInfo,
};
use solana_address::Address;
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Copy)]
pub struct SlotRange {
pub start: u64,
pub end: u64,
}
pub enum Outcome {
Completed {
summary: Option<SessionSummary>,
agent_stats: Option<Vec<AgentStatsReport>>,
},
Cancelled,
}
pub struct SessionFailure {
pub range: Option<(u64, u64)>,
pub reason: String,
}
pub trait ModificationsProvider {
fn is_empty(&self) -> bool;
fn build(
&self,
rpc_endpoint: &str,
) -> impl Future<Output = Result<BTreeMap<Address, AccountData>>> + Send;
}
pub trait SessionObserver {
fn on_data_event(&self, event: ManagedEvent);
fn on_slot(&self, _slot: u64) {}
}
pub trait DrivableSession {
fn session_info(&self) -> &SessionInfo;
fn subscribe_transactions(&mut self, program_ids: Vec<String>);
fn subscribe_account_diffs(&mut self, program_ids: Vec<String>);
fn next_event(
&mut self,
) -> impl Future<Output = std::result::Result<ManagedEvent, ManagedSessionError>> + Send;
fn send_continue(
&mut self,
params: ContinueParams,
) -> impl Future<Output = std::result::Result<(), ManagedSessionError>> + Send;
fn shutdown(self) -> impl Future<Output = ()> + Send;
}
impl DrivableSession for ManagedBacktestSession {
fn session_info(&self) -> &SessionInfo {
ManagedBacktestSession::session_info(self)
}
fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
ManagedBacktestSession::subscribe_transactions(self, program_ids);
}
fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
ManagedBacktestSession::subscribe_account_diffs(self, program_ids);
}
async fn next_event(&mut self) -> std::result::Result<ManagedEvent, ManagedSessionError> {
ManagedBacktestSession::next_event(self).await
}
async fn send_continue(
&mut self,
params: ContinueParams,
) -> std::result::Result<(), ManagedSessionError> {
ManagedBacktestSession::send_continue(self, params).await
}
async fn shutdown(self) {
ManagedBacktestSession::shutdown(self).await
}
}
impl DrivableSession for ParallelSubSession {
fn session_info(&self) -> &SessionInfo {
ParallelSubSession::session_info(self)
}
fn subscribe_transactions(&mut self, program_ids: Vec<String>) {
ParallelSubSession::subscribe_transactions(self, program_ids);
}
fn subscribe_account_diffs(&mut self, program_ids: Vec<String>) {
ParallelSubSession::subscribe_account_diffs(self, program_ids);
}
async fn next_event(&mut self) -> std::result::Result<ManagedEvent, ManagedSessionError> {
ParallelSubSession::next_event(self).await
}
async fn send_continue(
&mut self,
params: ContinueParams,
) -> std::result::Result<(), ManagedSessionError> {
ParallelSubSession::send_continue(self, params).await
}
async fn shutdown(self) {
ParallelSubSession::shutdown(self).await
}
}
pub async fn run_session_loop<S: DrivableSession>(
session: &mut S,
pb: &ProgressBar,
range: SlotRange,
advance_count: u64,
modifications: &impl ModificationsProvider,
observer: &impl SessionObserver,
) -> Result<Outcome> {
let SlotRange { start, end } = range;
let rpc_endpoint = session.session_info().rpc_endpoint.clone();
let mut first_ready_seen = false;
loop {
let event = match session.next_event().await {
Ok(e) => e,
Err(ManagedSessionError::Cancelled) => {
pb.set_message("closing session…");
return Ok(Outcome::Cancelled);
}
Err(e) => return Err(eyre::eyre!("session failed: {e}")),
};
match event {
ManagedEvent::Completed {
summary,
agent_stats,
} => {
return Ok(Outcome::Completed {
summary,
agent_stats,
});
}
ManagedEvent::Error(e) => return Err(eyre::eyre!("simulator error: {e}")),
ManagedEvent::Slot(s) => {
update_progress_position(pb, start, s);
observer.on_slot(s);
}
ManagedEvent::Status(status) => {
if !first_ready_seen {
pb.set_message(status.to_string());
}
}
ManagedEvent::ReadyForContinue => {
let params = if !first_ready_seen {
if !modifications.is_empty() {
pb.set_message("injecting program");
}
let modifications = modifications.build(&rpc_endpoint).await?;
pb.set_style(bar_style(start, end));
pb.reset_elapsed();
pb.set_position(0);
pb.set_message("running");
first_ready_seen = true;
ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(modifications),
}
} else {
ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(Default::default()),
}
};
session
.send_continue(params)
.await
.map_err(|e| eyre::eyre!("send_continue: {e}"))?;
}
event @ (ManagedEvent::Transaction(_) | ManagedEvent::AccountDiff(_)) => {
observer.on_data_event(event);
}
ManagedEvent::Paused(_) | ManagedEvent::DiscoveryBatch(_) => {}
}
}
}
pub async fn run_parallel_sessions<F, Fut>(
mut parallel: ManagedParallelSession,
bars: Arc<HashMap<u64, ProgressBar>>,
progress: &MultiProgress,
fail_fast: bool,
concurrency: Option<Arc<Semaphore>>,
cancellation: &CancellationToken,
drive: F,
) -> Vec<SessionFailure>
where
F: Fn(ParallelSubSession) -> Fut,
Fut: Future<Output = std::result::Result<(), SessionFailure>> + Send + 'static,
{
let mut failures = Vec::new();
let mut join_set = tokio::task::JoinSet::new();
for session in parallel.take_sub_sessions() {
let future = drive(session);
let concurrency = concurrency.clone();
join_set.spawn(async move {
let _permit = match concurrency {
Some(semaphore) => semaphore.acquire_owned().await.ok(),
None => None,
};
future.await
});
}
loop {
tokio::select! {
biased;
_ = cancellation.cancelled() => {
join_set.abort_all();
for bar in bars.values() {
if !bar.is_finished() {
bar.abandon_with_message("cancelled");
}
}
break;
}
result = join_set.join_next() => {
let Some(result) = result else { break };
match result {
Ok(Ok(())) => {}
Ok(Err(failure)) => {
failures.push(failure);
if fail_fast {
let _ = progress.println(format!(
"\n{} --fail-fast: a sub-range failed, aborting the rest",
style("✗").red(),
));
cancellation.cancel();
join_set.abort_all();
for bar in bars.values() {
if !bar.is_finished() {
bar.abandon_with_message("aborted (--fail-fast)");
}
}
break;
}
}
Err(e) if !e.is_cancelled() => failures.push(SessionFailure {
range: None,
reason: format!("session task panicked: {e}"),
}),
_ => {}
}
}
}
}
parallel.shutdown().await;
failures
}
pub fn finish_bar(pb: &ProgressBar, outcome: Result<Outcome>) -> Result<()> {
match outcome {
Ok(Outcome::Completed { .. }) => {
pb.finish_with_message("done");
Ok(())
}
Ok(Outcome::Cancelled) => {
pb.abandon_with_message("cancelled");
Ok(())
}
Err(e) => Err(e),
}
}
fn update_progress_position(pb: &ProgressBar, start_slot: u64, slot: u64) {
pb.set_position(slot.saturating_sub(start_slot).saturating_add(1));
}
pub fn spinner_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {spinner:.cyan} {msg}")
.unwrap_or_else(|_| ProgressStyle::default_spinner())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
fn bar_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {bar:40.cyan/blue} {pos}/{len} slots ({elapsed}) {msg}")
.unwrap_or_else(|_| ProgressStyle::default_bar())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
pub fn format_slot(n: u64) -> String {
let s = n.to_string();
let mut out = String::new();
for (i, c) in s.chars().rev().enumerate() {
if i > 0 && i % 3 == 0 {
out.push(',');
}
out.push(c);
}
out.chars().rev().collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_progress_position_tracks_slot_offset() {
let pb = ProgressBar::new(100);
update_progress_position(&pb, 362_270_659, 362_270_659);
assert_eq!(pb.position(), 1);
update_progress_position(&pb, 362_270_659, 362_270_700);
assert_eq!(pb.position(), 42);
}
}