1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{Arc, Mutex},
5 time::Instant,
6};
7
8use futures::{stream::FuturesUnordered, StreamExt};
9use solana_client::{client_error::ClientError, nonblocking::rpc_client::RpcClient};
10use solana_commitment_config::CommitmentConfig;
11use solana_rpc_client_api::{client_error::ErrorKind, response::Response as RpcResponse};
12use solana_sdk::{account::Account, hash::Hash, pubkey::Pubkey};
13use tokio::time;
14
15use crate::{
16 config::{HedgeConfig, ProviderConfig, ProviderId},
17 errors::HedgedError,
18};
19
20#[derive(Debug, Default)]
21struct ProviderStats {
22 wins: u64,
23 total_latency_ms: f64,
24 errors: u64,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProviderStatsSnapshot {
30 pub wins: u64,
32 pub avg_latency_ms: f64,
34 pub errors: u64,
36}
37
38#[derive(Clone)]
44pub struct HedgedRpcClient {
45 providers: Arc<Vec<(ProviderId, Arc<RpcClient>)>>,
46 cfg: HedgeConfig,
47 stats: Arc<Mutex<HashMap<ProviderId, ProviderStats>>>,
48}
49
50impl HedgedRpcClient {
51 pub fn new(provider_cfgs: Vec<ProviderConfig>, cfg: HedgeConfig) -> Self {
80 let providers_vec: Vec<(ProviderId, Arc<RpcClient>)> = provider_cfgs
81 .into_iter()
82 .map(|pcfg| {
83 let client = Arc::new(RpcClient::new(pcfg.url));
84 (pcfg.id, client)
85 })
86 .collect();
87
88 let mut stats_map = HashMap::new();
89 for (id, _) in &providers_vec {
90 stats_map.insert(*id, ProviderStats::default());
91 }
92
93 Self {
94 providers: Arc::new(providers_vec),
95 cfg,
96 stats: Arc::new(Mutex::new(stats_map)),
97 }
98 }
99
100 pub fn providers(&self) -> &[(ProviderId, Arc<RpcClient>)] {
102 &self.providers
103 }
104
105 pub fn provider_stats(&self) -> HashMap<ProviderId, ProviderStatsSnapshot> {
109 let stats = self.stats.lock().expect("provider stats mutex poisoned");
110
111 stats
112 .iter()
113 .map(|(id, s)| {
114 let avg = if s.wins > 0 {
115 s.total_latency_ms / (s.wins as f64)
116 } else {
117 0.0
118 };
119
120 (
121 *id,
122 ProviderStatsSnapshot {
123 wins: s.wins,
124 avg_latency_ms: avg,
125 errors: s.errors,
126 },
127 )
128 })
129 .collect()
130 }
131
132 async fn hedged_call<T, F, Fut>(&self, f: F) -> Result<(ProviderId, T), HedgedError>
142 where
143 T: Send,
144 F: Fn(Arc<RpcClient>) -> Fut + Send,
145 Fut: Future<Output = Result<T, ClientError>> + Send,
146 {
147 if self.providers.is_empty() {
148 return Err(HedgedError::NoProviders);
149 }
150
151 let max_idx = self.cfg.max_providers.min(self.providers.len());
152 if max_idx == 0 {
153 return Err(HedgedError::NoProviders);
154 }
155 let selected_providers = &self.providers[..max_idx];
156
157 let start = Instant::now();
158 let selected_ids: Vec<ProviderId> = selected_providers.iter().map(|(id, _)| *id).collect();
159
160 let hedging_logic = async {
161 let mut failures = Vec::new();
162 let mut futures = FuturesUnordered::new();
163
164 let spawn_provider = move |provider_id: ProviderId, client: Arc<RpcClient>| {
165 let fut = f(client);
166 async move {
167 let result = fut.await;
168 (provider_id, result)
169 }
170 };
171
172 let initial_count = self
173 .cfg
174 .initial_providers
175 .max(1)
176 .min(selected_providers.len());
177
178 for (provider_id, client) in &selected_providers[..initial_count] {
179 futures.push(spawn_provider(*provider_id, client.clone()));
180 }
181
182 let needs_hedging = initial_count < selected_providers.len();
183 let mut hedged = !needs_hedging;
184 let hedge_sleep = time::sleep(self.cfg.hedge_after);
185 tokio::pin!(hedge_sleep);
186
187 loop {
188 if futures.is_empty() && hedged {
189 break;
190 }
191
192 tokio::select! {
193 Some((provider_id, result)) = futures.next(), if !futures.is_empty() => {
194 match result {
195 Ok(val) => return Ok((provider_id, val)),
196 Err(e) => failures.push((provider_id, e)),
197 }
198 }
199 _ = &mut hedge_sleep, if needs_hedging && !hedged => {
200 hedged = true;
201 for (provider_id, client) in &selected_providers[initial_count..] {
202 futures.push(spawn_provider(*provider_id, client.clone()));
203 }
204 }
205 }
206 }
207
208 Err(HedgedError::AllFailed(failures))
209 };
210
211 let timed = time::timeout(self.cfg.overall_timeout, hedging_logic).await;
212
213 let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
214
215 match timed {
216 Err(_) => {
217 if let Ok(mut stats) = self.stats.lock() {
218 for id in selected_ids {
219 if let Some(entry) = stats.get_mut(&id) {
220 entry.errors += 1;
221 }
222 }
223 }
224 Err(HedgedError::Timeout(self.cfg.overall_timeout))
225 }
226 Ok(inner) => match inner {
227 Ok((winner_id, value)) => {
228 if let Ok(mut stats) = self.stats.lock() {
229 if let Some(entry) = stats.get_mut(&winner_id) {
230 entry.wins += 1;
231 entry.total_latency_ms += elapsed_ms;
232 }
233 }
234 Ok((winner_id, value))
235 }
236 Err(HedgedError::AllFailed(failures)) => {
237 if let Ok(mut stats) = self.stats.lock() {
238 for (id, _err) in failures.iter() {
239 if let Some(entry) = stats.get_mut(id) {
240 entry.errors += 1;
241 }
242 }
243 }
244 Err(HedgedError::AllFailed(failures))
245 }
246 Err(e) => Err(e),
247 },
248 }
249 }
250
251 pub async fn get_latest_blockhash(&self) -> Result<(ProviderId, Hash), HedgedError> {
255 let (id, resp) = self
256 .hedged_call(move |client| async move { client.get_latest_blockhash().await })
257 .await?;
258
259 Ok((id, resp))
260 }
261
262 pub async fn get_latest_blockhash_any(&self) -> Result<Hash, HedgedError> {
264 let (_id, resp) = self.get_latest_blockhash().await?;
265 Ok(resp)
266 }
267
268 pub async fn get_account(
276 &self,
277 pubkey: &Pubkey,
278 commitment: CommitmentConfig,
279 ) -> Result<(ProviderId, RpcResponse<Option<Account>>), HedgedError> {
280 let pk = *pubkey;
281
282 let (id, resp) = self
283 .hedged_call(move |client| {
284 let pk = pk;
285 async move { client.get_account_with_commitment(&pk, commitment).await }
286 })
287 .await?;
288
289 Ok((id, resp))
290 }
291
292 pub async fn get_account_any(
294 &self,
295 pubkey: &Pubkey,
296 commitment: CommitmentConfig,
297 ) -> Result<RpcResponse<Option<Account>>, HedgedError> {
298 let (_id, resp) = self.get_account(pubkey, commitment).await?;
299
300 Ok(resp)
301 }
302
303 pub async fn get_account_fresh(
313 &self,
314 pubkey: &Pubkey,
315 commitment: CommitmentConfig,
316 min_slot: u64,
317 ) -> Result<(ProviderId, RpcResponse<Option<Account>>), HedgedError> {
318 let (id, resp) = self.get_account(pubkey, commitment).await?;
319 if resp.context.slot < min_slot {
320 return Err(HedgedError::AllFailed(vec![(
321 id,
322 ErrorKind::Custom(format!(
323 "StaleResponse: min_slot {min_slot}, got {}",
324 resp.context.slot
325 ))
326 .into(),
327 )]));
328 }
329 Ok((id, resp))
330 }
331}