use alloy_chains::NamedChain;
use alloy_network::AnyNetwork;
use alloy_provider::{ProviderBuilder, RootProvider};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
use crate::errors::RpcError;
use crate::transport::RateLimitLayer;
pub type PooledProvider = Arc<RootProvider<AnyNetwork>>;
#[derive(Debug, Default)]
pub struct ProviderPool {
providers: RwLock<HashMap<NamedChain, PooledProvider>>,
default_rate_limit: Option<u32>,
}
impl ProviderPool {
#[must_use]
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
default_rate_limit: None,
}
}
#[must_use]
pub fn with_defaults(rate_limit: Option<u32>) -> Self {
Self {
providers: RwLock::new(HashMap::new()),
default_rate_limit: rate_limit,
}
}
pub fn from_endpoints(
endpoints: Vec<ChainEndpoint>,
rate_limit: Option<u32>,
) -> Result<Self, RpcError> {
let pool = Self::with_defaults(rate_limit);
for endpoint in endpoints {
pool.add(
endpoint.chain,
&endpoint.url,
endpoint.rate_limit.or(rate_limit),
)?;
}
Ok(pool)
}
pub fn add(
&self,
chain: NamedChain,
url: &str,
rate_limit: Option<u32>,
) -> Result<(), RpcError> {
let provider = create_pooled_provider(url, rate_limit.or(self.default_rate_limit))?;
let mut providers = self.providers.write().map_err(|_| {
RpcError::ProviderConnectionFailed("Provider pool lock poisoned".to_string())
})?;
if providers.contains_key(&chain) {
debug!(chain = ?chain, "Replacing existing provider");
} else {
info!(chain = ?chain, url = url, "Added provider to pool");
}
providers.insert(chain, Arc::new(provider));
Ok(())
}
#[must_use]
pub fn get(&self, chain: NamedChain) -> Option<PooledProvider> {
self.providers
.read()
.ok()
.and_then(|providers| providers.get(&chain).cloned())
}
pub fn get_or_add(
&self,
chain: NamedChain,
url: &str,
rate_limit: Option<u32>,
) -> Result<PooledProvider, RpcError> {
if let Some(provider) = self.get(chain) {
return Ok(provider);
}
self.add(chain, url, rate_limit)?;
self.get(chain).ok_or_else(|| {
RpcError::ProviderConnectionFailed("Failed to retrieve newly added provider".into())
})
}
pub fn remove(&self, chain: NamedChain) -> Option<PooledProvider> {
self.providers
.write()
.ok()
.and_then(|mut providers| providers.remove(&chain))
}
#[must_use]
pub fn contains(&self, chain: NamedChain) -> bool {
self.providers
.read()
.ok()
.is_some_and(|providers| providers.contains_key(&chain))
}
#[must_use]
pub fn len(&self) -> usize {
self.providers
.read()
.map(|providers| providers.len())
.unwrap_or(0)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn chains(&self) -> Vec<NamedChain> {
self.providers
.read()
.map(|providers| providers.keys().copied().collect())
.unwrap_or_default()
}
pub fn clear(&self) {
if let Ok(mut providers) = self.providers.write() {
providers.clear();
info!("Cleared all providers from pool");
}
}
}
#[derive(Default)]
pub struct ProviderPoolBuilder {
endpoints: Vec<ChainEndpoint>,
default_rate_limit: Option<u32>,
}
impl ProviderPoolBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn add_chain(mut self, chain: NamedChain, url: &str) -> Self {
self.endpoints.push(ChainEndpoint {
chain,
url: url.to_string(),
rate_limit: None,
});
self
}
#[must_use]
pub fn add_chain_with_rate_limit(
mut self,
chain: NamedChain,
url: &str,
rate_limit: u32,
) -> Self {
self.endpoints.push(ChainEndpoint {
chain,
url: url.to_string(),
rate_limit: Some(rate_limit),
});
self
}
#[must_use]
pub fn with_rate_limit(mut self, requests_per_second: u32) -> Self {
self.default_rate_limit = Some(requests_per_second);
self
}
pub fn build(self) -> Result<ProviderPool, RpcError> {
let pool = ProviderPool::with_defaults(self.default_rate_limit);
for endpoint in self.endpoints {
pool.add(
endpoint.chain,
&endpoint.url,
endpoint.rate_limit.or(self.default_rate_limit),
)?;
}
Ok(pool)
}
}
#[derive(Debug, Clone)]
pub struct ChainEndpoint {
pub chain: NamedChain,
pub url: String,
pub rate_limit: Option<u32>,
}
impl ChainEndpoint {
#[must_use]
pub fn new(chain: NamedChain, url: impl Into<String>) -> Self {
Self {
chain,
url: url.into(),
rate_limit: None,
}
}
#[must_use]
pub fn with_rate_limit(mut self, rate_limit: u32) -> Self {
self.rate_limit = Some(rate_limit);
self
}
#[must_use]
pub fn mainnet(url: impl Into<String>) -> Self {
Self::new(NamedChain::Mainnet, url)
}
#[must_use]
pub fn base(url: impl Into<String>) -> Self {
Self::new(NamedChain::Base, url)
}
#[must_use]
pub fn optimism(url: impl Into<String>) -> Self {
Self::new(NamedChain::Optimism, url)
}
#[must_use]
pub fn arbitrum(url: impl Into<String>) -> Self {
Self::new(NamedChain::Arbitrum, url)
}
#[must_use]
pub fn polygon(url: impl Into<String>) -> Self {
Self::new(NamedChain::Polygon, url)
}
#[must_use]
pub fn sepolia(url: impl Into<String>) -> Self {
Self::new(NamedChain::Sepolia, url)
}
}
fn create_pooled_provider(
url: &str,
rate_limit: Option<u32>,
) -> Result<RootProvider<AnyNetwork>, RpcError> {
let parsed_url: url::Url = url.parse().map_err(|e| {
warn!(url = url, error = ?e, "Invalid provider URL");
RpcError::ProviderUrlInvalid(url.to_string())
})?;
let client = match rate_limit {
Some(limit) => alloy_rpc_client::ClientBuilder::default()
.layer(RateLimitLayer::per_second(limit))
.http(parsed_url),
None => alloy_rpc_client::ClientBuilder::default().http(parsed_url),
};
let provider = ProviderBuilder::new()
.disable_recommended_fillers()
.network::<AnyNetwork>()
.connect_client(client);
Ok(provider)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_new() {
let pool = ProviderPool::new();
assert!(pool.is_empty());
assert_eq!(pool.len(), 0);
}
#[test]
fn test_pool_with_defaults() {
let pool = ProviderPool::with_defaults(Some(10));
assert_eq!(pool.default_rate_limit, Some(10));
}
#[test]
fn test_chain_endpoint_constructors() {
let endpoint = ChainEndpoint::mainnet("https://eth.llamarpc.com");
assert_eq!(endpoint.chain, NamedChain::Mainnet);
assert_eq!(endpoint.url, "https://eth.llamarpc.com");
assert!(endpoint.rate_limit.is_none());
let endpoint = ChainEndpoint::base("https://mainnet.base.org").with_rate_limit(5);
assert_eq!(endpoint.chain, NamedChain::Base);
assert_eq!(endpoint.rate_limit, Some(5));
}
#[test]
fn test_pool_builder() {
let builder = ProviderPoolBuilder::new()
.add_chain(NamedChain::Mainnet, "https://eth.llamarpc.com")
.add_chain_with_rate_limit(NamedChain::Base, "https://mainnet.base.org", 5)
.with_rate_limit(10);
assert_eq!(builder.endpoints.len(), 2);
assert_eq!(builder.default_rate_limit, Some(10));
}
#[test]
fn test_pool_contains_and_chains() {
let pool = ProviderPool::new();
let result = pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None);
assert!(result.is_ok());
assert!(pool.contains(NamedChain::Mainnet));
assert!(!pool.contains(NamedChain::Base));
let chains = pool.chains();
assert_eq!(chains.len(), 1);
assert!(chains.contains(&NamedChain::Mainnet));
}
#[test]
fn test_pool_remove() {
let pool = ProviderPool::new();
pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None)
.unwrap();
assert!(pool.contains(NamedChain::Mainnet));
let removed = pool.remove(NamedChain::Mainnet);
assert!(removed.is_some());
assert!(!pool.contains(NamedChain::Mainnet));
let removed_again = pool.remove(NamedChain::Mainnet);
assert!(removed_again.is_none());
}
#[test]
fn test_pool_clear() {
let pool = ProviderPool::new();
pool.add(NamedChain::Mainnet, "https://eth.llamarpc.com", None)
.unwrap();
pool.add(NamedChain::Base, "https://mainnet.base.org", None)
.unwrap();
assert_eq!(pool.len(), 2);
pool.clear();
assert!(pool.is_empty());
}
#[test]
fn test_invalid_url() {
let pool = ProviderPool::new();
let result = pool.add(NamedChain::Mainnet, "not a valid url", None);
assert!(result.is_err());
}
}