use std::{
collections::BTreeMap,
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};
use clap::Parser;
use console::style;
use eyre::{Context, Result};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use simulator_api::{AccountModifications, ContinueParams};
use simulator_client::{
BacktestClient, CreateSession,
managed::{ManagedBacktestSession, ManagedEvent, ManagedSessionError},
modify_program_via_rpc, split_range,
};
use solana_client::nonblocking::rpc_client::RpcClient;
use solana_rpc_client::{http_sender::HttpSender, rpc_client::RpcClientConfig};
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::{
output::{
OutputEvent, SessionStartedRecord, SimulationMetadata, SimulationSummary, format_slot,
},
signals::spawn_ctrlc_cursor_fix,
subscription::{RecordsStore, on_account_diff_notification, on_transaction_notification},
};
const MAX_PARALLEL_SESSIONS: usize = 50;
const CREATION_TIMEOUT_SECS: u64 = 900;
const ESTIMATED_MS_PER_BLOCK: u64 = 250;
const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const HTTP_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
type PreloadedPrograms = Vec<(String, Vec<u8>)>;
#[derive(Clone)]
struct SessionInfra {
preloaded_programs: Arc<PreloadedPrograms>,
http_client: reqwest::Client,
}
fn make_rpc_client(rpc_endpoint: &str, http_client: &reqwest::Client) -> RpcClient {
let sender = HttpSender::new_with_client(rpc_endpoint, http_client.clone());
RpcClient::new_sender(sender, RpcClientConfig::default())
}
#[derive(Clone, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum SubscriptionType {
Logs,
AccountDiff,
Transaction,
}
#[derive(Clone)]
struct OutputSink {
records: RecordsStore,
output_tx: tokio::sync::mpsc::UnboundedSender<OutputEvent>,
}
#[derive(Clone, Copy)]
struct SlotRange {
start: u64,
end: u64,
}
#[derive(Parser, Debug, Clone)]
pub struct RunArgs {
#[arg(long, env = "SIMULATOR_START_SLOT")]
pub start_slot: u64,
#[arg(long, env = "SIMULATOR_END_SLOT")]
pub end_slot: u64,
#[arg(long, env = "SIMULATOR_ADVANCE_COUNT")]
pub advance_count: Option<u64>,
#[arg(long, env = "SIMULATOR_PARALLEL", default_value_t = false)]
pub parallel: bool,
#[arg(long, env = "SIMULATOR_PROGRAM_ID")]
pub program_id: Vec<String>,
#[arg(long, env = "SIMULATOR_PROGRAM_SO")]
pub program_so: Vec<PathBuf>,
#[arg(long, env = "SIMULATOR_OUTPUT_FILE")]
pub output_file: Option<PathBuf>,
#[arg(long, env = "SIMULATOR_EXTRA_COMPUTE_UNITS")]
pub extra_compute_units: Option<u32>,
#[arg(long, env = "SIMULATOR_SUBSCRIPTION", requires = "program_id")]
pub subscription: Option<SubscriptionType>,
#[arg(long, default_value_t = false)]
pub verbose: bool,
#[arg(long, env = "SIMULATOR_PLAN", default_value_t = false)]
pub plan: bool,
}
fn validate_args(args: &RunArgs) -> Result<()> {
if !args.program_so.is_empty() {
eyre::ensure!(
args.program_so.len() == args.program_id.len(),
"--program-so count ({}) must match --program-id count ({})",
args.program_so.len(),
args.program_id.len(),
);
}
Ok(())
}
async fn generate_plan(client: &BacktestClient, args: &RunArgs) -> Result<()> {
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
if sub_ranges.len() > MAX_PARALLEL_SESSIONS {
eyre::bail!(
"unable to generate plan: requires {} sessions but max is {MAX_PARALLEL_SESSIONS}",
sub_ranges.len(),
);
}
for window in sub_ranges.windows(2) {
let (_, end) = window[0];
let (next_start, _) = window[1];
if end + 1 != next_start {
eyre::bail!(
"unable to generate plan: gap in data between slots {} and {}",
format_slot(end),
format_slot(next_start),
);
}
}
let max_slots = sub_ranges
.iter()
.map(|(start, end)| end.saturating_sub(*start) + 1)
.max()
.unwrap_or(0);
let estimated_ms = max_slots * ESTIMATED_MS_PER_BLOCK;
let estimated = if estimated_ms < 60 * 60 * 1000 {
format!("~{}m", estimated_ms / 60_000)
} else {
format!("~{:.1}h", estimated_ms as f64 / 3_600_000.0)
};
println!(
"\n{} {estimated} runtime · {} session(s)\n{} [{} → {}] · {} slots",
style(">").cyan(),
sub_ranges.len(),
style(">").cyan(),
format_slot(args.start_slot),
format_slot(args.end_slot),
format_slot(args.end_slot - args.start_slot + 1),
);
println!();
println!(
"{:<5} {:<20} {:<20} {:<15}",
"#", "Start Slot", "End Slot", "Slots"
);
println!("{}", "-".repeat(62));
for (i, (start, end)) in sub_ranges.iter().enumerate() {
let slots = end.saturating_sub(*start) + 1;
println!(
"{:<5} {:<20} {:<20} {:<15}",
i + 1,
format_slot(*start),
format_slot(*end),
format_slot(slots),
);
}
Ok(())
}
pub async fn run(
args: RunArgs,
url: String,
api_key: String,
cancellation: CancellationToken,
) -> Result<()> {
validate_args(&args)?;
let base_url = if url.starts_with("ws://") || url.starts_with("wss://") {
format!("{}/backtest", url.trim_end_matches('/'))
} else {
format!("wss://{url}/backtest")
};
let client = BacktestClient::builder()
.url(base_url.clone())
.api_key(api_key.clone())
.build();
if args.plan {
return generate_plan(&client, &args).await;
}
let output_file = args.output_file.clone().unwrap_or_else(|| {
let suffix = if !args.program_so.is_empty() {
"override"
} else {
"baseline"
};
PathBuf::from(format!(
"{}_{}_{}.ndjson",
args.start_slot, args.end_slot, suffix
))
});
let output_writer = open_output_writer(&output_file)?;
let (output_tx, output_rx) = tokio::sync::mpsc::unbounded_channel::<OutputEvent>();
let output_writer_handle = spawn_output_writer(output_file.clone(), output_writer, output_rx);
let metadata = SimulationMetadata {
start_slot: args.start_slot,
end_slot: args.end_slot,
program_ids: args.program_id.clone(),
program_so: args
.program_so
.iter()
.map(|p| p.display().to_string())
.collect(),
ran_at_unix_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_ids: Vec::new(),
};
let _ = output_tx.send(OutputEvent::Metadata(metadata));
let output = OutputSink {
records: RecordsStore::new(),
output_tx,
};
let session_ids: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let preloaded_programs: Arc<PreloadedPrograms> = Arc::new(
args.program_id
.iter()
.zip(&args.program_so)
.map(|(id, path)| {
let elf = std::fs::read(path)
.with_context(|| format!("failed to read {}", path.display()))?;
Ok((id.clone(), elf))
})
.collect::<Result<_>>()?,
);
let http_client = reqwest::Client::builder()
.timeout(HTTP_REQUEST_TIMEOUT)
.pool_idle_timeout(HTTP_POOL_IDLE_TIMEOUT)
.hickory_dns(true)
.build()
.context("failed to build shared HTTP client")?;
let infra = SessionInfra {
preloaded_programs,
http_client,
};
let args = Arc::new(args);
spawn_ctrlc_cursor_fix();
if args.verbose {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
}
if args.parallel {
let progress = MultiProgress::new();
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
info!("splitting into {} parallel session(s)", sub_ranges.len());
let _ = progress.println(format!(
"{} Splitting into {} parallel session(s)",
style(">").cyan(),
sub_ranges.len(),
));
let bars: Vec<ProgressBar> = sub_ranges
.iter()
.map(|&(start, end)| {
if args.verbose {
return ProgressBar::hidden();
}
let bar = ProgressBar::new(end.saturating_sub(start) + 1);
bar.set_style(spinner_style(start, end));
let bar = progress.add(bar);
bar.enable_steady_tick(Duration::from_millis(100));
bar
})
.collect();
let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_SESSIONS));
let mut join_set = tokio::task::JoinSet::new();
let all_bars: Vec<ProgressBar> = bars.to_vec();
for ((start, end), bar) in sub_ranges.into_iter().zip(bars) {
let base_url = base_url.clone();
let api_key = api_key.clone();
let args = args.clone();
let output = output.clone();
let session_ids = session_ids.clone();
let cancellation = cancellation.clone();
let semaphore = semaphore.clone();
let infra = infra.clone();
join_set.spawn(async move {
let _permit = tokio::select! {
biased;
_ = cancellation.cancelled() => return Ok(()),
p = semaphore.acquire_owned() => p?,
};
drive_session(
base_url,
api_key,
&args,
&bar,
SlotRange { start, end },
&output,
&session_ids,
&infra,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))
});
}
let mut first_error: Option<eyre::Report> = None;
loop {
tokio::select! {
biased;
_ = cancellation.cancelled() => {
join_set.abort_all();
for bar in &all_bars {
if !bar.is_finished() {
bar.abandon_with_message("cancelled");
}
}
let _ = progress.println(format!(
"\n{} Cancelled by user",
style("!").yellow(),
));
break;
}
result = join_set.join_next() => {
let Some(result) = result else { break };
match result {
Ok(Ok(())) => {}
Ok(Err(e)) if first_error.is_none() => first_error = Some(e),
Err(e) if !e.is_cancelled() => {
return Err(eyre::eyre!("session task panicked: {e}"));
}
_ => {}
}
}
}
}
if let Some(e) = first_error
&& !cancellation.is_cancelled()
{
return Err(e);
}
} else {
let range = SlotRange {
start: args.start_slot,
end: args.end_slot,
};
let bar = if args.verbose {
ProgressBar::hidden()
} else {
let bar = ProgressBar::new(range.end.saturating_sub(range.start) + 1);
bar.set_style(spinner_style(range.start, range.end));
bar.set_message("starting session...");
bar.enable_steady_tick(Duration::from_millis(100));
bar
};
drive_session(
base_url,
api_key,
&args,
&bar,
range,
&output,
&session_ids,
&infra,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))?;
}
if cancellation.is_cancelled() {
eprintln!("\n{} Cancelled by user", style("!").yellow());
let OutputSink { output_tx, .. } = output;
drop(output_tx);
let _ = output_writer_handle.await;
return Ok(());
}
let OutputSink {
records, output_tx, ..
} = output;
let session_ids = session_ids.lock().unwrap().clone();
let total = records.total();
let successes = records.successes();
let summary = SimulationSummary {
total_transactions: total,
successes,
failures: total.saturating_sub(successes),
session_ids,
};
let _ = output_tx.send(OutputEvent::Summary(summary.clone()));
drop(output_tx);
output_writer_handle
.await
.context("output writer task panicked")?
.context("output writer failed")?;
if summary.total_transactions == 0 {
println!(
"\n{} Simulation finished with no transactions",
style("⚠").yellow()
);
println!(" Output file : {}", output_file.display());
return Err(eyre::eyre!(
"no matching transactions found in the requested slot range"
));
}
println!("\n{} Simulation complete", style("✔").green());
println!(" Transactions : {}", summary.total_transactions);
println!(" Successes : {}", summary.successes);
println!(" Failures : {}", summary.failures);
println!(" Output file : {}", output_file.display());
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn drive_session(
base_url: String,
api_key: String,
args: &RunArgs,
pb: &ProgressBar,
range: SlotRange,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
infra: &SessionInfra,
parent_cancel: &CancellationToken,
) -> Result<()> {
let SlotRange { start, end } = range;
let advance_count = args.advance_count.unwrap_or(end.saturating_sub(start) + 1);
let create = CreateSession::builder()
.start_slot(start)
.end_slot(end)
.disconnect_timeout_secs(180)
.capacity_wait_timeout_secs(CREATION_TIMEOUT_SECS as u16)
.maybe_extra_compute_units(args.extra_compute_units)
.build()
.into_request()
.map_err(|e| eyre::eyre!("building create request: {e}"))?;
let mut session = match ManagedBacktestSession::start_with_cancel(
base_url,
api_key,
create,
parent_cancel.clone(),
)
.await
{
Ok(s) => s,
Err(ManagedSessionError::Cancelled) => {
pb.abandon_with_message("cancelled");
return Ok(());
}
Err(e) => return Err(eyre::eyre!("session create failed: {e}")),
};
let session_info = session.session_info().clone();
session_ids
.lock()
.unwrap()
.push(session_info.session_id.clone());
info!(
session_id = %session_info.session_id,
task_id = ?session_info.task_id,
start,
end,
"session created"
);
let _ = output
.output_tx
.send(OutputEvent::SessionStarted(SessionStartedRecord {
session_id: session_info.session_id.clone(),
task_id: session_info.task_id.clone(),
start_slot: start,
end_slot: end,
}));
pb.set_message("waiting for runtime");
if !args.program_id.is_empty() {
session.subscribe_transactions(args.program_id.clone());
if matches!(args.subscription, Some(SubscriptionType::AccountDiff)) {
session.subscribe_account_diffs(args.program_id.clone());
}
}
let outcome = run_session_loop(
&mut session,
pb,
range,
advance_count,
&session_info.rpc_endpoint,
infra,
output,
)
.await;
session.shutdown().await;
match outcome {
Ok(Outcome::Completed) => {
pb.finish_with_message("done");
Ok(())
}
Ok(Outcome::Cancelled) => {
pb.abandon_with_message("cancelled");
Ok(())
}
Err(e) => Err(e),
}
}
enum Outcome {
Completed,
Cancelled,
}
#[allow(clippy::too_many_arguments)]
async fn run_session_loop(
session: &mut ManagedBacktestSession,
pb: &ProgressBar,
range: SlotRange,
advance_count: u64,
rpc_endpoint: &str,
infra: &SessionInfra,
output: &OutputSink,
) -> Result<Outcome> {
let SlotRange { start, end } = range;
let mut first_ready_seen = false;
loop {
let event = match session.next_event().await {
Ok(e) => e,
Err(ManagedSessionError::Cancelled) => return Ok(Outcome::Cancelled),
Err(e) => return Err(eyre::eyre!("session failed: {e}")),
};
match event {
ManagedEvent::Completed => return Ok(Outcome::Completed),
ManagedEvent::Error(e) => return Err(eyre::eyre!("simulator error: {e}")),
ManagedEvent::Slot(s) => update_progress_position(pb, start, s),
ManagedEvent::Status(status) => {
if !first_ready_seen {
pb.set_message(status.to_string());
}
}
ManagedEvent::ReadyForContinue => {
let params = if !first_ready_seen {
if !infra.preloaded_programs.is_empty() {
pb.set_message("injecting program");
}
let modifications = build_modifications(rpc_endpoint, infra).await?;
pb.set_style(bar_style(start, end));
pb.reset_elapsed();
pb.set_position(0);
pb.set_message("running");
first_ready_seen = true;
ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(modifications),
}
} else {
ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(Default::default()),
}
};
session
.send_continue(params)
.await
.map_err(|e| eyre::eyre!("send_continue: {e}"))?;
}
event @ (ManagedEvent::Transaction(_) | ManagedEvent::AccountDiff(_)) => {
dispatch_event(event, output);
}
ManagedEvent::Paused(_) | ManagedEvent::DiscoveryBatch(_) => {}
}
}
}
fn dispatch_event(event: ManagedEvent, output: &OutputSink) {
match event {
ManagedEvent::Transaction(notification) => on_transaction_notification(
output.records.clone(),
output.output_tx.clone(),
*notification,
),
ManagedEvent::AccountDiff(notification) => on_account_diff_notification(
output.records.clone(),
output.output_tx.clone(),
notification,
),
_ => {}
}
}
async fn build_modifications(
rpc_endpoint: &str,
infra: &SessionInfra,
) -> Result<BTreeMap<solana_address::Address, simulator_api::AccountData>> {
if infra.preloaded_programs.is_empty() {
return Ok(BTreeMap::new());
}
let rpc = make_rpc_client(rpc_endpoint, &infra.http_client);
let mut mods = BTreeMap::new();
for (id, elf) in infra.preloaded_programs.iter() {
let m = modify_program_via_rpc(&rpc, id, elf)
.await
.map_err(|e| eyre::eyre!("building program injection for {id}: {e}"))?;
mods.extend(m);
}
Ok(mods)
}
fn open_output_writer(path: &std::path::Path) -> Result<std::fs::File> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
&& !parent.exists()
{
std::fs::create_dir_all(parent)
.with_context(|| format!("could not create output directory '{}'", parent.display()))?;
}
std::fs::File::create(path).with_context(|| format!("could not create '{}'", path.display()))
}
fn spawn_output_writer(
path: PathBuf,
file: std::fs::File,
mut rx: tokio::sync::mpsc::UnboundedReceiver<OutputEvent>,
) -> tokio::task::JoinHandle<Result<()>> {
tokio::spawn(async move {
use std::io::Write;
let write_err = || format!("could not write to '{}'", path.display());
let mut writer = std::io::BufWriter::new(file);
while let Some(event) = rx.recv().await {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
while let Ok(event) = rx.try_recv() {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
}
writer.flush().with_context(write_err)?;
}
writer
.flush()
.with_context(|| format!("could not finalize '{}'", path.display()))?;
Ok(())
})
}
fn update_progress_position(pb: &ProgressBar, start_slot: u64, slot: u64) {
pb.set_position(slot.saturating_sub(start_slot).saturating_add(1));
}
fn spinner_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {spinner:.cyan} {msg}")
.unwrap_or_else(|_| ProgressStyle::default_spinner())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
fn bar_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {bar:40.cyan/blue} {pos}/{len} slots ({elapsed}) {msg}")
.unwrap_or_else(|_| ProgressStyle::default_bar())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_progress_position_tracks_slot_offset() {
let pb = ProgressBar::new(100);
update_progress_position(&pb, 362_270_659, 362_270_659);
assert_eq!(pb.position(), 1);
update_progress_position(&pb, 362_270_659, 362_270_700);
assert_eq!(pb.position(), 42);
}
}