use crate::{poller::PollerBuilder, BatchRequest, ClientBuilder, RpcCall};
use alloy_json_rpc::{Id, Request, RpcRecv, RpcSend};
use alloy_transport::{mock::Asserter, BoxTransport, IntoBoxTransport};
use std::{
borrow::Cow,
ops::Deref,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
},
time::Duration,
};
use tower::{layer::util::Identity, ServiceBuilder};
pub type WeakClient = Weak<RpcClientInner>;
pub type ClientRef<'a> = &'a RpcClientInner;
pub type NoParams = [(); 0];
#[cfg(feature = "pubsub")]
type MaybePubsub = Option<alloy_pubsub::PubSubFrontend>;
#[derive(Debug)]
pub struct RpcClient(Arc<RpcClientInner>);
impl Clone for RpcClient {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl RpcClient {
pub const fn builder() -> ClientBuilder<Identity> {
ClientBuilder { builder: ServiceBuilder::new() }
}
}
impl RpcClient {
pub fn new(t: impl IntoBoxTransport, is_local: bool) -> Self {
Self::new_maybe_pubsub(
t,
is_local,
#[cfg(feature = "pubsub")]
None,
)
}
pub fn mocked(asserter: Asserter) -> Self {
Self::new(alloy_transport::mock::MockTransport::new(asserter), true)
}
#[cfg(feature = "reqwest")]
pub fn new_http(url: reqwest::Url) -> Self {
let http = alloy_transport_http::Http::new(url);
let is_local = http.guess_local();
Self::new(http, is_local)
}
#[cfg(feature = "reqwest")]
pub fn new_http_with_client(client: reqwest::Client, url: reqwest::Url) -> Self {
let http = alloy_transport_http::Http::with_client(client, url);
let is_local = http.guess_local();
Self::new(http, is_local)
}
fn new_maybe_pubsub(
t: impl IntoBoxTransport,
is_local: bool,
#[cfg(feature = "pubsub")] pubsub: MaybePubsub,
) -> Self {
Self(Arc::new(RpcClientInner::new_maybe_pubsub(
t,
is_local,
#[cfg(feature = "pubsub")]
pubsub,
)))
}
pub(crate) fn new_layered<F, T, R>(is_local: bool, main_transport: T, layer: F) -> Self
where
F: FnOnce(T) -> R,
T: IntoBoxTransport,
R: IntoBoxTransport,
{
#[cfg(feature = "pubsub")]
{
let t = main_transport.clone().into_box_transport();
let maybe_pubsub = t.as_any().downcast_ref::<alloy_pubsub::PubSubFrontend>().cloned();
Self::new_maybe_pubsub(layer(main_transport), is_local, maybe_pubsub)
}
#[cfg(not(feature = "pubsub"))]
Self::new(layer(main_transport), is_local)
}
pub fn from_inner(inner: RpcClientInner) -> Self {
Self(Arc::new(inner))
}
pub const fn inner(&self) -> &Arc<RpcClientInner> {
&self.0
}
pub fn into_inner(self) -> Arc<RpcClientInner> {
self.0
}
pub fn get_weak(&self) -> WeakClient {
Arc::downgrade(&self.0)
}
pub fn get_ref(&self) -> ClientRef<'_> {
&self.0
}
pub fn with_poll_interval(self, poll_interval: Duration) -> Self {
self.inner().set_poll_interval(poll_interval);
self
}
pub fn prepare_static_poller<Params, Resp>(
&self,
method: impl Into<Cow<'static, str>>,
params: Params,
) -> PollerBuilder<Params, Resp>
where
Params: RpcSend + 'static,
Resp: RpcRecv + Clone,
{
PollerBuilder::new(self.get_weak(), method, params)
}
#[inline]
pub fn new_batch(&self) -> BatchRequest<'_> {
BatchRequest::new(&self.0)
}
}
impl Deref for RpcClient {
type Target = RpcClientInner;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct RpcClientInner {
pub(crate) transport: BoxTransport,
#[cfg(feature = "pubsub")]
pub(crate) pubsub: MaybePubsub,
pub(crate) is_local: bool,
pub(crate) id: AtomicU64,
pub(crate) poll_interval: AtomicU64,
}
impl RpcClientInner {
#[inline]
pub fn new(t: impl IntoBoxTransport, is_local: bool) -> Self {
Self {
transport: t.into_box_transport(),
#[cfg(feature = "pubsub")]
pubsub: None,
is_local,
id: AtomicU64::new(0),
poll_interval: if is_local { AtomicU64::new(250) } else { AtomicU64::new(7000) },
}
}
pub(crate) fn new_maybe_pubsub(
t: impl IntoBoxTransport,
is_local: bool,
#[cfg(feature = "pubsub")] pubsub: MaybePubsub,
) -> Self {
Self {
#[cfg(feature = "pubsub")]
pubsub,
..Self::new(t, is_local)
}
}
#[inline]
pub fn with_id(self, id: u64) -> Self {
Self { id: AtomicU64::new(id), ..self }
}
pub fn poll_interval(&self) -> Duration {
Duration::from_millis(self.poll_interval.load(Ordering::Relaxed))
}
pub fn set_poll_interval(&self, poll_interval: Duration) {
self.poll_interval.store(poll_interval.as_millis() as u64, Ordering::Relaxed);
}
#[inline]
pub const fn transport(&self) -> &BoxTransport {
&self.transport
}
#[inline]
pub const fn transport_mut(&mut self) -> &mut BoxTransport {
&mut self.transport
}
#[inline]
pub fn into_transport(self) -> BoxTransport {
self.transport
}
#[cfg(feature = "pubsub")]
#[inline]
#[track_caller]
pub fn pubsub_frontend(&self) -> Option<&alloy_pubsub::PubSubFrontend> {
if let Some(pubsub) = &self.pubsub {
return Some(pubsub);
}
self.transport.as_any().downcast_ref::<alloy_pubsub::PubSubFrontend>()
}
#[cfg(feature = "pubsub")]
#[inline]
#[track_caller]
pub fn expect_pubsub_frontend(&self) -> &alloy_pubsub::PubSubFrontend {
self.pubsub_frontend().expect("called pubsub_frontend on a non-pubsub transport")
}
#[inline]
pub fn make_request<Params: RpcSend>(
&self,
method: impl Into<Cow<'static, str>>,
params: Params,
) -> Request<Params> {
Request::new(method, self.next_id(), params)
}
#[inline]
pub const fn is_local(&self) -> bool {
self.is_local
}
#[inline]
pub const fn set_local(&mut self, is_local: bool) {
self.is_local = is_local;
}
#[inline]
fn increment_id(&self) -> u64 {
self.id.fetch_add(1, Ordering::Relaxed)
}
#[inline]
pub fn next_id(&self) -> Id {
self.increment_id().into()
}
#[doc(alias = "prepare")]
pub fn request<Params: RpcSend, Resp: RpcRecv>(
&self,
method: impl Into<Cow<'static, str>>,
params: Params,
) -> RpcCall<Params, Resp> {
let request = self.make_request(method, params);
RpcCall::new(request, self.transport.clone())
}
pub fn request_noparams<Resp: RpcRecv>(
&self,
method: impl Into<Cow<'static, str>>,
) -> RpcCall<NoParams, Resp> {
self.request(method, [])
}
}
#[cfg(feature = "pubsub")]
mod pubsub_impl {
use super::*;
use alloy_pubsub::{PubSubConnect, RawSubscription, Subscription};
use alloy_transport::TransportResult;
impl RpcClientInner {
pub async fn get_raw_subscription(&self, id: alloy_primitives::B256) -> RawSubscription {
self.expect_pubsub_frontend().get_subscription(id).await.unwrap()
}
pub async fn get_subscription<T: serde::de::DeserializeOwned>(
&self,
id: alloy_primitives::B256,
) -> Subscription<T> {
Subscription::from(self.get_raw_subscription(id).await)
}
}
impl RpcClient {
pub async fn connect_pubsub<C: PubSubConnect>(connect: C) -> TransportResult<Self> {
ClientBuilder::default().pubsub(connect).await
}
#[track_caller]
pub fn channel_size(&self) -> usize {
self.expect_pubsub_frontend().channel_size()
}
#[track_caller]
pub fn set_channel_size(&self, size: usize) {
self.expect_pubsub_frontend().set_channel_size(size)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use similar_asserts::assert_eq;
#[test]
fn test_client_with_poll_interval() {
let poll_interval = Duration::from_millis(5_000);
let client = RpcClient::new_http(reqwest::Url::parse("http://localhost").unwrap())
.with_poll_interval(poll_interval);
assert_eq!(client.poll_interval(), poll_interval);
}
}