use std::{
collections::{BTreeSet, HashMap},
path::PathBuf,
};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use clap::{Parser, Subcommand};
use console::style;
use eyre::{Context, Result};
use solana_address::Address;
use crate::output::{SimulationOutput, Transaction};
const METADATA_PROGRAM: &str = "metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s";
const RPC_URL: &str = "https://api.mainnet-beta.solana.com";
const WSOL_MINT: &str = "So11111111111111111111111111111111111111112";
const DEFAULT_LIMIT: usize = 10;
struct SolDiff {
pubkey: String,
baseline_delta: i64,
experiment_delta: i64,
}
struct SplDiff {
pubkey: String,
mint: String,
decimals: u8,
baseline_delta: i64,
experiment_delta: i64,
}
struct PnlDiff<'a> {
base: &'a Transaction,
sol: Vec<SolDiff>,
tokens: Vec<SplDiff>,
score: f64,
}
impl<'a> PnlDiff<'a> {
fn new(base: &'a Transaction, exp: &'a Transaction) -> Option<Self> {
let base_sol: HashMap<&str, i64> = base
.sol_changes
.iter()
.map(|c| (c.pubkey.as_str(), c.delta()))
.collect();
let exp_sol: HashMap<&str, i64> = exp
.sol_changes
.iter()
.map(|c| (c.pubkey.as_str(), c.delta()))
.collect();
let sol_keys: BTreeSet<&str> = base_sol.keys().chain(exp_sol.keys()).copied().collect();
let sol_diffs: Vec<SolDiff> = sol_keys
.into_iter()
.filter_map(|pk| {
let b = base_sol.get(pk).copied().unwrap_or(0);
let e = exp_sol.get(pk).copied().unwrap_or(0);
if b == e {
return None;
}
Some(SolDiff {
pubkey: pk.to_string(),
baseline_delta: b,
experiment_delta: e,
})
})
.collect();
let base_spl: HashMap<(&str, &str), (i64, u8)> = base
.token_changes
.iter()
.map(|c| {
(
(c.pubkey.as_str(), c.mint.as_str()),
(c.delta(), c.decimals),
)
})
.collect();
let exp_spl: HashMap<(&str, &str), (i64, u8)> = exp
.token_changes
.iter()
.map(|c| {
(
(c.pubkey.as_str(), c.mint.as_str()),
(c.delta(), c.decimals),
)
})
.collect();
let spl_keys: BTreeSet<(&str, &str)> =
base_spl.keys().chain(exp_spl.keys()).copied().collect();
let token_diffs: Vec<SplDiff> = spl_keys
.into_iter()
.filter_map(|(pk, mint)| {
let (b, decimals) = base_spl.get(&(pk, mint)).copied().unwrap_or_else(|| {
exp_spl
.get(&(pk, mint))
.map(|(_, dec)| (0, *dec))
.unwrap_or((0, 0))
});
let (e, _) = exp_spl.get(&(pk, mint)).copied().unwrap_or((0, decimals));
if b == e {
return None;
}
Some(SplDiff {
pubkey: pk.to_string(),
mint: mint.to_string(),
decimals,
baseline_delta: b,
experiment_delta: e,
})
})
.collect();
let normalized = |base: i64, exp: i64| -> f64 {
let diff = (exp - base).unsigned_abs() as f64;
let base_abs = base.unsigned_abs() as f64;
if base == 0 {
diff
} else {
(diff * base_abs).sqrt()
}
};
let discrepancy: f64 = sol_diffs
.iter()
.map(|d| normalized(d.baseline_delta, d.experiment_delta))
.sum::<f64>()
+ token_diffs
.iter()
.map(|d| normalized(d.baseline_delta, d.experiment_delta))
.sum::<f64>();
if discrepancy == 0.0 {
return None;
}
Some(PnlDiff {
base,
sol: sol_diffs,
tokens: token_diffs,
score: discrepancy,
})
}
}
#[derive(Subcommand, Debug)]
pub enum CompareSection {
Regressions,
Improvements,
Logs,
Balances,
}
#[derive(Parser, Debug)]
pub struct CompareArgs {
pub baseline: PathBuf,
pub experiment: PathBuf,
#[command(subcommand)]
pub section: Option<CompareSection>,
}
pub async fn compare(args: CompareArgs) -> Result<()> {
let baseline = read_output(&args.baseline)?;
let experiment = read_output(&args.experiment)?;
let baseline_txs: HashMap<&str, &Transaction> = baseline
.transactions
.iter()
.map(|tx| (tx.signature.as_str(), tx))
.collect();
let experiment_txs: HashMap<&str, &Transaction> = experiment
.transactions
.iter()
.map(|tx| (tx.signature.as_str(), tx))
.collect();
let mut regressions: Vec<(&Transaction, &Transaction)> = Vec::new();
let mut improvements: Vec<(&Transaction, &Transaction)> = Vec::new();
let mut log_diffs: Vec<(&Transaction, &Transaction)> = Vec::new();
let mut missing: Vec<&str> = Vec::new();
let mut pnl_diffs: Vec<PnlDiff> = Vec::new();
for (sig, base_tx) in &baseline_txs {
let Some(exp_tx) = experiment_txs.get(sig) else {
missing.push(*sig);
continue;
};
match (base_tx.success, exp_tx.success) {
(true, false) => regressions.push((base_tx, exp_tx)),
(false, true) => improvements.push((base_tx, exp_tx)),
_ => {
if base_tx.logs.len() != exp_tx.logs.len() {
log_diffs.push((base_tx, exp_tx));
}
}
}
if base_tx.success
&& exp_tx.success
&& let Some(diff) = PnlDiff::new(base_tx, exp_tx)
{
pnl_diffs.push(diff);
}
}
regressions.sort_by_key(|(tx, _)| tx.slot);
improvements.sort_by_key(|(tx, _)| tx.slot);
log_diffs.sort_by_key(|(tx, _)| tx.slot);
missing.sort();
pnl_diffs.sort_by(|a, b| b.score.total_cmp(&a.score));
let common = baseline_txs.len() - missing.len();
println!(
"Baseline ({}) versus Experiment ({})",
args.baseline.display(),
args.experiment.display(),
);
println!(
"Baseline transactions: {}",
baseline.summary.total_transactions,
);
println!(
"Experiment transactions: {}",
experiment.summary.total_transactions,
);
println!("Common transactions: {common}\n");
match args.section {
Some(CompareSection::Regressions) => {
print_regressions(®ressions, None);
}
Some(CompareSection::Improvements) => {
print_improvements(&improvements, None);
}
Some(CompareSection::Logs) => {
print_log_diffs(&log_diffs, None);
}
Some(CompareSection::Balances) => {
let symbols = resolve_token_symbols(&pnl_diffs).await;
print_pnl_diffs(&pnl_diffs, None, &symbols);
}
None => {
if regressions.is_empty()
&& improvements.is_empty()
&& log_diffs.is_empty()
&& missing.is_empty()
&& pnl_diffs.is_empty()
{
println!("{} No differences found", style("✔").green());
return Ok(());
}
print_regressions(®ressions, Some(DEFAULT_LIMIT));
print_improvements(&improvements, Some(DEFAULT_LIMIT));
print_log_diffs(&log_diffs, Some(DEFAULT_LIMIT));
print_missing(&missing, Some(DEFAULT_LIMIT));
let symbols = resolve_token_symbols(&pnl_diffs).await;
print_pnl_diffs(&pnl_diffs, Some(DEFAULT_LIMIT), &symbols);
println!(
"Summary: {} regressions, {} improvements, {} log diffs, {} missing, {} balance diffs",
regressions.len(),
improvements.len(),
log_diffs.len(),
missing.len(),
pnl_diffs.len(),
);
}
}
Ok(())
}
fn print_regressions(regressions: &[(&Transaction, &Transaction)], limit: Option<usize>) {
if regressions.is_empty() {
return;
}
println!(
"{} Regressions (success → failure): {}",
style("✖").red(),
regressions.len()
);
let shown = limit.map_or(regressions.len(), |l| l.min(regressions.len()));
for (base, modified) in ®ressions[..shown] {
println!(" slot {} {}", base.slot, format(&base.signature));
if let Some(err) = &modified.error {
println!(" error: {err}");
}
}
if let Some(limit) = limit
&& regressions.len() > limit
{
println!(
" … and {} more (use `sim compare regressions` to view all)",
regressions.len() - limit
);
}
println!();
}
fn print_improvements(improvements: &[(&Transaction, &Transaction)], limit: Option<usize>) {
if improvements.is_empty() {
return;
}
println!(
"{} Improvements (failure → success): {}",
style("✔").green(),
improvements.len()
);
let shown = limit.map_or(improvements.len(), |l| l.min(improvements.len()));
for (base, _) in &improvements[..shown] {
println!(" slot {} {}", base.slot, format(&base.signature));
if let Some(err) = &base.error {
println!(" was: {err}");
}
}
if let Some(limit) = limit
&& improvements.len() > limit
{
println!(
" … and {} more (use `sim compare improvements` to view all)",
improvements.len() - limit
);
}
println!();
}
fn print_log_diffs(log_diffs: &[(&Transaction, &Transaction)], limit: Option<usize>) {
if log_diffs.is_empty() {
return;
}
println!(
"{} Log differences: {}",
style("~").yellow(),
log_diffs.len()
);
let shown = limit.map_or(log_diffs.len(), |l| l.min(log_diffs.len()));
for (base, modified) in &log_diffs[..shown] {
println!(
" slot {} {} ({} baseline logs → {} modified logs)",
base.slot,
format(&base.signature),
base.logs.len(),
modified.logs.len(),
);
}
if let Some(limit) = limit
&& log_diffs.len() > limit
{
println!(
" … and {} more (use `sim compare logs` to view all)",
log_diffs.len() - limit
);
}
println!();
}
fn print_missing(missing: &[&str], limit: Option<usize>) {
if missing.is_empty() {
return;
}
println!(
"{} Missing from experiment: {}",
style("?").dim(),
missing.len()
);
let shown = limit.map_or(missing.len(), |l| l.min(missing.len()));
for signature in &missing[..shown] {
println!(" {}", format(signature));
}
if let Some(limit) = limit
&& missing.len() > limit
{
println!(" … and {} more", missing.len() - limit);
}
println!();
}
async fn resolve_token_symbols(pnl_diffs: &[PnlDiff<'_>]) -> HashMap<String, String> {
let mints: BTreeSet<&str> = pnl_diffs
.iter()
.flat_map(|d| d.tokens.iter().map(|t| t.mint.as_str()))
.collect();
if mints.is_empty() {
return HashMap::new();
}
let Ok(program_id) = METADATA_PROGRAM.parse::<Address>() else {
return HashMap::new();
};
let client = reqwest::Client::new();
let mut symbols = HashMap::new();
for mint in mints {
let Ok(mint_addr) = mint.parse::<Address>() else {
continue;
};
let (pda, _) = Address::find_program_address(
&[b"metadata", program_id.as_ref(), mint_addr.as_ref()],
&program_id,
);
let body = serde_json::json!({
"jsonrpc": "2.0", "id": 1,
"method": "getAccountInfo",
"params": [pda.to_string(), {"encoding": "base64"}]
});
let Ok(resp) = client.post(RPC_URL).json(&body).send().await else {
continue;
};
let Ok(json) = resp.json::<serde_json::Value>().await else {
continue;
};
if let Some(b64) = json["result"]["value"]["data"][0].as_str()
&& let Ok(raw) = BASE64.decode(b64)
&& let Some(symbol) = parse_metadata_symbol(&raw)
{
symbols.insert(mint.to_string(), symbol);
}
}
symbols
}
fn parse_metadata_symbol(data: &[u8]) -> Option<String> {
let mut off = 1 + 32 + 32;
let name_len = u32::from_le_bytes(data.get(off..off + 4)?.try_into().ok()?) as usize;
off += 4 + name_len;
let sym_len = u32::from_le_bytes(data.get(off..off + 4)?.try_into().ok()?) as usize;
off += 4;
let symbol = std::str::from_utf8(data.get(off..off + sym_len)?)
.ok()?
.trim_end_matches('\0')
.to_string();
if symbol.is_empty() {
return None;
}
Some(symbol)
}
fn print_pnl_diffs(
pnl_diffs: &[PnlDiff<'_>],
limit: Option<usize>,
symbols: &HashMap<String, String>,
) {
if pnl_diffs.is_empty() {
return;
}
println!(
"{} Balance differences: {}",
style("$").cyan(),
pnl_diffs.len()
);
let shown = limit.map_or(pnl_diffs.len(), |l| l.min(pnl_diffs.len()));
for diff in &pnl_diffs[..shown] {
println!(
" slot {} {}",
diff.base.slot,
format(&diff.base.signature)
);
for d in &diff.sol {
let pct = (d.experiment_delta - d.baseline_delta) as f64
/ d.baseline_delta.unsigned_abs() as f64
* 100.0;
println!(
" SOL {} {:+.9} → {:+.9} ({:+.4}%)",
format(&d.pubkey),
d.baseline_delta as f64 / 1e9,
d.experiment_delta as f64 / 1e9,
pct,
);
}
for d in &diff.tokens {
let scale = 10f64.powi(d.decimals as i32);
let prec = d.decimals as usize;
let ticker = if d.mint == WSOL_MINT {
"WSOL"
} else {
symbols.get(&d.mint).map(|s| s.as_str()).unwrap_or("TOK")
};
let pct = (d.experiment_delta - d.baseline_delta) as f64
/ d.baseline_delta.unsigned_abs() as f64
* 100.0;
println!(
" {ticker} {} {:+.prec$} → {:+.prec$} ({:+.4}%)",
format(&d.pubkey),
d.baseline_delta as f64 / scale,
d.experiment_delta as f64 / scale,
pct,
prec = prec,
);
}
println!();
}
if let Some(limit) = limit
&& pnl_diffs.len() > limit
{
println!(
" … and {} more (use `sim compare balances` to view all)",
pnl_diffs.len() - limit
);
}
println!();
}
fn read_output(path: &PathBuf) -> Result<SimulationOutput> {
let json = std::fs::read_to_string(path)
.with_context(|| format!("failed to read {}", path.display()))?;
serde_json::from_str(&json)
.with_context(|| format!("failed to parse {} as simulation output", path.display()))
}
fn format(signature: &str) -> String {
let n = signature.len();
if n <= 16 {
return signature.to_string();
}
format!("{}…{}", &signature[..8], &signature[n - 8..])
}