use std::time::Duration;
use alloy_network::AnyNetwork;
use alloy_provider::RootProvider;
use alloy_rpc_client::{ClientBuilder, RpcClient};
use crate::errors::RpcError;
use crate::transport::RateLimitLayer;
use super::config::ProviderConfig;
use super::http_client::reqwest_client_with_timeout;
use super::AnyHttpProvider;
#[track_caller]
pub(super) fn rate_limit_layer_for(
rate_limit_per_second: Option<u32>,
min_delay: Option<Duration>,
) -> Result<Option<RateLimitLayer>, RpcError> {
match (rate_limit_per_second, min_delay) {
(Some(rps), Some(delay)) => Err(RpcError::ConflictingRateLimit {
rate_limit_per_second: rps,
min_delay: delay,
}),
(Some(rps), None) => Ok(Some(RateLimitLayer::per_second(rps))),
(None, Some(delay)) => Ok(Some(RateLimitLayer::with_min_delay(delay))),
(None, None) => Ok(None),
}
}
pub(super) fn build_http_client(config: ProviderConfig) -> Result<RpcClient, RpcError> {
let url: url::Url = config
.url
.parse()
.map_err(|e| RpcError::ProviderUrlInvalid(format!("{e}")))?;
let builder = ClientBuilder::default();
let layer = rate_limit_layer_for(config.rate_limit_per_second, config.min_delay)?;
let client = match (layer, config.timeout) {
(Some(layer), Some(timeout)) => builder
.layer(layer)
.http_with_client(reqwest_client_with_timeout(timeout)?, url),
(Some(layer), None) => builder.layer(layer).http(url),
(None, Some(timeout)) => {
builder.http_with_client(reqwest_client_with_timeout(timeout)?, url)
}
(None, None) => builder.http(url),
};
Ok(client)
}
pub fn create_http_provider(config: ProviderConfig) -> Result<AnyHttpProvider, RpcError> {
Ok(RootProvider::<AnyNetwork>::new(build_http_client(config)?))
}
#[cfg(feature = "ws")]
pub async fn create_ws_provider(
config: ProviderConfig,
) -> Result<alloy_provider::RootProvider<AnyNetwork>, RpcError> {
use alloy_provider::WsConnect;
if config.timeout.is_some() {
tracing::warn!(
"ProviderConfig::timeout is ignored for WebSocket providers; \
alloy_provider::WsConnect does not expose a per-request timeout knob"
);
}
let layer = rate_limit_layer_for(config.rate_limit_per_second, config.min_delay)?;
let ws = WsConnect::new(&config.url);
let builder = ClientBuilder::default();
let client = match layer {
Some(layer) => builder
.layer(layer)
.ws(ws)
.await
.map_err(|e| RpcError::ProviderConnectionFailed(e.to_string()))?,
None => builder
.ws(ws)
.await
.map_err(|e| RpcError::ProviderConnectionFailed(e.to_string()))?,
};
Ok(RootProvider::<AnyNetwork>::new(client))
}
pub fn create_typed_http_provider<N>(
config: ProviderConfig,
) -> Result<alloy_provider::RootProvider<N>, RpcError>
where
N: alloy_network::Network,
{
Ok(RootProvider::<N>::new(build_http_client(config)?))
}
pub fn simple_http_provider(url: &str) -> Result<AnyHttpProvider, RpcError> {
create_http_provider(ProviderConfig::new(url))
}
pub fn rate_limited_http_provider(
url: &str,
requests_per_second: u32,
) -> Result<AnyHttpProvider, RpcError> {
create_http_provider(ProviderConfig::new(url).with_rate_limit(requests_per_second))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_http_provider_invalid_url() {
let result = create_http_provider(ProviderConfig::new("not-a-valid-url"));
assert!(result.is_err());
}
#[test]
fn test_create_http_provider_valid_url() {
let result = create_http_provider(ProviderConfig::new("http://localhost:8545"));
assert!(result.is_ok());
}
#[test]
fn test_create_http_provider_with_rate_limit() {
let result =
create_http_provider(ProviderConfig::new("http://localhost:8545").with_rate_limit(10));
assert!(result.is_ok());
}
#[test]
fn test_simple_http_provider() {
let result = simple_http_provider("http://localhost:8545");
assert!(result.is_ok());
}
#[test]
fn test_rate_limited_http_provider() {
let result = rate_limited_http_provider("http://localhost:8545", 10);
assert!(result.is_ok());
}
#[test]
fn test_create_typed_http_provider() {
use alloy_network::Ethereum;
let result =
create_typed_http_provider::<Ethereum>(ProviderConfig::new("http://localhost:8545"));
assert!(result.is_ok());
}
#[test]
fn typed_http_provider_accepts_full_dispatch_matrix() {
use alloy_network::Ethereum;
use std::time::Duration;
let url = "http://localhost:8545";
create_typed_http_provider::<Ethereum>(ProviderConfig::new(url)).expect("no rate limiting");
create_typed_http_provider::<Ethereum>(ProviderConfig::new(url).with_rate_limit(10))
.expect("rate_limit_per_second");
create_typed_http_provider::<Ethereum>(
ProviderConfig::new(url).with_min_delay(Duration::from_millis(250)),
)
.expect("min_delay only");
}
#[test]
fn typed_http_provider_rejects_both_rate_limit_axes() {
use alloy_network::Ethereum;
use std::time::Duration;
let url = "http://localhost:8545";
let err = create_typed_http_provider::<Ethereum>(
ProviderConfig::new(url)
.with_rate_limit(5)
.with_min_delay(Duration::from_millis(250)),
)
.expect_err("both axes must be rejected");
match err {
RpcError::ConflictingRateLimit {
rate_limit_per_second,
min_delay,
} => {
assert_eq!(rate_limit_per_second, 5);
assert_eq!(min_delay, Duration::from_millis(250));
}
other => panic!("expected ConflictingRateLimit, got {other:?}"),
}
}
#[test]
fn shared_builder_accepts_full_dispatch_matrix() {
use std::time::Duration;
let url = "http://localhost:8545";
build_http_client(ProviderConfig::new(url)).expect("no rate limiting");
build_http_client(ProviderConfig::new(url).with_rate_limit(10))
.expect("rate_limit_per_second");
build_http_client(ProviderConfig::new(url).with_min_delay(Duration::from_millis(250)))
.expect("min_delay only");
}
#[test]
fn shared_builder_rejects_both_rate_limit_axes() {
use std::time::Duration;
let url = "http://localhost:8545";
let err = build_http_client(
ProviderConfig::new(url)
.with_rate_limit(5)
.with_min_delay(Duration::from_millis(250)),
)
.expect_err("both axes must be rejected");
assert!(
matches!(err, RpcError::ConflictingRateLimit { .. }),
"expected ConflictingRateLimit, got {err:?}"
);
}
#[test]
fn rate_limit_layer_for_covers_full_matrix() {
use std::time::Duration;
assert!(
rate_limit_layer_for(None, None)
.expect("unset axes must not error")
.is_none(),
"both axes unset must produce no layer"
);
let rps_only = rate_limit_layer_for(Some(10), None)
.expect("rate_limit_per_second alone must not error")
.expect("rate_limit_per_second alone must produce a layer");
assert!(
format!("{rps_only:?}").contains("capacity: 10"),
"rate_limit_per_second arm must produce a per_second layer with the given budget; got {rps_only:?}"
);
let delay_only = rate_limit_layer_for(None, Some(Duration::from_millis(250)))
.expect("min_delay alone must not error")
.expect("min_delay alone must produce a layer");
assert!(
format!("{delay_only:?}").contains("capacity: 1"),
"min_delay arm must produce a single-token (capacity = 1) layer; got {delay_only:?}"
);
let err = rate_limit_layer_for(Some(5), Some(Duration::from_millis(250)))
.expect_err("both axes set must be rejected");
match err {
RpcError::ConflictingRateLimit {
rate_limit_per_second,
min_delay,
} => {
assert_eq!(rate_limit_per_second, 5);
assert_eq!(min_delay, Duration::from_millis(250));
}
other => panic!("expected ConflictingRateLimit, got {other:?}"),
}
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn create_ws_provider_accepts_full_dispatch_matrix() {
use std::time::Duration;
let url = "not-a-valid-ws-url";
assert!(
create_ws_provider(ProviderConfig::new(url)).await.is_err(),
"no rate limiting"
);
assert!(
create_ws_provider(ProviderConfig::new(url).with_rate_limit(10))
.await
.is_err(),
"rate_limit_per_second"
);
assert!(
create_ws_provider(ProviderConfig::new(url).with_min_delay(Duration::from_millis(250)))
.await
.is_err(),
"min_delay only"
);
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn create_ws_provider_rejects_both_rate_limit_axes() {
use std::time::Duration;
let url = "not-a-valid-ws-url";
let err = create_ws_provider(
ProviderConfig::new(url)
.with_rate_limit(5)
.with_min_delay(Duration::from_millis(250)),
)
.await
.expect_err("both axes must be rejected");
assert!(
matches!(err, RpcError::ConflictingRateLimit { .. }),
"expected ConflictingRateLimit, got {err:?}"
);
}
}