hedged_rpc_client/
client.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    sync::{Arc, Mutex},
5    time::Instant,
6};
7
8use futures::{stream::FuturesUnordered, StreamExt};
9use solana_client::{client_error::ClientError, nonblocking::rpc_client::RpcClient};
10use solana_commitment_config::CommitmentConfig;
11use solana_rpc_client_api::{client_error::ErrorKind, response::Response as RpcResponse};
12use solana_sdk::{account::Account, hash::Hash, pubkey::Pubkey};
13use tokio::time;
14
15use crate::{
16    config::{HedgeConfig, ProviderConfig, ProviderId},
17    errors::HedgedError,
18};
19
20#[derive(Debug, Default)]
21struct ProviderStats {
22    wins: u64,
23    total_latency_ms: f64,
24    errors: u64,
25}
26
27/// Snapshot of provider performance statistics.
28#[derive(Debug, Clone)]
29pub struct ProviderStatsSnapshot {
30    /// Number of times this provider won the race.
31    pub wins: u64,
32    /// Average latency in milliseconds for winning calls.
33    pub avg_latency_ms: f64,
34    /// Number of failed calls from this provider.
35    pub errors: u64,
36}
37
38/// A Solana RPC client that hedges requests across multiple providers.
39///
40/// The client races requests to multiple RPC endpoints and returns the first
41/// successful response, implementing the "hedged requests" pattern to reduce
42/// tail latency.
43#[derive(Clone)]
44pub struct HedgedRpcClient {
45    providers: Arc<Vec<(ProviderId, Arc<RpcClient>)>>,
46    cfg: HedgeConfig,
47    stats: Arc<Mutex<HashMap<ProviderId, ProviderStats>>>,
48}
49
50impl HedgedRpcClient {
51    /// Creates a new hedged RPC client with the specified providers and configuration.
52    ///
53    /// # Arguments
54    /// * `provider_cfgs` - List of RPC provider configurations (URLs and IDs)
55    /// * `cfg` - Hedging strategy configuration
56    ///
57    /// # Example
58    /// ```no_run
59    /// use hedged_rpc_client::{HedgedRpcClient, HedgeConfig, ProviderConfig, ProviderId};
60    /// use std::time::Duration;
61    ///
62    /// let providers = vec![
63    ///     ProviderConfig {
64    ///         id: ProviderId("helius"),
65    ///         url: "https://mainnet.helius-rpc.com".to_string(),
66    ///     },
67    /// ];
68    ///
69    /// let config = HedgeConfig {
70    ///     initial_providers: 1,
71    ///     hedge_after: Duration::from_millis(50),
72    ///     max_providers: 3,
73    ///     min_slot: None,
74    ///     overall_timeout: Duration::from_secs(2),
75    /// };
76    ///
77    /// let client = HedgedRpcClient::new(providers, config);
78    /// ```
79    pub fn new(provider_cfgs: Vec<ProviderConfig>, cfg: HedgeConfig) -> Self {
80        let providers_vec: Vec<(ProviderId, Arc<RpcClient>)> = provider_cfgs
81            .into_iter()
82            .map(|pcfg| {
83                let client = Arc::new(RpcClient::new(pcfg.url));
84                (pcfg.id, client)
85            })
86            .collect();
87
88        let mut stats_map = HashMap::new();
89        for (id, _) in &providers_vec {
90            stats_map.insert(*id, ProviderStats::default());
91        }
92
93        Self {
94            providers: Arc::new(providers_vec),
95            cfg,
96            stats: Arc::new(Mutex::new(stats_map)),
97        }
98    }
99
100    /// Returns a reference to the configured providers.
101    pub fn providers(&self) -> &[(ProviderId, Arc<RpcClient>)] {
102        &self.providers
103    }
104
105    /// Returns a snapshot of accumulated performance statistics for each provider.
106    ///
107    /// Statistics include wins (successful responses), average latency, and error counts.
108    pub fn provider_stats(&self) -> HashMap<ProviderId, ProviderStatsSnapshot> {
109        let stats = self.stats.lock().expect("provider stats mutex poisoned");
110
111        stats
112            .iter()
113            .map(|(id, s)| {
114                let avg = if s.wins > 0 {
115                    s.total_latency_ms / (s.wins as f64)
116                } else {
117                    0.0
118                };
119
120                (
121                    *id,
122                    ProviderStatsSnapshot {
123                        wins: s.wins,
124                        avg_latency_ms: avg,
125                        errors: s.errors,
126                    },
127                )
128            })
129            .collect()
130    }
131
132    /// Core hedged request implementation.
133    ///
134    /// Races the provided RPC call across multiple providers according to the configured
135    /// hedging strategy. Returns the first successful response along with the provider ID.
136    ///
137    /// # Type Parameters
138    /// * `T` - The response type
139    /// * `F` - Closure that creates the RPC call
140    /// * `Fut` - Future returned by the closure
141    async fn hedged_call<T, F, Fut>(&self, f: F) -> Result<(ProviderId, T), HedgedError>
142    where
143        T: Send,
144        F: Fn(Arc<RpcClient>) -> Fut + Send,
145        Fut: Future<Output = Result<T, ClientError>> + Send,
146    {
147        if self.providers.is_empty() {
148            return Err(HedgedError::NoProviders);
149        }
150
151        let max_idx = self.cfg.max_providers.min(self.providers.len());
152        if max_idx == 0 {
153            return Err(HedgedError::NoProviders);
154        }
155        let selected_providers = &self.providers[..max_idx];
156
157        let start = Instant::now();
158        let selected_ids: Vec<ProviderId> = selected_providers.iter().map(|(id, _)| *id).collect();
159
160        let hedging_logic = async {
161            let mut failures = Vec::new();
162            let mut futures = FuturesUnordered::new();
163
164            let spawn_provider = move |provider_id: ProviderId, client: Arc<RpcClient>| {
165                let fut = f(client);
166                async move {
167                    let result = fut.await;
168                    (provider_id, result)
169                }
170            };
171
172            let initial_count = self
173                .cfg
174                .initial_providers
175                .max(1)
176                .min(selected_providers.len());
177
178            for (provider_id, client) in &selected_providers[..initial_count] {
179                futures.push(spawn_provider(*provider_id, client.clone()));
180            }
181
182            let needs_hedging = initial_count < selected_providers.len();
183            let mut hedged = !needs_hedging;
184            let hedge_sleep = time::sleep(self.cfg.hedge_after);
185            tokio::pin!(hedge_sleep);
186
187            loop {
188                if futures.is_empty() && hedged {
189                    break;
190                }
191
192                tokio::select! {
193                    Some((provider_id, result)) = futures.next(), if !futures.is_empty() => {
194                        match result {
195                            Ok(val) => return Ok((provider_id, val)),
196                            Err(e) => failures.push((provider_id, e)),
197                        }
198                    }
199                    _ = &mut hedge_sleep, if needs_hedging && !hedged => {
200                        hedged = true;
201                        for (provider_id, client) in &selected_providers[initial_count..] {
202                            futures.push(spawn_provider(*provider_id, client.clone()));
203                        }
204                    }
205                }
206            }
207
208            Err(HedgedError::AllFailed(failures))
209        };
210
211        let timed = time::timeout(self.cfg.overall_timeout, hedging_logic).await;
212
213        let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
214
215        match timed {
216            Err(_) => {
217                if let Ok(mut stats) = self.stats.lock() {
218                    for id in selected_ids {
219                        if let Some(entry) = stats.get_mut(&id) {
220                            entry.errors += 1;
221                        }
222                    }
223                }
224                Err(HedgedError::Timeout(self.cfg.overall_timeout))
225            }
226            Ok(inner) => match inner {
227                Ok((winner_id, value)) => {
228                    if let Ok(mut stats) = self.stats.lock() {
229                        if let Some(entry) = stats.get_mut(&winner_id) {
230                            entry.wins += 1;
231                            entry.total_latency_ms += elapsed_ms;
232                        }
233                    }
234                    Ok((winner_id, value))
235                }
236                Err(HedgedError::AllFailed(failures)) => {
237                    if let Ok(mut stats) = self.stats.lock() {
238                        for (id, _err) in failures.iter() {
239                            if let Some(entry) = stats.get_mut(id) {
240                                entry.errors += 1;
241                            }
242                        }
243                    }
244                    Err(HedgedError::AllFailed(failures))
245                }
246                Err(e) => Err(e),
247            },
248        }
249    }
250
251    /// Gets the latest blockhash from the fastest responding provider.
252    ///
253    /// Returns the blockhash along with the ID of the provider that responded first.
254    pub async fn get_latest_blockhash(&self) -> Result<(ProviderId, Hash), HedgedError> {
255        let (id, resp) = self
256            .hedged_call(move |client| async move { client.get_latest_blockhash().await })
257            .await?;
258
259        Ok((id, resp))
260    }
261
262    /// Gets the latest blockhash, returning only the hash without provider information.
263    pub async fn get_latest_blockhash_any(&self) -> Result<Hash, HedgedError> {
264        let (_id, resp) = self.get_latest_blockhash().await?;
265        Ok(resp)
266    }
267
268    /// Gets account data from the fastest responding provider.
269    ///
270    /// Returns the account response along with the ID of the provider that responded first.
271    ///
272    /// # Arguments
273    /// * `pubkey` - The account's public key
274    /// * `commitment` - The commitment level for the query
275    pub async fn get_account(
276        &self,
277        pubkey: &Pubkey,
278        commitment: CommitmentConfig,
279    ) -> Result<(ProviderId, RpcResponse<Option<Account>>), HedgedError> {
280        let pk = *pubkey;
281
282        let (id, resp) = self
283            .hedged_call(move |client| {
284                let pk = pk;
285                async move { client.get_account_with_commitment(&pk, commitment).await }
286            })
287            .await?;
288
289        Ok((id, resp))
290    }
291
292    /// Gets account data, returning only the response without provider information.
293    pub async fn get_account_any(
294        &self,
295        pubkey: &Pubkey,
296        commitment: CommitmentConfig,
297    ) -> Result<RpcResponse<Option<Account>>, HedgedError> {
298        let (_id, resp) = self.get_account(pubkey, commitment).await?;
299
300        Ok(resp)
301    }
302
303    /// Gets account data with slot freshness validation.
304    ///
305    /// Returns an error if the response slot is older than the specified minimum slot.
306    /// Useful for ensuring data recency in time-sensitive operations.
307    ///
308    /// # Arguments
309    /// * `pubkey` - The account's public key
310    /// * `commitment` - The commitment level for the query
311    /// * `min_slot` - Minimum acceptable slot number
312    pub async fn get_account_fresh(
313        &self,
314        pubkey: &Pubkey,
315        commitment: CommitmentConfig,
316        min_slot: u64,
317    ) -> Result<(ProviderId, RpcResponse<Option<Account>>), HedgedError> {
318        let (id, resp) = self.get_account(pubkey, commitment).await?;
319        if resp.context.slot < min_slot {
320            return Err(HedgedError::AllFailed(vec![(
321                id,
322                ErrorKind::Custom(format!(
323                    "StaleResponse: min_slot {min_slot}, got {}",
324                    resp.context.slot
325                ))
326                .into(),
327            )]));
328        }
329        Ok((id, resp))
330    }
331}