use std::{
collections::{BTreeMap, HashMap},
path::{Path, PathBuf},
sync::{Arc, Mutex},
time::Duration,
};
use clap::Parser;
use console::style;
use eyre::{Context, Result};
use indicatif::{MultiProgress, ProgressBar};
use sim_cli::{
engine::{
self, DrivableSession, ModificationsProvider, Outcome, SessionFailure, SessionObserver,
SlotRange, format_slot,
},
signals::spawn_ctrlc_cursor_fix,
};
use simulator_client::{
BacktestClient, CreateSession, backtest_ws_url,
managed::{
ManagedBacktestSession, ManagedEvent, ManagedParallelSession, ManagedSessionError,
ParallelSubSession, ReconnectCoordinator,
},
modify_program_via_rpc, split_range,
};
use solana_client::nonblocking::rpc_client::RpcClient;
use solana_rpc_client::{http_sender::HttpSender, rpc_client::RpcClientConfig};
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::{
output::{OutputEvent, SessionStartedRecord, SimulationMetadata, SimulationSummary},
subscription::{RecordsStore, on_account_diff_notification, on_transaction_notification},
};
const MAX_PARALLEL_SESSIONS: usize = 50;
const CREATION_TIMEOUT_SECS: u16 = 900;
const DISCONNECT_TIMEOUT_SECS: u16 = 180;
const ESTIMATED_MS_PER_BLOCK: u64 = 250;
const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const HTTP_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
type PreloadedPrograms = Vec<(String, Vec<u8>)>;
#[derive(Clone)]
struct SessionInfra {
preloaded_programs: Arc<PreloadedPrograms>,
http_client: reqwest::Client,
}
fn make_rpc_client(rpc_endpoint: &str, http_client: &reqwest::Client) -> RpcClient {
let sender = HttpSender::new_with_client(rpc_endpoint, http_client.clone());
RpcClient::new_sender(sender, RpcClientConfig::default())
}
impl ModificationsProvider for SessionInfra {
fn is_empty(&self) -> bool {
self.preloaded_programs.is_empty()
}
async fn build(
&self,
rpc_endpoint: &str,
) -> Result<BTreeMap<solana_address::Address, simulator_api::AccountData>> {
if self.preloaded_programs.is_empty() {
return Ok(BTreeMap::new());
}
let rpc = make_rpc_client(rpc_endpoint, &self.http_client);
let mut mods = BTreeMap::new();
for (id, elf) in self.preloaded_programs.iter() {
let m = modify_program_via_rpc(&rpc, id, elf)
.await
.map_err(|e| eyre::eyre!("building program injection for {id}: {e}"))?;
mods.extend(m);
}
Ok(mods)
}
}
#[derive(Clone, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum SubscriptionType {
Logs,
AccountDiff,
Transaction,
}
#[derive(Clone)]
struct OutputSink {
records: RecordsStore,
output_tx: tokio::sync::mpsc::UnboundedSender<OutputEvent>,
}
impl SessionObserver for OutputSink {
fn on_data_event(&self, event: ManagedEvent) {
match event {
ManagedEvent::Transaction(notification) => on_transaction_notification(
self.records.clone(),
self.output_tx.clone(),
*notification,
),
ManagedEvent::AccountDiff(notification) => on_account_diff_notification(
self.records.clone(),
self.output_tx.clone(),
notification,
),
_ => {}
}
}
}
#[derive(Parser, Debug, Clone)]
pub struct RunArgs {
#[arg(long, env = "SIMULATOR_START_SLOT")]
pub start_slot: u64,
#[arg(long, env = "SIMULATOR_END_SLOT")]
pub end_slot: u64,
#[arg(long, env = "SIMULATOR_ADVANCE_COUNT")]
pub advance_count: Option<u64>,
#[arg(long, env = "SIMULATOR_PARALLEL", default_value_t = false)]
pub parallel: bool,
#[arg(long, env = "SIMULATOR_FAIL_FAST", default_value_t = false)]
pub fail_fast: bool,
#[arg(long, env = "SIMULATOR_PROGRAM_ID")]
pub program_id: Vec<String>,
#[arg(long, env = "SIMULATOR_PROGRAM_SO")]
pub program_so: Vec<PathBuf>,
#[arg(long, env = "SIMULATOR_OUTPUT_FILE")]
pub output_file: Option<PathBuf>,
#[arg(long, env = "SIMULATOR_EXTRA_COMPUTE_UNITS")]
pub extra_compute_units: Option<u32>,
#[arg(long, env = "SIMULATOR_SUBSCRIPTION", requires = "program_id")]
pub subscription: Option<SubscriptionType>,
#[arg(long, default_value_t = false)]
pub verbose: bool,
#[arg(long, env = "SIMULATOR_PLAN", default_value_t = false)]
pub plan: bool,
}
fn validate_args(args: &RunArgs) -> Result<()> {
if !args.program_so.is_empty() {
eyre::ensure!(
args.program_so.len() == args.program_id.len(),
"--program-so count ({}) must match --program-id count ({})",
args.program_so.len(),
args.program_id.len(),
);
}
Ok(())
}
async fn generate_plan(client: &BacktestClient, args: &RunArgs) -> Result<()> {
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
if sub_ranges.len() > MAX_PARALLEL_SESSIONS {
eyre::bail!(
"unable to generate plan: requires {} sessions but max is {MAX_PARALLEL_SESSIONS}",
sub_ranges.len(),
);
}
for window in sub_ranges.windows(2) {
let (_, end) = window[0];
let (next_start, _) = window[1];
if end + 1 != next_start {
eyre::bail!(
"unable to generate plan: gap in data between slots {} and {}",
format_slot(end),
format_slot(next_start),
);
}
}
let max_slots = sub_ranges
.iter()
.map(|(start, end)| end.saturating_sub(*start) + 1)
.max()
.unwrap_or(0);
let estimated_ms = max_slots * ESTIMATED_MS_PER_BLOCK;
let estimated = if estimated_ms < 60 * 60 * 1000 {
format!("~{}m", estimated_ms / 60_000)
} else {
format!("~{:.1}h", estimated_ms as f64 / 3_600_000.0)
};
println!(
"\n{} {estimated} runtime · {} session(s)\n{} [{} → {}] · {} slots",
style(">").cyan(),
sub_ranges.len(),
style(">").cyan(),
format_slot(args.start_slot),
format_slot(args.end_slot),
format_slot(args.end_slot - args.start_slot + 1),
);
println!();
println!(
"{:<5} {:<20} {:<20} {:<15}",
"#", "Start Slot", "End Slot", "Slots"
);
println!("{}", "-".repeat(62));
for (i, (start, end)) in sub_ranges.iter().enumerate() {
let slots = end.saturating_sub(*start) + 1;
println!(
"{:<5} {:<20} {:<20} {:<15}",
i + 1,
format_slot(*start),
format_slot(*end),
format_slot(slots),
);
}
Ok(())
}
fn resolve_output_file(args: &RunArgs) -> PathBuf {
args.output_file.clone().unwrap_or_else(|| {
let suffix = if !args.program_so.is_empty() {
"override"
} else {
"baseline"
};
PathBuf::from(format!(
"{}_{}_{}.ndjson",
args.start_slot, args.end_slot, suffix
))
})
}
fn setup_output(
output_file: &Path,
args: &RunArgs,
) -> Result<(OutputSink, tokio::task::JoinHandle<Result<()>>)> {
let output_writer = open_output_writer(output_file)?;
let (output_tx, output_rx) = tokio::sync::mpsc::unbounded_channel::<OutputEvent>();
let output_writer_handle =
spawn_output_writer(output_file.to_path_buf(), output_writer, output_rx);
let metadata = SimulationMetadata {
start_slot: args.start_slot,
end_slot: args.end_slot,
program_ids: args.program_id.clone(),
program_so: args
.program_so
.iter()
.map(|p| p.display().to_string())
.collect(),
ran_at_unix_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_ids: Vec::new(),
};
let _ = output_tx.send(OutputEvent::Metadata(metadata));
let output = OutputSink {
records: RecordsStore::new(),
output_tx,
};
Ok((output, output_writer_handle))
}
fn build_infra(args: &RunArgs) -> Result<SessionInfra> {
let preloaded_programs: Arc<PreloadedPrograms> = Arc::new(
args.program_id
.iter()
.zip(&args.program_so)
.map(|(id, path)| {
let elf = std::fs::read(path)
.with_context(|| format!("failed to read {}", path.display()))?;
Ok((id.clone(), elf))
})
.collect::<Result<_>>()?,
);
let http_client = reqwest::Client::builder()
.timeout(HTTP_REQUEST_TIMEOUT)
.pool_idle_timeout(HTTP_POOL_IDLE_TIMEOUT)
.hickory_dns(true)
.build()
.context("failed to build shared HTTP client")?;
Ok(SessionInfra {
preloaded_programs,
http_client,
})
}
pub async fn run(
args: RunArgs,
url: String,
api_key: String,
cancellation: CancellationToken,
) -> Result<()> {
validate_args(&args)?;
let base_url = backtest_ws_url(&url);
let client = BacktestClient::builder()
.url(base_url.clone())
.api_key(api_key.clone())
.build();
if args.plan {
return generate_plan(&client, &args).await;
}
let output_file = resolve_output_file(&args);
let (output, output_writer_handle) = setup_output(&output_file, &args)?;
let session_ids: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let infra = build_infra(&args)?;
let args = Arc::new(args);
spawn_ctrlc_cursor_fix();
if args.verbose {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
}
let mut session_failures: Vec<SessionFailure> = Vec::new();
let mut session_count = 1;
if args.parallel {
let progress = MultiProgress::new();
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
session_count = sub_ranges.len();
info!("splitting into {session_count} parallel session(s)");
let _ = progress.println(format!(
"{} Splitting into {session_count} parallel session(s)",
style(">").cyan(),
));
let mut bars_by_start: HashMap<u64, ProgressBar> = HashMap::new();
for &(start, end) in &sub_ranges {
let bar = if args.verbose {
ProgressBar::hidden()
} else {
let bar = ProgressBar::new(end.saturating_sub(start) + 1);
bar.set_style(engine::spinner_style(start, end));
let bar = progress.add(bar);
bar.enable_steady_tick(Duration::from_millis(100));
bar
};
bars_by_start.insert(start, bar);
}
let create = CreateSession::builder()
.start_slot(args.start_slot)
.end_slot(args.end_slot)
.parallel(true)
.disconnect_timeout_secs(DISCONNECT_TIMEOUT_SECS)
.capacity_wait_timeout_secs(CREATION_TIMEOUT_SECS)
.maybe_extra_compute_units(args.extra_compute_units)
.build()
.into_request()
.map_err(|e| eyre::eyre!("building parallel create request: {e}"))?;
let parallel = match ManagedParallelSession::start_with_cancel(
base_url.clone(),
api_key.clone(),
create,
cancellation.clone(),
)
.await
{
Ok(p) => Some(p),
Err(ManagedSessionError::Cancelled) => None,
Err(e) => return Err(eyre::eyre!("parallel session create failed: {e}")),
};
if let Some(parallel) = parallel {
let bars_by_start = Arc::new(bars_by_start);
let drive = {
let args = args.clone();
let output = output.clone();
let session_ids = session_ids.clone();
let infra = infra.clone();
let bars = bars_by_start.clone();
move |session: ParallelSubSession| {
let args = args.clone();
let output = output.clone();
let session_ids = session_ids.clone();
let infra = infra.clone();
let bars = bars.clone();
async move {
drive_parallel_sub(session, &args, &bars, &output, &session_ids, &infra)
.await
}
}
};
session_failures = engine::run_parallel_sessions(
parallel,
bars_by_start,
&progress,
args.fail_fast,
None,
&cancellation,
drive,
)
.await;
} else {
for bar in bars_by_start.values() {
if !bar.is_finished() {
bar.abandon_with_message("cancelled");
}
}
}
} else {
let range = SlotRange {
start: args.start_slot,
end: args.end_slot,
};
let bar = if args.verbose {
ProgressBar::hidden()
} else {
let bar = ProgressBar::new(range.end.saturating_sub(range.start) + 1);
bar.set_style(engine::spinner_style(range.start, range.end));
bar.set_message("starting session...");
bar.enable_steady_tick(Duration::from_millis(100));
bar
};
drive_single(
base_url,
api_key,
&args,
&bar,
range,
&output,
&session_ids,
&infra,
&cancellation,
None,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))?;
}
if cancellation.is_cancelled() {
eprintln!("\n{} Cancelled by user", style("!").yellow());
let OutputSink { output_tx, .. } = output;
drop(output_tx);
let _ = output_writer_handle.await;
return Ok(());
}
let OutputSink {
records, output_tx, ..
} = output;
let session_ids = session_ids.lock().unwrap().clone();
let total = records.total();
let successes = records.successes();
let summary = SimulationSummary {
total_transactions: total,
successes,
failures: total.saturating_sub(successes),
session_ids,
};
let _ = output_tx.send(OutputEvent::Summary(summary.clone()));
drop(output_tx);
output_writer_handle
.await
.context("output writer task panicked")?
.context("output writer failed")?;
render_outcome(&session_failures, session_count, &summary, &output_file)
}
fn render_outcome(
session_failures: &[SessionFailure],
session_count: usize,
summary: &SimulationSummary,
output_file: &Path,
) -> Result<()> {
if !session_failures.is_empty() {
let mut by_reason: BTreeMap<&str, Vec<(u64, u64)>> = BTreeMap::new();
for failure in session_failures {
by_reason
.entry(failure.reason.as_str())
.or_default()
.extend(failure.range);
}
println!(
"\n{} {} of {} sub-range(s) did not complete",
style("✖").red(),
session_failures.len(),
session_count,
);
for (reason, ranges) in &mut by_reason {
ranges.sort_unstable();
let slots = if ranges.is_empty() {
String::new()
} else {
let list = ranges
.iter()
.map(|(s, e)| format!("{s}–{e}"))
.collect::<Vec<_>>()
.join(", ");
format!(" {}", style(format!("[slots {list}]")).dim())
};
println!(" {}{}", style(reason).yellow(), slots);
}
println!(
"\n Captured {} transaction(s) from the sub-ranges that finished",
summary.total_transactions,
);
println!(" Output file : {}", output_file.display());
return Err(eyre::eyre!(
"summary: {} of {} parallel sub-ranges failed",
session_failures.len(),
session_count,
));
}
if summary.total_transactions == 0 {
println!(
"\n{} Simulation finished with no transactions",
style("⚠").yellow()
);
println!(" Output file : {}", output_file.display());
return Err(eyre::eyre!(
"no matching transactions found in the requested slot range"
));
}
println!("\n{} Simulation complete", style("✔").green());
println!(" Transactions : {}", summary.total_transactions);
println!(" Successes : {}", summary.successes);
println!(" Failures : {}", summary.failures);
println!(" Output file : {}", output_file.display());
Ok(())
}
fn record_session_started<S: DrivableSession>(
session: &S,
range: SlotRange,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
) {
let info = session.session_info();
session_ids.lock().unwrap().push(info.session_id.clone());
info!(
session_id = %info.session_id,
task_id = ?info.task_id,
start = range.start,
end = range.end,
"session created"
);
let _ = output
.output_tx
.send(OutputEvent::SessionStarted(SessionStartedRecord {
session_id: info.session_id.clone(),
task_id: info.task_id.clone(),
start_slot: range.start,
end_slot: range.end,
}));
}
async fn subscribe_and_drive<S: DrivableSession>(
mut session: S,
args: &RunArgs,
pb: &ProgressBar,
range: SlotRange,
advance_count: u64,
output: &OutputSink,
infra: &SessionInfra,
) -> Result<Outcome> {
pb.set_message("waiting for runtime");
if !args.program_id.is_empty() {
session.subscribe_transactions(args.program_id.clone());
if matches!(args.subscription, Some(SubscriptionType::AccountDiff)) {
session.subscribe_account_diffs(args.program_id.clone());
}
}
let outcome =
engine::run_session_loop(&mut session, pb, range, advance_count, infra, output).await;
session.shutdown().await;
outcome
}
#[allow(clippy::too_many_arguments)]
async fn drive_single(
base_url: String,
api_key: String,
args: &RunArgs,
pb: &ProgressBar,
range: SlotRange,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
infra: &SessionInfra,
parent_cancel: &CancellationToken,
reconnect_coordinator: Option<Arc<ReconnectCoordinator>>,
) -> Result<()> {
let advance_count = args
.advance_count
.unwrap_or(range.end.saturating_sub(range.start) + 1);
let create = CreateSession::builder()
.start_slot(range.start)
.end_slot(range.end)
.disconnect_timeout_secs(DISCONNECT_TIMEOUT_SECS)
.capacity_wait_timeout_secs(CREATION_TIMEOUT_SECS)
.maybe_extra_compute_units(args.extra_compute_units)
.build()
.into_request()
.map_err(|e| eyre::eyre!("building create request: {e}"))?;
let session = match ManagedBacktestSession::start_with_cancel(
base_url,
api_key,
create,
parent_cancel.clone(),
reconnect_coordinator,
)
.await
{
Ok(s) => s,
Err(ManagedSessionError::Cancelled) => {
pb.abandon_with_message("cancelled");
return Ok(());
}
Err(e) => return Err(eyre::eyre!("session create failed: {e}")),
};
record_session_started(&session, range, output, session_ids);
let outcome = subscribe_and_drive(session, args, pb, range, advance_count, output, infra).await;
engine::finish_bar(pb, outcome)
}
#[allow(clippy::too_many_arguments)]
async fn drive_parallel_sub(
session: ParallelSubSession,
args: &RunArgs,
bars: &HashMap<u64, ProgressBar>,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
infra: &SessionInfra,
) -> std::result::Result<(), SessionFailure> {
let (start, end) = session.range();
let range = SlotRange { start, end };
let pb = bars
.get(&start)
.cloned()
.unwrap_or_else(ProgressBar::hidden);
let advance_count = args.advance_count.unwrap_or(end.saturating_sub(start) + 1);
record_session_started(&session, range, output, session_ids);
let outcome =
subscribe_and_drive(session, args, &pb, range, advance_count, output, infra).await;
engine::finish_bar(&pb, outcome).map_err(|e| SessionFailure {
range: Some((start, end)),
reason: format!("{e:#}"),
})
}
fn open_output_writer(path: &std::path::Path) -> Result<std::fs::File> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
&& !parent.exists()
{
std::fs::create_dir_all(parent)
.with_context(|| format!("could not create output directory '{}'", parent.display()))?;
}
std::fs::File::create(path).with_context(|| format!("could not create '{}'", path.display()))
}
fn spawn_output_writer(
path: PathBuf,
file: std::fs::File,
mut rx: tokio::sync::mpsc::UnboundedReceiver<OutputEvent>,
) -> tokio::task::JoinHandle<Result<()>> {
tokio::spawn(async move {
use std::io::Write;
let write_err = || format!("could not write to '{}'", path.display());
let mut writer = std::io::BufWriter::new(file);
while let Some(event) = rx.recv().await {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
while let Ok(event) = rx.try_recv() {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
}
writer.flush().with_context(write_err)?;
}
writer
.flush()
.with_context(|| format!("could not finalize '{}'", path.display()))?;
Ok(())
})
}