1use std::{
8 collections::HashMap,
9 env,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13
14use hedged_rpc_client::{HedgeConfig, HedgedRpcClient, ProviderConfig, ProviderId};
15use solana_commitment_config::CommitmentConfig;
16use solana_sdk::pubkey::Pubkey;
17use tokio::sync::{mpsc, Semaphore};
18
19#[derive(Debug)]
20enum CallOutcome {
21 Ok {
22 provider: ProviderId,
23 latency: Duration,
24 },
25 Err {
26 error: String,
27 latency: Duration,
28 },
29}
30
31#[derive(Debug)]
32struct CallResult {
33 call_idx: usize,
34 outcome: CallOutcome,
35}
36
37#[derive(Debug)]
38struct RunnerStats {
39 label: &'static str,
40 total_calls: usize,
41 successes: usize,
42 errors: usize,
43 avg_latency_ms: f64,
44 per_provider_wins: HashMap<&'static str, usize>,
45}
46
47fn provider_from_env(env_key: &str, id: &'static str) -> Option<ProviderConfig> {
48 env::var(env_key).ok().map(|url| ProviderConfig {
49 id: ProviderId(id),
50 url,
51 })
52}
53
54#[tokio::main]
55async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
56 let mut providers = Vec::new();
57
58 if let Some(p) = provider_from_env("HELIUS_RPC_URL", "helius") {
59 providers.push(p);
60 }
61 if let Some(p) = provider_from_env("TRITON_RPC_URL", "triton") {
62 providers.push(p);
63 }
64 if let Some(p) = provider_from_env("QUICKNODE_RPC_URL", "quicknode") {
65 providers.push(p);
66 }
67
68 if providers.is_empty() {
69 eprintln!("No providers configured.");
70 eprintln!("Set at least one of: HELIUS_RPC_URL, TRITON_RPC_URL, QUICKNODE_RPC_URL");
71 return Ok(());
72 }
73
74 let cfg = HedgeConfig {
75 initial_providers: 2,
76 hedge_after: Duration::from_millis(20),
77 max_providers: providers.len(),
78 min_slot: None,
79 overall_timeout: Duration::from_secs(2),
80 };
81
82 let client_a = HedgedRpcClient::new(providers.clone(), cfg.clone());
83 let client_b = HedgedRpcClient::new(providers, cfg);
84
85 let addr: Pubkey = "So11111111111111111111111111111111111111112".parse()?;
86 let commitment = CommitmentConfig::processed();
87 let num_calls_per_runner: usize = 10_000;
88 let max_in_flight_per_runner: usize = 256;
89
90 let runner_a = tokio::spawn(run_runner(
91 "A",
92 client_a,
93 addr,
94 commitment,
95 num_calls_per_runner,
96 max_in_flight_per_runner,
97 ));
98
99 let runner_b = tokio::spawn(run_runner(
100 "B",
101 client_b,
102 addr,
103 commitment,
104 num_calls_per_runner,
105 max_in_flight_per_runner,
106 ));
107
108 let stats_a = runner_a.await??;
109 let stats_b = runner_b.await??;
110
111 println!("\n=== comparison ===");
112 println!(
113 "Runner {}: total={}, successes={}, errors={}, avg_latency={:.3} ms",
114 stats_a.label,
115 stats_a.total_calls,
116 stats_a.successes,
117 stats_a.errors,
118 stats_a.avg_latency_ms
119 );
120 for (provider, wins) in &stats_a.per_provider_wins {
121 println!(
122 " [{}] wins from provider {} = {}",
123 stats_a.label, provider, wins
124 );
125 }
126
127 println!(
128 "Runner {}: total={}, successes={}, errors={}, avg_latency={:.3} ms",
129 stats_b.label,
130 stats_b.total_calls,
131 stats_b.successes,
132 stats_b.errors,
133 stats_b.avg_latency_ms
134 );
135 for (provider, wins) in &stats_b.per_provider_wins {
136 println!(
137 " [{}] wins from provider {} = {}",
138 stats_b.label, provider, wins
139 );
140 }
141
142 if stats_a.avg_latency_ms < stats_b.avg_latency_ms {
143 println!(
144 "\n=> Runner {} was faster on average by {:.3} ms",
145 stats_a.label,
146 stats_b.avg_latency_ms - stats_a.avg_latency_ms
147 );
148 } else if stats_b.avg_latency_ms < stats_a.avg_latency_ms {
149 println!(
150 "\n=> Runner {} was faster on average by {:.3} ms",
151 stats_b.label,
152 stats_a.avg_latency_ms - stats_b.avg_latency_ms
153 );
154 } else {
155 println!("\n=> Both runners had the same average latency.");
156 }
157
158 Ok(())
159}
160
161async fn run_runner(
162 label: &'static str,
163 client: HedgedRpcClient,
164 addr: Pubkey,
165 commitment: CommitmentConfig,
166 num_calls: usize,
167 max_in_flight: usize,
168) -> Result<RunnerStats, Box<dyn std::error::Error + Send + Sync>> {
169 let (tx, mut rx) = mpsc::channel::<CallResult>(max_in_flight * 2);
170 let semaphore = Arc::new(Semaphore::new(max_in_flight));
171
172 let consumer = tokio::spawn(async move {
173 let mut results: Vec<CallResult> = Vec::with_capacity(num_calls);
174
175 while let Some(res) = rx.recv().await {
176 match &res.outcome {
177 CallOutcome::Ok { provider, latency } => {
178 println!(
179 "[{} call {:05}] OK provider={} latency={:?}",
180 label, res.call_idx, provider.0, latency
181 );
182 }
183 CallOutcome::Err { error, latency } => {
184 println!(
185 "[{} call {:05}] ERR latency={:?} error={}",
186 label, res.call_idx, latency, error
187 );
188 }
189 }
190
191 results.push(res);
192 }
193
194 results
195 });
196
197 for i in 0..num_calls {
198 let client_clone = client.clone();
199 let tx_clone = tx.clone();
200 let addr_copy = addr;
201 let commitment_clone = commitment;
202 let sem = semaphore.clone();
203
204 tokio::spawn(async move {
205 let _permit = sem.acquire_owned().await.expect("semaphore closed");
206
207 let start = Instant::now();
208 let res = client_clone.get_account(&addr_copy, commitment_clone).await;
209 let elapsed = start.elapsed();
210
211 let outcome = match res {
212 Ok((provider, _resp)) => CallOutcome::Ok {
213 provider,
214 latency: elapsed,
215 },
216 Err(e) => CallOutcome::Err {
217 error: e.to_string(),
218 latency: elapsed,
219 },
220 };
221
222 let _ = tx_clone
223 .send(CallResult {
224 call_idx: i,
225 outcome,
226 })
227 .await;
228 });
229 }
230
231 drop(tx);
232 let results = consumer.await?;
233
234 let mut successes = 0usize;
235 let mut errors = 0usize;
236 let mut sum_latency = Duration::ZERO;
237 let mut per_provider_wins: HashMap<&'static str, usize> = HashMap::new();
238
239 for r in &results {
240 match &r.outcome {
241 CallOutcome::Ok { provider, latency } => {
242 successes += 1;
243 sum_latency += *latency;
244 *per_provider_wins.entry(provider.0).or_insert(0) += 1;
245 }
246 CallOutcome::Err { .. } => {
247 errors += 1;
248 }
249 }
250 }
251
252 let avg_latency_ms = if successes > 0 {
253 (sum_latency.as_secs_f64() * 1000.0) / (successes as f64)
254 } else {
255 0.0
256 };
257
258 Ok(RunnerStats {
259 label,
260 total_calls: num_calls,
261 successes,
262 errors,
263 avg_latency_ms,
264 per_provider_wins,
265 })
266}