use std::{
collections::BTreeMap,
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};
use clap::Parser;
use console::style;
use dashmap::DashMap;
use eyre::{Context, Result};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use simulator_api::{AccountModifications, ContinueParams};
use simulator_client::{BacktestClient, CreateSession, modify_program_via_rpc, split_range};
use solana_client::nonblocking::rpc_client::RpcClient;
use tokio::sync::{Semaphore, watch};
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::{
manager::{
AccountDiff, ConnectionStatus, ControlEvent, ControlHandle, ProgramLog, SubscriptionHandle,
spawn_control_manager, spawn_subscription_manager,
},
output::{AccountDiffRow, SimulationMetadata, SimulationOutput, Transaction, format_slot},
signals::spawn_ctrlc_cursor_fix,
subscription::{on_account_diff_notification, on_log_notification},
};
const MAX_PARALLEL_SESSIONS: usize = 50;
const CREATION_TIMEOUT_SECS: u64 = 900;
const ESTIMATED_MS_PER_BLOCK: u64 = 250;
#[derive(Clone, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum SubscriptionType {
Logs,
AccountDiff,
}
#[derive(Clone)]
struct OutputSink {
records: Arc<DashMap<String, Transaction>>,
stream_tx: Option<tokio::sync::mpsc::UnboundedSender<AccountDiffRow>>,
}
#[derive(Clone, Copy)]
struct SlotRange {
start: u64,
end: u64,
}
#[derive(Parser, Debug, Clone)]
pub struct RunArgs {
#[arg(
long,
env = "SIMULATOR_URL",
default_value = "simulator.termina.technology"
)]
pub url: String,
#[arg(long, env = "SIMULATOR_API_KEY")]
pub api_key: String,
#[arg(long, env = "SIMULATOR_START_SLOT")]
pub start_slot: u64,
#[arg(long, env = "SIMULATOR_END_SLOT")]
pub end_slot: u64,
#[arg(long, env = "SIMULATOR_ADVANCE_COUNT")]
pub advance_count: Option<u64>,
#[arg(long, env = "SIMULATOR_PARALLEL", default_value_t = false)]
pub parallel: bool,
#[arg(long, env = "SIMULATOR_PROGRAM_ID")]
pub program_id: Vec<String>,
#[arg(long, env = "SIMULATOR_PROGRAM_SO")]
pub program_so: Vec<PathBuf>,
#[arg(long, env = "SIMULATOR_OUTPUT_FILE")]
pub output_file: Option<PathBuf>,
#[arg(long, env = "SIMULATOR_EXTRA_COMPUTE_UNITS")]
pub extra_compute_units: Option<u32>,
#[arg(long, env = "SIMULATOR_SUBSCRIPTION", requires = "program_id")]
pub subscription: Option<SubscriptionType>,
#[arg(long, env = "SIMULATOR_STREAM_FILE", requires = "subscription")]
pub stream_file: Option<PathBuf>,
#[arg(long, default_value_t = false)]
pub verbose: bool,
#[arg(long, env = "SIMULATOR_PLAN", default_value_t = false)]
pub plan: bool,
}
fn validate_args(args: &RunArgs) -> Result<()> {
if !args.program_so.is_empty() {
eyre::ensure!(
args.program_so.len() == args.program_id.len(),
"--program-so count ({}) must match --program-id count ({})",
args.program_so.len(),
args.program_id.len(),
);
}
Ok(())
}
async fn generate_plan(client: &BacktestClient, args: &RunArgs) -> Result<()> {
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
if sub_ranges.len() > MAX_PARALLEL_SESSIONS {
eyre::bail!(
"unable to generate plan: requires {} sessions but max is {MAX_PARALLEL_SESSIONS}",
sub_ranges.len(),
);
}
for window in sub_ranges.windows(2) {
let (_, end) = window[0];
let (next_start, _) = window[1];
if end + 1 != next_start {
eyre::bail!(
"unable to generate plan: gap in data between slots {} and {}",
format_slot(end),
format_slot(next_start),
);
}
}
let max_slots = sub_ranges
.iter()
.map(|(start, end)| end.saturating_sub(*start) + 1)
.max()
.unwrap_or(0);
let estimated_ms = max_slots * ESTIMATED_MS_PER_BLOCK;
let estimated = if estimated_ms < 60 * 60 * 1000 {
format!("~{}m", estimated_ms / 60_000)
} else {
format!("~{:.1}h", estimated_ms as f64 / 3_600_000.0)
};
println!(
"\n{} {estimated} runtime · {} session(s)\n{} [{} → {}] · {} slots",
style(">").cyan(),
sub_ranges.len(),
style(">").cyan(),
format_slot(args.start_slot),
format_slot(args.end_slot),
format_slot(args.end_slot - args.start_slot + 1),
);
println!();
println!(
"{:<5} {:<20} {:<20} {:<15}",
"#", "Start Slot", "End Slot", "Slots"
);
println!("{}", "-".repeat(62));
for (i, (start, end)) in sub_ranges.iter().enumerate() {
let slots = end.saturating_sub(*start) + 1;
println!(
"{:<5} {:<20} {:<20} {:<15}",
i + 1,
format_slot(*start),
format_slot(*end),
format_slot(slots),
);
}
Ok(())
}
pub async fn run(args: RunArgs, cancellation: CancellationToken) -> Result<()> {
validate_args(&args)?;
let base_url = if args.url.starts_with("ws://") || args.url.starts_with("wss://") {
format!("{}/backtest", args.url.trim_end_matches('/'))
} else {
format!("wss://{}/backtest", args.url)
};
let client = BacktestClient::builder()
.url(base_url.clone())
.api_key(args.api_key.clone())
.build();
if args.plan {
return generate_plan(&client, &args).await;
}
let output_file = args.output_file.clone().unwrap_or_else(|| {
let suffix = if !args.program_so.is_empty() {
"override"
} else {
"baseline"
};
PathBuf::from(format!(
"{}_{}_{}.json",
args.start_slot, args.end_slot, suffix
))
});
let (stream_tx, stream_writer_handle) = if let Some(ref stream_path) = args.stream_file {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<AccountDiffRow>();
let handle = spawn_stream_writer(stream_path.clone(), rx);
(Some(tx), Some(handle))
} else {
(None, None)
};
let output = OutputSink {
records: Arc::new(DashMap::new()),
stream_tx,
};
let session_ids: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let args = Arc::new(args);
spawn_ctrlc_cursor_fix();
if args.verbose {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
}
if args.parallel {
let progress = MultiProgress::new();
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
info!("splitting into {} parallel session(s)", sub_ranges.len());
let _ = progress.println(format!(
"{} Splitting into {} parallel session(s)",
style(">").cyan(),
sub_ranges.len(),
));
let bars: Vec<ProgressBar> = sub_ranges
.iter()
.map(|&(start, end)| {
if args.verbose {
return ProgressBar::hidden();
}
let bar = ProgressBar::new(end.saturating_sub(start) + 1);
bar.set_style(spinner_style(start, end));
let bar = progress.add(bar);
bar.enable_steady_tick(Duration::from_millis(100));
bar
})
.collect();
let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_SESSIONS));
let mut join_set = tokio::task::JoinSet::new();
let all_bars: Vec<ProgressBar> = bars.to_vec();
for ((start, end), bar) in sub_ranges.into_iter().zip(bars) {
let base_url = base_url.clone();
let api_key = args.api_key.clone();
let args = args.clone();
let output = output.clone();
let session_ids = session_ids.clone();
let cancellation = cancellation.clone();
let semaphore = semaphore.clone();
join_set.spawn(async move {
let _permit = tokio::select! {
biased;
_ = cancellation.cancelled() => return Ok(()),
p = semaphore.acquire_owned() => p?,
};
drive_session(
base_url,
api_key,
&args,
&bar,
SlotRange { start, end },
&output,
&session_ids,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))
});
}
let mut first_error: Option<eyre::Report> = None;
loop {
tokio::select! {
biased;
_ = cancellation.cancelled() => {
join_set.abort_all();
for bar in &all_bars {
if !bar.is_finished() {
bar.abandon_with_message("cancelled");
}
}
let _ = progress.println(format!(
"\n{} Cancelled by user",
style("!").yellow(),
));
break;
}
result = join_set.join_next() => {
let Some(result) = result else { break };
match result {
Ok(Ok(())) => {}
Ok(Err(e)) if first_error.is_none() => first_error = Some(e),
Err(e) if !e.is_cancelled() => {
return Err(eyre::eyre!("session task panicked: {e}"));
}
_ => {}
}
}
}
}
if let Some(e) = first_error
&& !cancellation.is_cancelled()
{
return Err(e);
}
} else {
let range = SlotRange {
start: args.start_slot,
end: args.end_slot,
};
let bar = if args.verbose {
ProgressBar::hidden()
} else {
let bar = ProgressBar::new(range.end.saturating_sub(range.start) + 1);
bar.set_style(spinner_style(range.start, range.end));
bar.set_message("starting session...");
bar.enable_steady_tick(Duration::from_millis(100));
bar
};
drive_session(
base_url,
args.api_key.clone(),
&args,
&bar,
range,
&output,
&session_ids,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))?;
}
if cancellation.is_cancelled() {
eprintln!("\n{} Cancelled by user", style("!").yellow());
return Ok(());
}
let OutputSink { records, stream_tx } = output;
drop(stream_tx);
if let Some(handle) = stream_writer_handle {
handle
.await
.context("stream writer task panicked")?
.context("stream writer failed")?;
}
let session_ids = session_ids.lock().unwrap().clone();
let records_map = Arc::try_unwrap(records).unwrap_or_else(|arc| (*arc).clone());
let mut records: Vec<Transaction> = records_map.into_iter().map(|(_, tx)| tx).collect();
records.sort_by(|a, b| a.slot.cmp(&b.slot).then(a.signature.cmp(&b.signature)));
let metadata = SimulationMetadata {
start_slot: args.start_slot,
end_slot: args.end_slot,
program_ids: args.program_id.clone(),
program_so: args
.program_so
.iter()
.map(|p| p.display().to_string())
.collect(),
ran_at_unix_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_ids,
};
let output = SimulationOutput::build(metadata, records);
let json =
serde_json::to_string_pretty(&output).context("failed to serialize output to JSON")?;
std::fs::write(&output_file, &json)
.with_context(|| format!("failed to write {}", output_file.display()))?;
if output.summary.total_transactions == 0 {
println!(
"\n{} Simulation finished with no transactions",
style("⚠").yellow()
);
println!(" Output file : {}", output_file.display());
return Err(eyre::eyre!(
"no matching transactions found in the requested slot range"
));
}
println!("\n{} Simulation complete", style("✔").green());
println!(" Transactions : {}", output.summary.total_transactions);
println!(" Successes : {}", output.summary.successes);
println!(" Failures : {}", output.summary.failures);
println!(" Output file : {}", output_file.display());
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn drive_session(
base_url: String,
api_key: String,
args: &RunArgs,
pb: &ProgressBar,
range: SlotRange,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
parent_cancel: &CancellationToken,
) -> Result<()> {
let SlotRange { start, end } = range;
let advance_count = args.advance_count.unwrap_or(end.saturating_sub(start) + 1);
let session_cancel = parent_cancel.child_token();
let create = CreateSession::builder()
.start_slot(start)
.end_slot(end)
.disconnect_timeout_secs(180)
.capacity_wait_timeout_secs(CREATION_TIMEOUT_SECS as u16)
.maybe_extra_compute_units(args.extra_compute_units)
.build()
.into_request()
.map_err(|e| eyre::eyre!("building create request: {e}"))?;
let mut control = spawn_control_manager(base_url, api_key, create, session_cancel.clone());
let session_info = tokio::select! {
biased;
_ = parent_cancel.cancelled() => {
session_cancel.cancel();
pb.abandon_with_message("cancelled");
return Ok(());
}
result = control.wait_for_session() => result.map_err(|e| eyre::eyre!("session create failed: {e}"))?,
};
session_ids
.lock()
.unwrap()
.push(session_info.session_id.clone());
info!(
session_id = %session_info.session_id,
start,
end,
"session created"
);
pb.set_message("waiting for runtime");
let modifications = build_modifications(&session_info.rpc_endpoint, args).await?;
let subscription = spawn_subscription(
&session_info.rpc_endpoint,
args,
output,
session_cancel.clone(),
);
let mut pending: Option<ContinueParams> = Some(ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(modifications),
});
let result = driver_loop(
&mut control,
subscription.as_ref(),
pb,
range,
&mut pending,
&session_cancel,
)
.await;
control.join().await;
session_cancel.cancel();
if let Some(sub) = subscription {
let _ = sub.join.await;
}
match result {
Ok(Outcome::Completed) => {
pb.finish_with_message("done");
Ok(())
}
Ok(Outcome::Cancelled) => {
pb.abandon_with_message("cancelled");
Ok(())
}
Err(e) => Err(e),
}
}
enum Outcome {
Completed,
Cancelled,
}
async fn driver_loop(
control: &mut ControlHandle,
subscription: Option<&SubscriptionHandle>,
pb: &ProgressBar,
range: SlotRange,
pending: &mut Option<ContinueParams>,
cancel: &CancellationToken,
) -> Result<Outcome> {
let SlotRange { start, end } = range;
let mut first_ready_seen = false;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => return Ok(Outcome::Cancelled),
event = control.events.recv() => match event {
None => return Err(eyre::eyre!("control channel closed")),
Some(ControlEvent::Completed) => return Ok(Outcome::Completed),
Some(ControlEvent::Error(e)) => return Err(eyre::eyre!("simulator error: {e}")),
Some(ControlEvent::Slot(s)) => update_progress_position(pb, start, s),
Some(ControlEvent::Status(status)) => {
if !first_ready_seen {
pb.set_message(status.to_string());
}
}
Some(ControlEvent::ReadyForContinue) => {
if !first_ready_seen {
pb.set_style(bar_style(start, end));
pb.reset_elapsed();
pb.set_position(0);
pb.set_message("running");
first_ready_seen = true;
}
wait_both_up(&control.status, subscription.map(|s| &s.status), cancel).await?;
if let Some(params) = pending.take() {
let advance_count = params.advance_count;
if let Err(e) = control.send_continue(params).await {
return Err(eyre::eyre!("control closed while sending continue: {e}"));
}
*pending = Some(ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(Default::default()),
});
}
}
}
}
}
}
async fn wait_both_up(
control_status: &watch::Receiver<ConnectionStatus>,
sub_status: Option<&watch::Receiver<ConnectionStatus>>,
cancel: &CancellationToken,
) -> Result<()> {
let mut c = control_status.clone();
let mut s = sub_status.cloned();
loop {
let cv = c.borrow().clone();
if let ConnectionStatus::Failed(why) = &cv {
return Err(eyre::eyre!("control failed: {why}"));
}
let sv = s.as_ref().map(|s| s.borrow().clone());
if let Some(ConnectionStatus::Failed(why)) = &sv {
return Err(eyre::eyre!("subscription failed: {why}"));
}
let control_up = cv == ConnectionStatus::Up;
let sub_up = sv.as_ref().is_none_or(|v| *v == ConnectionStatus::Up);
if control_up && sub_up {
return Ok(());
}
tokio::select! {
_ = cancel.cancelled() => return Err(eyre::eyre!("cancelled")),
_ = c.changed() => {}
_ = async {
if let Some(s) = s.as_mut() { let _ = s.changed().await; }
else { std::future::pending::<()>().await; }
} => {}
}
}
}
async fn build_modifications(
rpc_endpoint: &str,
args: &RunArgs,
) -> Result<BTreeMap<solana_address::Address, simulator_api::AccountData>> {
if args.program_so.is_empty() {
return Ok(BTreeMap::new());
}
let rpc = RpcClient::new(rpc_endpoint.to_string());
let mut mods = BTreeMap::new();
for (id, path) in args.program_id.iter().zip(&args.program_so) {
let elf =
std::fs::read(path).with_context(|| format!("failed to read {}", path.display()))?;
let m = modify_program_via_rpc(&rpc, id, &elf)
.await
.map_err(|e| eyre::eyre!("building program injection for {id}: {e}"))?;
mods.extend(m);
}
Ok(mods)
}
fn spawn_subscription(
rpc_endpoint: &str,
args: &RunArgs,
output: &OutputSink,
cancel: CancellationToken,
) -> Option<SubscriptionHandle> {
if args.program_id.is_empty() {
return None;
}
let rpc = Arc::new(RpcClient::new(rpc_endpoint.to_string()));
let records = output.records.clone();
let stream_tx = output.stream_tx.clone();
let handle = match args.subscription {
Some(SubscriptionType::AccountDiff) => spawn_subscription_manager::<AccountDiff, _, _>(
rpc_endpoint.to_string(),
args.program_id.clone(),
move |notification| {
let rpc = rpc.clone();
let records = records.clone();
let stream_tx = stream_tx.clone();
async move {
on_account_diff_notification(rpc, records, stream_tx, notification).await;
}
},
cancel,
),
Some(SubscriptionType::Logs) | None => spawn_subscription_manager::<ProgramLog, _, _>(
rpc_endpoint.to_string(),
args.program_id.clone(),
move |notification| {
let rpc = rpc.clone();
let records = records.clone();
async move {
on_log_notification(rpc, records, notification).await;
}
},
cancel,
),
};
Some(handle)
}
fn spawn_stream_writer(
path: PathBuf,
mut rx: tokio::sync::mpsc::UnboundedReceiver<AccountDiffRow>,
) -> tokio::task::JoinHandle<Result<()>> {
tokio::spawn(async move {
use std::io::Write;
let write_err = || format!("could not write to '{}'", path.display());
let file = std::fs::File::create(&path).with_context(|| {
format!(
"could not create '{}' — check that the directory exists and is writable",
path.display()
)
})?;
let mut writer = std::io::BufWriter::new(file);
while let Some(row) = rx.recv().await {
serde_json::to_writer(&mut writer, &row).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
while let Ok(row) = rx.try_recv() {
serde_json::to_writer(&mut writer, &row).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
}
writer.flush().with_context(write_err)?;
}
writer
.flush()
.with_context(|| format!("could not finalize '{}'", path.display()))?;
Ok(())
})
}
fn update_progress_position(pb: &ProgressBar, start_slot: u64, slot: u64) {
pb.set_position(slot.saturating_sub(start_slot).saturating_add(1));
}
fn spinner_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {spinner:.cyan} {msg}")
.unwrap_or_else(|_| ProgressStyle::default_spinner())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
fn bar_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {bar:40.cyan/blue} {pos}/{len} slots ({elapsed}) {msg}")
.unwrap_or_else(|_| ProgressStyle::default_bar())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_progress_position_tracks_slot_offset() {
let pb = ProgressBar::new(100);
update_progress_position(&pb, 362_270_659, 362_270_659);
assert_eq!(pb.position(), 1);
update_progress_position(&pb, 362_270_659, 362_270_700);
assert_eq!(pb.position(), 42);
}
}