1use std::{
8 collections::HashMap,
9 env,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13
14use hedged_rpc_client::{HedgeConfig, HedgedRpcClient, ProviderConfig, ProviderId};
15use solana_commitment_config::CommitmentConfig;
16use solana_sdk::pubkey::Pubkey;
17use tokio::sync::{mpsc, Semaphore};
18
19const NUM_CALLS: usize = 50_000;
20const MAX_IN_FLIGHT: usize = 256;
21
22#[derive(Debug)]
23enum CallOutcome {
24 Ok {
25 provider: ProviderId,
26 latency: Duration,
27 },
28 Err {
29 error: String,
30 latency: Duration,
31 },
32}
33
34#[derive(Debug)]
35struct CallResult {
36 call_idx: usize,
37 outcome: CallOutcome,
38}
39
40fn provider_from_env(env_key: &str, id: &'static str) -> Option<ProviderConfig> {
41 env::var(env_key).ok().map(|url| ProviderConfig {
42 id: ProviderId(id),
43 url,
44 })
45}
46
47#[tokio::main]
48async fn main() -> Result<(), Box<dyn std::error::Error>> {
49 let mut providers = Vec::new();
50
51 if let Some(p) = provider_from_env("HELIUS_RPC_URL", "helius") {
52 providers.push(p);
53 }
54 if let Some(p) = provider_from_env("TRITON_RPC_URL", "triton") {
55 providers.push(p);
56 }
57 if let Some(p) = provider_from_env("QUICKNODE_RPC_URL", "quicknode") {
58 providers.push(p);
59 }
60
61 if providers.is_empty() {
62 eprintln!("No providers configured.");
63 eprintln!("Set at least one of: HELIUS_RPC_URL, TRITON_RPC_URL, QUICKNODE_RPC_URL");
64 return Ok(());
65 }
66
67 let cfg = HedgeConfig {
68 initial_providers: providers.len(),
69 hedge_after: Duration::from_millis(20),
70 max_providers: providers.len(),
71 min_slot: None,
72 overall_timeout: Duration::from_secs(1),
73 };
74
75 let client = HedgedRpcClient::new(providers, cfg);
76
77 let addr: Pubkey = "So11111111111111111111111111111111111111112".parse()?;
78 let commitment = CommitmentConfig::processed();
79
80 let (tx, mut rx) = mpsc::channel::<CallResult>(MAX_IN_FLIGHT * 2);
81 let semaphore = Arc::new(Semaphore::new(MAX_IN_FLIGHT));
82 let consumer = tokio::spawn(async move {
83 let mut results: Vec<CallResult> = Vec::with_capacity(NUM_CALLS);
84
85 while let Some(res) = rx.recv().await {
86 match &res.outcome {
87 CallOutcome::Ok { provider, latency } => {
88 println!(
89 "[call {:05}] OK provider={} latency={:?}",
90 res.call_idx, provider.0, latency
91 );
92 }
93 CallOutcome::Err { error, latency } => {
94 println!(
95 "[call {:05}] ERR latency={:?} error={}",
96 res.call_idx, latency, error
97 );
98 }
99 }
100
101 results.push(res);
102 }
103
104 results
105 });
106
107 for i in 0..NUM_CALLS {
108 let client_clone = client.clone();
109 let tx_clone = tx.clone();
110 let addr_copy = addr;
111 let commitment_clone = commitment;
112 let sem = semaphore.clone();
113
114 tokio::spawn(async move {
115 let _permit = sem.acquire_owned().await.expect("semaphore closed");
116
117 let start = Instant::now();
118 let res = client_clone.get_account(&addr_copy, commitment_clone).await;
119 let elapsed = start.elapsed();
120
121 let outcome = match res {
122 Ok((provider, _resp)) => CallOutcome::Ok {
123 provider,
124 latency: elapsed,
125 },
126 Err(e) => CallOutcome::Err {
127 error: e.to_string(),
128 latency: elapsed,
129 },
130 };
131
132 let _ = tx_clone
133 .send(CallResult {
134 call_idx: i,
135 outcome,
136 })
137 .await;
138 });
139 }
140
141 drop(tx);
142 let mut results = consumer.await?;
143
144 results.sort_by_key(|r| r.call_idx);
145
146 let mut wins: HashMap<&'static str, usize> = HashMap::new();
147 let mut total_latency: HashMap<&'static str, Duration> = HashMap::new();
148 let mut error_count = 0usize;
149
150 for r in &results {
151 match &r.outcome {
152 CallOutcome::Ok { provider, latency } => {
153 let name = provider.0;
154 *wins.entry(name).or_insert(0) += 1;
155 *total_latency.entry(name).or_insert(Duration::ZERO) += *latency;
156 }
157 CallOutcome::Err { .. } => {
158 error_count += 1;
159 }
160 }
161 }
162
163 println!("\n=== summary ===");
164 println!("total calls : {}", NUM_CALLS);
165 println!("successes : {}", NUM_CALLS - error_count);
166 println!("errors (any kind) : {}", error_count);
167
168 for (provider, count) in wins.iter() {
169 let total = total_latency[provider];
170 let avg_ms = total.as_secs_f64() * 1000.0 / (*count as f64);
171 println!(
172 "provider {:>10}: wins = {:6}, avg_latency = {:8.3} ms",
173 provider, count, avg_ms,
174 );
175 }
176
177 Ok(())
178}