use alloy_chains::NamedChain;
use alloy_network::AnyNetwork;
use alloy_provider::RootProvider;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::config::policy::RpcPolicy;
use crate::errors::RpcError;
use super::config::ProviderConfig;
use super::factory::build_http_client;
pub type PooledProvider = Arc<RootProvider<AnyNetwork>>;
#[derive(Debug, Default)]
pub struct ProviderPool {
providers: RwLock<HashMap<NamedChain, PooledProvider>>,
default_rate_limit: Option<u32>,
default_timeout: Option<Duration>,
}
impl ProviderPool {
#[must_use]
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
default_rate_limit: None,
default_timeout: None,
}
}
#[must_use]
pub fn with_defaults(rate_limit: Option<u32>) -> Self {
Self {
providers: RwLock::new(HashMap::new()),
default_rate_limit: rate_limit,
default_timeout: None,
}
}
#[must_use]
pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = Some(timeout);
self
}
pub fn from_endpoints(
endpoints: Vec<ChainEndpoint>,
rate_limit: Option<u32>,
) -> Result<Self, RpcError> {
let pool = Self::with_defaults(rate_limit);
for endpoint in endpoints {
pool.add_endpoint(&endpoint)?;
}
Ok(pool)
}
pub fn add(
&self,
chain: NamedChain,
url: &str,
rate_limit: Option<u32>,
) -> Result<(), RpcError> {
self.add_inner(chain, url, rate_limit, None, None)
}
pub fn add_endpoint(&self, endpoint: &ChainEndpoint) -> Result<(), RpcError> {
self.add_inner(
endpoint.chain,
&endpoint.url,
endpoint.rate_limit,
endpoint.timeout,
endpoint.min_delay,
)
}
fn add_inner(
&self,
chain: NamedChain,
url: &str,
rate_limit: Option<u32>,
timeout: Option<Duration>,
min_delay: Option<Duration>,
) -> Result<(), RpcError> {
let provider = create_pooled_provider(
url,
rate_limit.or(self.default_rate_limit),
timeout.or(self.default_timeout),
min_delay,
)?;
let mut providers = self.providers.write().map_err(|_| {
RpcError::ProviderConnectionFailed("Provider pool lock poisoned".to_string())
})?;
if providers.contains_key(&chain) {
debug!(chain = ?chain, "Replacing existing provider");
} else {
info!(chain = ?chain, url = url, "Added provider to pool");
}
providers.insert(chain, Arc::new(provider));
Ok(())
}
#[must_use]
pub fn get(&self, chain: NamedChain) -> Option<PooledProvider> {
self.providers
.read()
.ok()
.and_then(|providers| providers.get(&chain).cloned())
}
pub fn get_or_add(
&self,
chain: NamedChain,
url: &str,
rate_limit: Option<u32>,
) -> Result<PooledProvider, RpcError> {
if let Some(provider) = self.get(chain) {
return Ok(provider);
}
self.add(chain, url, rate_limit)?;
self.get(chain).ok_or_else(|| {
RpcError::ProviderConnectionFailed("Failed to retrieve newly added provider".into())
})
}
pub fn remove(&self, chain: NamedChain) -> Option<PooledProvider> {
self.providers
.write()
.ok()
.and_then(|mut providers| providers.remove(&chain))
}
#[must_use]
pub fn contains(&self, chain: NamedChain) -> bool {
self.providers
.read()
.ok()
.is_some_and(|providers| providers.contains_key(&chain))
}
#[must_use]
pub fn len(&self) -> usize {
self.providers
.read()
.map(|providers| providers.len())
.unwrap_or(0)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn chains(&self) -> Vec<NamedChain> {
self.providers
.read()
.map(|providers| providers.keys().copied().collect())
.unwrap_or_default()
}
pub fn clear(&self) {
if let Ok(mut providers) = self.providers.write() {
providers.clear();
info!("Cleared all providers from pool");
}
}
}
#[derive(Default)]
pub struct ProviderPoolBuilder {
endpoints: Vec<ChainEndpoint>,
default_rate_limit: Option<u32>,
default_timeout: Option<Duration>,
rpc_policy_timeouts: HashMap<NamedChain, Duration>,
rpc_policy_min_delays: HashMap<NamedChain, Duration>,
}
impl ProviderPoolBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn add_chain(mut self, chain: NamedChain, url: &str) -> Self {
self.endpoints.push(ChainEndpoint::new(chain, url));
self
}
#[must_use]
pub fn add_chain_with_rate_limit(
mut self,
chain: NamedChain,
url: &str,
rate_limit: u32,
) -> Self {
self.endpoints
.push(ChainEndpoint::new(chain, url).with_rate_limit(rate_limit));
self
}
#[must_use]
pub fn add_chain_with_timeout(
mut self,
chain: NamedChain,
url: &str,
timeout: Duration,
) -> Self {
self.endpoints
.push(ChainEndpoint::new(chain, url).with_timeout(timeout));
self
}
#[must_use]
pub fn add_endpoint(mut self, endpoint: ChainEndpoint) -> Self {
self.endpoints.push(endpoint);
self
}
#[must_use]
pub fn with_rate_limit(mut self, requests_per_second: u32) -> Self {
self.default_rate_limit = Some(requests_per_second);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = Some(timeout);
self
}
#[must_use]
pub fn with_rpc_policy<P: RpcPolicy>(mut self, policy: &P) -> Self {
for endpoint in &self.endpoints {
let cfg = policy.rpc_config(endpoint.chain);
self.rpc_policy_timeouts
.insert(endpoint.chain, cfg.rpc_timeout);
if let Some(delay) = cfg.rate_limit_delay {
self.rpc_policy_min_delays.insert(endpoint.chain, delay);
}
}
self
}
pub fn build(self) -> Result<ProviderPool, RpcError> {
let pool = ProviderPool::with_defaults(self.default_rate_limit);
let pool = match self.default_timeout {
Some(t) => pool.with_default_timeout(t),
None => pool,
};
for endpoint in &self.endpoints {
let policy_timeout = self.rpc_policy_timeouts.get(&endpoint.chain).copied();
let policy_min_delay = self.rpc_policy_min_delays.get(&endpoint.chain).copied();
let effective = ChainEndpoint {
chain: endpoint.chain,
url: endpoint.url.clone(),
rate_limit: endpoint.rate_limit.or(self.default_rate_limit),
timeout: endpoint.timeout.or(policy_timeout),
min_delay: endpoint.min_delay.or(policy_min_delay),
};
pool.add_endpoint(&effective)?;
}
Ok(pool)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ChainEndpoint {
pub chain: NamedChain,
pub url: String,
pub rate_limit: Option<u32>,
pub timeout: Option<Duration>,
pub min_delay: Option<Duration>,
}
impl ChainEndpoint {
#[must_use]
pub fn new(chain: NamedChain, url: impl Into<String>) -> Self {
Self {
chain,
url: url.into(),
rate_limit: None,
timeout: None,
min_delay: None,
}
}
#[must_use]
pub fn with_rate_limit(mut self, rate_limit: u32) -> Self {
self.rate_limit = Some(rate_limit);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
#[track_caller]
pub fn with_min_delay(mut self, delay: Duration) -> Self {
super::assert_nonzero_min_delay(delay);
self.min_delay = Some(delay);
self
}
#[must_use]
pub fn mainnet(url: impl Into<String>) -> Self {
Self::new(NamedChain::Mainnet, url)
}
#[must_use]
pub fn base(url: impl Into<String>) -> Self {
Self::new(NamedChain::Base, url)
}
#[must_use]
pub fn optimism(url: impl Into<String>) -> Self {
Self::new(NamedChain::Optimism, url)
}
#[must_use]
pub fn arbitrum(url: impl Into<String>) -> Self {
Self::new(NamedChain::Arbitrum, url)
}
#[must_use]
pub fn polygon(url: impl Into<String>) -> Self {
Self::new(NamedChain::Polygon, url)
}
#[must_use]
pub fn sepolia(url: impl Into<String>) -> Self {
Self::new(NamedChain::Sepolia, url)
}
}
fn create_pooled_provider(
url: &str,
rate_limit: Option<u32>,
timeout: Option<Duration>,
min_delay: Option<Duration>,
) -> Result<RootProvider<AnyNetwork>, RpcError> {
let mut config = ProviderConfig::new(url).with_rate_limit_opt(rate_limit);
if let Some(t) = timeout {
config = config.with_timeout(t);
}
if let Some(d) = min_delay {
config = config.with_min_delay(d);
}
let client = build_http_client(config).inspect_err(|e| {
warn!(url = url, error = ?e, "Failed to build pooled provider");
})?;
Ok(RootProvider::<AnyNetwork>::new(client))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_new() {
let pool = ProviderPool::new();
assert!(pool.is_empty());
assert_eq!(pool.len(), 0);
}
#[test]
fn test_pool_with_defaults() {
let pool = ProviderPool::with_defaults(Some(10));
assert_eq!(pool.default_rate_limit, Some(10));
}
#[test]
fn test_chain_endpoint_constructors() {
let endpoint = ChainEndpoint::mainnet("https://eth.llamarpc.com");
assert_eq!(endpoint.chain, NamedChain::Mainnet);
assert_eq!(endpoint.url, "https://eth.llamarpc.com");
assert!(endpoint.rate_limit.is_none());
let endpoint = ChainEndpoint::base("https://mainnet.base.org").with_rate_limit(5);
assert_eq!(endpoint.chain, NamedChain::Base);
assert_eq!(endpoint.rate_limit, Some(5));
}
#[test]
fn test_pool_builder() {
let builder = ProviderPoolBuilder::new()
.add_chain(NamedChain::Mainnet, "https://eth.llamarpc.com")
.add_chain_with_rate_limit(NamedChain::Base, "https://mainnet.base.org", 5)
.with_rate_limit(10);
assert_eq!(builder.endpoints.len(), 2);
assert_eq!(builder.default_rate_limit, Some(10));
}
#[test]
fn test_pool_contains_and_chains() {
let pool = ProviderPool::new();
let result = pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None);
assert!(result.is_ok());
assert!(pool.contains(NamedChain::Mainnet));
assert!(!pool.contains(NamedChain::Base));
let chains = pool.chains();
assert_eq!(chains.len(), 1);
assert!(chains.contains(&NamedChain::Mainnet));
}
#[test]
fn test_pool_remove() {
let pool = ProviderPool::new();
pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None)
.unwrap();
assert!(pool.contains(NamedChain::Mainnet));
let removed = pool.remove(NamedChain::Mainnet);
assert!(removed.is_some());
assert!(!pool.contains(NamedChain::Mainnet));
let removed_again = pool.remove(NamedChain::Mainnet);
assert!(removed_again.is_none());
}
#[test]
fn test_pool_clear() {
let pool = ProviderPool::new();
pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None)
.unwrap();
pool.add(NamedChain::Base, "https://mainnet.base.org", None)
.unwrap();
assert_eq!(pool.len(), 2);
pool.clear();
assert!(pool.is_empty());
}
#[test]
fn test_invalid_url() {
let pool = ProviderPool::new();
let result = pool.add(NamedChain::Mainnet, "not a valid url", None);
assert!(result.is_err());
}
#[test]
#[should_panic(expected = "min_delay must be > 0")]
fn test_chain_endpoint_with_min_delay_rejects_zero() {
let _ = ChainEndpoint::new(NamedChain::Mainnet, "https://eth.llamarpc.com")
.with_min_delay(Duration::ZERO);
}
#[test]
fn with_rpc_policy_only_covers_chains_added_before_it() {
use crate::config::policy::RpcConfig;
struct FixedPolicy {
timeout: Duration,
rate_limit_delay: Option<Duration>,
}
impl RpcPolicy for FixedPolicy {
fn rpc_config(&self, _: NamedChain) -> RpcConfig {
RpcConfig {
rpc_timeout: self.timeout,
rate_limit_delay: self.rate_limit_delay,
}
}
}
let policy = FixedPolicy {
timeout: Duration::from_secs(5),
rate_limit_delay: Some(Duration::from_millis(250)),
};
let before = ProviderPoolBuilder::new()
.add_chain(NamedChain::Mainnet, "http://localhost:8545")
.with_rpc_policy(&policy);
assert_eq!(
before
.rpc_policy_timeouts
.get(&NamedChain::Mainnet)
.copied(),
Some(Duration::from_secs(5)),
"policy timeout must apply when chain is added before with_rpc_policy",
);
assert_eq!(
before
.rpc_policy_min_delays
.get(&NamedChain::Mainnet)
.copied(),
Some(Duration::from_millis(250)),
"policy rate_limit_delay must apply when chain is added before with_rpc_policy",
);
let after = ProviderPoolBuilder::new()
.with_rpc_policy(&policy)
.add_chain(NamedChain::Mainnet, "http://localhost:8545");
assert!(
!after.rpc_policy_timeouts.contains_key(&NamedChain::Mainnet),
"policy timeout must not apply when chain is added after with_rpc_policy",
);
assert!(
!after
.rpc_policy_min_delays
.contains_key(&NamedChain::Mainnet),
"policy rate_limit_delay must not apply when chain is added after with_rpc_policy",
);
}
}