use std::{
collections::BTreeMap,
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};
use clap::Parser;
use console::style;
use dashmap::DashMap;
use eyre::{Context, Result};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use simulator_api::{AccountModifications, ContinueParams};
use simulator_client::{BacktestClient, CreateSession, modify_program_via_rpc, split_range};
use solana_client::nonblocking::rpc_client::RpcClient;
use solana_rpc_client::{http_sender::HttpSender, rpc_client::RpcClientConfig};
use tokio::sync::{Semaphore, watch};
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::{
manager::{
AccountDiff, ConnectionStatus, ControlEvent, ControlHandle, SubscriptionHandle,
Transaction as TransactionKind, spawn_control_manager, spawn_subscription_manager,
},
output::{
OutputEvent, SessionStartedRecord, SimulationMetadata, SimulationSummary, Transaction,
format_slot,
},
signals::spawn_ctrlc_cursor_fix,
subscription::{on_account_diff_notification, on_transaction_notification},
};
const MAX_PARALLEL_SESSIONS: usize = 50;
const CREATION_TIMEOUT_SECS: u64 = 900;
const ESTIMATED_MS_PER_BLOCK: u64 = 250;
const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const HTTP_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
type PreloadedPrograms = Vec<(String, Vec<u8>)>;
#[derive(Clone)]
struct SessionInfra {
preloaded_programs: Arc<PreloadedPrograms>,
http_client: reqwest::Client,
}
fn make_rpc_client(rpc_endpoint: &str, http_client: &reqwest::Client) -> RpcClient {
let sender = HttpSender::new_with_client(rpc_endpoint, http_client.clone());
RpcClient::new_sender(sender, RpcClientConfig::default())
}
#[derive(Clone, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum SubscriptionType {
Logs,
AccountDiff,
Transaction,
}
#[derive(Clone)]
struct OutputSink {
records: Arc<DashMap<String, Transaction>>,
output_tx: tokio::sync::mpsc::UnboundedSender<OutputEvent>,
}
#[derive(Clone, Copy)]
struct SlotRange {
start: u64,
end: u64,
}
#[derive(Parser, Debug, Clone)]
pub struct RunArgs {
#[arg(
long,
env = "SIMULATOR_URL",
default_value = "simulator.termina.technology"
)]
pub url: String,
#[arg(long, env = "SIMULATOR_API_KEY")]
pub api_key: String,
#[arg(long, env = "SIMULATOR_START_SLOT")]
pub start_slot: u64,
#[arg(long, env = "SIMULATOR_END_SLOT")]
pub end_slot: u64,
#[arg(long, env = "SIMULATOR_ADVANCE_COUNT")]
pub advance_count: Option<u64>,
#[arg(long, env = "SIMULATOR_PARALLEL", default_value_t = false)]
pub parallel: bool,
#[arg(long, env = "SIMULATOR_PROGRAM_ID")]
pub program_id: Vec<String>,
#[arg(long, env = "SIMULATOR_PROGRAM_SO")]
pub program_so: Vec<PathBuf>,
#[arg(long, env = "SIMULATOR_OUTPUT_FILE")]
pub output_file: Option<PathBuf>,
#[arg(long, env = "SIMULATOR_EXTRA_COMPUTE_UNITS")]
pub extra_compute_units: Option<u32>,
#[arg(long, env = "SIMULATOR_SUBSCRIPTION", requires = "program_id")]
pub subscription: Option<SubscriptionType>,
#[arg(long, default_value_t = false)]
pub verbose: bool,
#[arg(long, env = "SIMULATOR_PLAN", default_value_t = false)]
pub plan: bool,
}
fn validate_args(args: &RunArgs) -> Result<()> {
if !args.program_so.is_empty() {
eyre::ensure!(
args.program_so.len() == args.program_id.len(),
"--program-so count ({}) must match --program-id count ({})",
args.program_so.len(),
args.program_id.len(),
);
}
Ok(())
}
async fn generate_plan(client: &BacktestClient, args: &RunArgs) -> Result<()> {
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
if sub_ranges.len() > MAX_PARALLEL_SESSIONS {
eyre::bail!(
"unable to generate plan: requires {} sessions but max is {MAX_PARALLEL_SESSIONS}",
sub_ranges.len(),
);
}
for window in sub_ranges.windows(2) {
let (_, end) = window[0];
let (next_start, _) = window[1];
if end + 1 != next_start {
eyre::bail!(
"unable to generate plan: gap in data between slots {} and {}",
format_slot(end),
format_slot(next_start),
);
}
}
let max_slots = sub_ranges
.iter()
.map(|(start, end)| end.saturating_sub(*start) + 1)
.max()
.unwrap_or(0);
let estimated_ms = max_slots * ESTIMATED_MS_PER_BLOCK;
let estimated = if estimated_ms < 60 * 60 * 1000 {
format!("~{}m", estimated_ms / 60_000)
} else {
format!("~{:.1}h", estimated_ms as f64 / 3_600_000.0)
};
println!(
"\n{} {estimated} runtime · {} session(s)\n{} [{} → {}] · {} slots",
style(">").cyan(),
sub_ranges.len(),
style(">").cyan(),
format_slot(args.start_slot),
format_slot(args.end_slot),
format_slot(args.end_slot - args.start_slot + 1),
);
println!();
println!(
"{:<5} {:<20} {:<20} {:<15}",
"#", "Start Slot", "End Slot", "Slots"
);
println!("{}", "-".repeat(62));
for (i, (start, end)) in sub_ranges.iter().enumerate() {
let slots = end.saturating_sub(*start) + 1;
println!(
"{:<5} {:<20} {:<20} {:<15}",
i + 1,
format_slot(*start),
format_slot(*end),
format_slot(slots),
);
}
Ok(())
}
pub async fn run(args: RunArgs, cancellation: CancellationToken) -> Result<()> {
validate_args(&args)?;
let base_url = if args.url.starts_with("ws://") || args.url.starts_with("wss://") {
format!("{}/backtest", args.url.trim_end_matches('/'))
} else {
format!("wss://{}/backtest", args.url)
};
let client = BacktestClient::builder()
.url(base_url.clone())
.api_key(args.api_key.clone())
.build();
if args.plan {
return generate_plan(&client, &args).await;
}
let output_file = args.output_file.clone().unwrap_or_else(|| {
let suffix = if !args.program_so.is_empty() {
"override"
} else {
"baseline"
};
PathBuf::from(format!(
"{}_{}_{}.ndjson",
args.start_slot, args.end_slot, suffix
))
});
let output_writer = open_output_writer(&output_file)?;
let (output_tx, output_rx) = tokio::sync::mpsc::unbounded_channel::<OutputEvent>();
let output_writer_handle = spawn_output_writer(output_file.clone(), output_writer, output_rx);
let metadata = SimulationMetadata {
start_slot: args.start_slot,
end_slot: args.end_slot,
program_ids: args.program_id.clone(),
program_so: args
.program_so
.iter()
.map(|p| p.display().to_string())
.collect(),
ran_at_unix_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
session_ids: Vec::new(),
};
let _ = output_tx.send(OutputEvent::Metadata(metadata));
let output = OutputSink {
records: Arc::new(DashMap::new()),
output_tx,
};
let session_ids: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let preloaded_programs: Arc<PreloadedPrograms> = Arc::new(
args.program_id
.iter()
.zip(&args.program_so)
.map(|(id, path)| {
let elf = std::fs::read(path)
.with_context(|| format!("failed to read {}", path.display()))?;
Ok((id.clone(), elf))
})
.collect::<Result<_>>()?,
);
let http_client = reqwest::Client::builder()
.timeout(HTTP_REQUEST_TIMEOUT)
.pool_idle_timeout(HTTP_POOL_IDLE_TIMEOUT)
.hickory_dns(true)
.build()
.context("failed to build shared HTTP client")?;
let infra = SessionInfra {
preloaded_programs,
http_client,
};
let args = Arc::new(args);
spawn_ctrlc_cursor_fix();
if args.verbose {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
}
if args.parallel {
let progress = MultiProgress::new();
let ranges = client
.available_ranges()
.await
.map_err(|e| eyre::eyre!("failed to fetch available ranges: {e}"))?;
let sub_ranges = split_range(&ranges, args.start_slot, args.end_slot)
.map_err(|e| eyre::eyre!("cannot split range: {e}"))?;
info!("splitting into {} parallel session(s)", sub_ranges.len());
let _ = progress.println(format!(
"{} Splitting into {} parallel session(s)",
style(">").cyan(),
sub_ranges.len(),
));
let bars: Vec<ProgressBar> = sub_ranges
.iter()
.map(|&(start, end)| {
if args.verbose {
return ProgressBar::hidden();
}
let bar = ProgressBar::new(end.saturating_sub(start) + 1);
bar.set_style(spinner_style(start, end));
let bar = progress.add(bar);
bar.enable_steady_tick(Duration::from_millis(100));
bar
})
.collect();
let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_SESSIONS));
let mut join_set = tokio::task::JoinSet::new();
let all_bars: Vec<ProgressBar> = bars.to_vec();
for ((start, end), bar) in sub_ranges.into_iter().zip(bars) {
let base_url = base_url.clone();
let api_key = args.api_key.clone();
let args = args.clone();
let output = output.clone();
let session_ids = session_ids.clone();
let cancellation = cancellation.clone();
let semaphore = semaphore.clone();
let infra = infra.clone();
join_set.spawn(async move {
let _permit = tokio::select! {
biased;
_ = cancellation.cancelled() => return Ok(()),
p = semaphore.acquire_owned() => p?,
};
drive_session(
base_url,
api_key,
&args,
&bar,
SlotRange { start, end },
&output,
&session_ids,
&infra,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))
});
}
let mut first_error: Option<eyre::Report> = None;
loop {
tokio::select! {
biased;
_ = cancellation.cancelled() => {
join_set.abort_all();
for bar in &all_bars {
if !bar.is_finished() {
bar.abandon_with_message("cancelled");
}
}
let _ = progress.println(format!(
"\n{} Cancelled by user",
style("!").yellow(),
));
break;
}
result = join_set.join_next() => {
let Some(result) = result else { break };
match result {
Ok(Ok(())) => {}
Ok(Err(e)) if first_error.is_none() => first_error = Some(e),
Err(e) if !e.is_cancelled() => {
return Err(eyre::eyre!("session task panicked: {e}"));
}
_ => {}
}
}
}
}
if let Some(e) = first_error
&& !cancellation.is_cancelled()
{
return Err(e);
}
} else {
let range = SlotRange {
start: args.start_slot,
end: args.end_slot,
};
let bar = if args.verbose {
ProgressBar::hidden()
} else {
let bar = ProgressBar::new(range.end.saturating_sub(range.start) + 1);
bar.set_style(spinner_style(range.start, range.end));
bar.set_message("starting session...");
bar.enable_steady_tick(Duration::from_millis(100));
bar
};
drive_session(
base_url,
args.api_key.clone(),
&args,
&bar,
range,
&output,
&session_ids,
&infra,
&cancellation,
)
.await
.inspect_err(|e| bar.abandon_with_message(format!("failed: {e}")))?;
}
if cancellation.is_cancelled() {
eprintln!("\n{} Cancelled by user", style("!").yellow());
let OutputSink { output_tx, .. } = output;
drop(output_tx);
let _ = output_writer_handle.await;
return Ok(());
}
let OutputSink {
records, output_tx, ..
} = output;
let session_ids = session_ids.lock().unwrap().clone();
let total = records.len();
let successes = records.iter().filter(|t| t.success).count();
let summary = SimulationSummary {
total_transactions: total,
successes,
failures: total.saturating_sub(successes),
session_ids,
};
let _ = output_tx.send(OutputEvent::Summary(summary.clone()));
drop(output_tx);
output_writer_handle
.await
.context("output writer task panicked")?
.context("output writer failed")?;
if summary.total_transactions == 0 {
println!(
"\n{} Simulation finished with no transactions",
style("⚠").yellow()
);
println!(" Output file : {}", output_file.display());
return Err(eyre::eyre!(
"no matching transactions found in the requested slot range"
));
}
println!("\n{} Simulation complete", style("✔").green());
println!(" Transactions : {}", summary.total_transactions);
println!(" Successes : {}", summary.successes);
println!(" Failures : {}", summary.failures);
println!(" Output file : {}", output_file.display());
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn drive_session(
base_url: String,
api_key: String,
args: &RunArgs,
pb: &ProgressBar,
range: SlotRange,
output: &OutputSink,
session_ids: &Arc<Mutex<Vec<String>>>,
infra: &SessionInfra,
parent_cancel: &CancellationToken,
) -> Result<()> {
let SlotRange { start, end } = range;
let advance_count = args.advance_count.unwrap_or(end.saturating_sub(start) + 1);
let session_cancel = parent_cancel.child_token();
let create = CreateSession::builder()
.start_slot(start)
.end_slot(end)
.disconnect_timeout_secs(180)
.capacity_wait_timeout_secs(CREATION_TIMEOUT_SECS as u16)
.maybe_extra_compute_units(args.extra_compute_units)
.build()
.into_request()
.map_err(|e| eyre::eyre!("building create request: {e}"))?;
let mut control = spawn_control_manager(base_url, api_key, create, session_cancel.clone());
let session_info = tokio::select! {
biased;
_ = parent_cancel.cancelled() => {
session_cancel.cancel();
pb.abandon_with_message("cancelled");
return Ok(());
}
result = control.wait_for_session() => result.map_err(|e| eyre::eyre!("session create failed: {e}"))?,
};
session_ids
.lock()
.unwrap()
.push(session_info.session_id.clone());
info!(
session_id = %session_info.session_id,
task_id = ?session_info.task_id,
start,
end,
"session created"
);
let _ = output
.output_tx
.send(OutputEvent::SessionStarted(SessionStartedRecord {
session_id: session_info.session_id.clone(),
task_id: session_info.task_id.clone(),
start_slot: start,
end_slot: end,
}));
pb.set_message("waiting for runtime");
let subscriptions = spawn_subscription(
&session_info.rpc_endpoint,
args,
output,
session_cancel.clone(),
parent_cancel.clone(),
);
let result = driver_loop(
&mut control,
&subscriptions,
pb,
range,
advance_count,
&session_info.rpc_endpoint,
infra,
&session_cancel,
)
.await;
control.join().await;
session_cancel.cancel();
for sub in subscriptions {
let _ = sub.join.await;
}
match result {
Ok(Outcome::Completed) => {
pb.finish_with_message("done");
Ok(())
}
Ok(Outcome::Cancelled) => {
pb.abandon_with_message("cancelled");
Ok(())
}
Err(e) => Err(e),
}
}
enum Outcome {
Completed,
Cancelled,
}
#[allow(clippy::too_many_arguments)]
async fn driver_loop(
control: &mut ControlHandle,
subscriptions: &[SubscriptionHandle],
pb: &ProgressBar,
range: SlotRange,
advance_count: u64,
rpc_endpoint: &str,
infra: &SessionInfra,
cancel: &CancellationToken,
) -> Result<Outcome> {
let SlotRange { start, end } = range;
let mut first_ready_seen = false;
let mut pending: Option<ContinueParams> = None;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => return Ok(Outcome::Cancelled),
event = control.events.recv() => match event {
None => return Err(eyre::eyre!("control channel closed")),
Some(ControlEvent::Completed) => return Ok(Outcome::Completed),
Some(ControlEvent::Error(e)) => return Err(eyre::eyre!("simulator error: {e}")),
Some(ControlEvent::Slot(s)) => update_progress_position(pb, start, s),
Some(ControlEvent::Status(status)) => {
if !first_ready_seen {
pb.set_message(status.to_string());
}
}
Some(ControlEvent::ReadyForContinue) => {
wait_all_up(&control.status, subscriptions, cancel).await?;
if !first_ready_seen {
if !infra.preloaded_programs.is_empty() {
pb.set_message("injecting program");
}
let modifications = build_modifications(rpc_endpoint, infra).await?;
pending = Some(ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(modifications),
});
pb.set_style(bar_style(start, end));
pb.reset_elapsed();
pb.set_position(0);
pb.set_message("running");
first_ready_seen = true;
}
if let Some(params) = pending.take() {
if let Err(e) = control.send_continue(params).await {
return Err(eyre::eyre!("control closed while sending continue: {e}"));
}
pending = Some(ContinueParams {
advance_count,
transactions: Vec::new(),
modify_account_states: AccountModifications(Default::default()),
});
}
}
}
}
}
}
async fn wait_all_up(
control_status: &watch::Receiver<ConnectionStatus>,
subscriptions: &[SubscriptionHandle],
cancel: &CancellationToken,
) -> Result<()> {
let mut c = control_status.clone();
let mut subs: Vec<watch::Receiver<ConnectionStatus>> =
subscriptions.iter().map(|s| s.status.clone()).collect();
loop {
let cv = c.borrow().clone();
if let ConnectionStatus::Failed(why) = &cv {
return Err(eyre::eyre!("control failed: {why}"));
}
let mut all_subs_up = true;
for sub in &subs {
match &*sub.borrow() {
ConnectionStatus::Failed(why) => {
return Err(eyre::eyre!("subscription failed: {why}"));
}
ConnectionStatus::Up => {}
_ => all_subs_up = false,
}
}
if cv == ConnectionStatus::Up && all_subs_up {
return Ok(());
}
tokio::select! {
_ = cancel.cancelled() => return Err(eyre::eyre!("cancelled")),
_ = c.changed() => {}
_ = wait_any_sub_change(&mut subs) => {}
}
}
}
async fn wait_any_sub_change(subs: &mut [watch::Receiver<ConnectionStatus>]) {
if subs.is_empty() {
std::future::pending::<()>().await;
return;
}
let _ = futures::future::select_all(subs.iter_mut().map(|s| Box::pin(s.changed()))).await;
}
async fn build_modifications(
rpc_endpoint: &str,
infra: &SessionInfra,
) -> Result<BTreeMap<solana_address::Address, simulator_api::AccountData>> {
if infra.preloaded_programs.is_empty() {
return Ok(BTreeMap::new());
}
let rpc = make_rpc_client(rpc_endpoint, &infra.http_client);
let mut mods = BTreeMap::new();
for (id, elf) in infra.preloaded_programs.iter() {
let m = modify_program_via_rpc(&rpc, id, elf)
.await
.map_err(|e| eyre::eyre!("building program injection for {id}: {e}"))?;
mods.extend(m);
}
Ok(mods)
}
fn spawn_subscription(
rpc_endpoint: &str,
args: &RunArgs,
output: &OutputSink,
cancel: CancellationToken,
parent_cancel: CancellationToken,
) -> Vec<SubscriptionHandle> {
if args.program_id.is_empty() {
return Vec::new();
}
match args.subscription {
Some(SubscriptionType::AccountDiff) => {
vec![
spawn_transaction_subscription(
rpc_endpoint,
args,
output,
cancel.clone(),
parent_cancel.clone(),
),
spawn_account_diff_subscription(rpc_endpoint, args, output, cancel, parent_cancel),
]
}
Some(SubscriptionType::Logs | SubscriptionType::Transaction) | None => {
vec![spawn_transaction_subscription(
rpc_endpoint,
args,
output,
cancel,
parent_cancel,
)]
}
}
}
fn spawn_transaction_subscription(
rpc_endpoint: &str,
args: &RunArgs,
output: &OutputSink,
cancel: CancellationToken,
parent_cancel: CancellationToken,
) -> SubscriptionHandle {
let records = output.records.clone();
let output_tx = output.output_tx.clone();
spawn_subscription_manager::<TransactionKind, _, _>(
rpc_endpoint.to_string(),
args.program_id.clone(),
move |notification| {
let records = records.clone();
let output_tx = output_tx.clone();
async move {
on_transaction_notification(records, output_tx, notification);
}
},
cancel,
parent_cancel,
)
}
fn spawn_account_diff_subscription(
rpc_endpoint: &str,
args: &RunArgs,
output: &OutputSink,
cancel: CancellationToken,
parent_cancel: CancellationToken,
) -> SubscriptionHandle {
let records = output.records.clone();
let output_tx = output.output_tx.clone();
spawn_subscription_manager::<AccountDiff, _, _>(
rpc_endpoint.to_string(),
args.program_id.clone(),
move |notification| {
let records = records.clone();
let output_tx = output_tx.clone();
async move {
on_account_diff_notification(records, output_tx, notification);
}
},
cancel,
parent_cancel,
)
}
fn open_output_writer(path: &std::path::Path) -> Result<std::fs::File> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
&& !parent.exists()
{
std::fs::create_dir_all(parent)
.with_context(|| format!("could not create output directory '{}'", parent.display()))?;
}
std::fs::File::create(path).with_context(|| format!("could not create '{}'", path.display()))
}
fn spawn_output_writer(
path: PathBuf,
file: std::fs::File,
mut rx: tokio::sync::mpsc::UnboundedReceiver<OutputEvent>,
) -> tokio::task::JoinHandle<Result<()>> {
tokio::spawn(async move {
use std::io::Write;
let write_err = || format!("could not write to '{}'", path.display());
let mut writer = std::io::BufWriter::new(file);
while let Some(event) = rx.recv().await {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
while let Ok(event) = rx.try_recv() {
serde_json::to_writer(&mut writer, &event).with_context(write_err)?;
writer.write_all(b"\n").with_context(write_err)?;
}
writer.flush().with_context(write_err)?;
}
writer
.flush()
.with_context(|| format!("could not finalize '{}'", path.display()))?;
Ok(())
})
}
fn update_progress_position(pb: &ProgressBar, start_slot: u64, slot: u64) {
pb.set_position(slot.saturating_sub(start_slot).saturating_add(1));
}
fn spinner_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {spinner:.cyan} {msg}")
.unwrap_or_else(|_| ProgressStyle::default_spinner())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
fn bar_style(start: u64, end: u64) -> ProgressStyle {
ProgressStyle::with_template("[{range}] {bar:40.cyan/blue} {pos}/{len} slots ({elapsed}) {msg}")
.unwrap_or_else(|_| ProgressStyle::default_bar())
.with_key(
"range",
move |_: &ProgressState, f: &mut dyn std::fmt::Write| {
let _ = write!(f, "{} → {}", format_slot(start), format_slot(end));
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_progress_position_tracks_slot_offset() {
let pb = ProgressBar::new(100);
update_progress_position(&pb, 362_270_659, 362_270_659);
assert_eq!(pb.position(), 1);
update_progress_position(&pb, 362_270_659, 362_270_700);
assert_eq!(pb.position(), 42);
}
}