alloy_rpc_client/
batch.rs

1use crate::{client::RpcClientInner, ClientRef};
2use alloy_json_rpc::{
3    transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcRecv,
4    RpcSend, SerializedRequest,
5};
6use alloy_primitives::map::HashMap;
7use alloy_transport::{
8    BoxTransport, TransportError, TransportErrorKind, TransportFut, TransportResult,
9};
10use futures::FutureExt;
11use pin_project::pin_project;
12use serde_json::value::RawValue;
13use std::{
14    borrow::Cow,
15    future::{Future, IntoFuture},
16    marker::PhantomData,
17    pin::Pin,
18    task::{
19        self, ready,
20        Poll::{self, Ready},
21    },
22};
23use tokio::sync::oneshot;
24use tower::Service;
25
26pub(crate) type Channel = oneshot::Sender<TransportResult<Box<RawValue>>>;
27pub(crate) type ChannelMap = HashMap<Id, Channel>;
28
29/// A batch JSON-RPC request, used to bundle requests into a single transport
30/// call.
31#[derive(Debug)]
32#[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"]
33pub struct BatchRequest<'a> {
34    /// The transport via which the batch will be sent.
35    transport: ClientRef<'a>,
36
37    /// The requests to be sent.
38    requests: RequestPacket,
39
40    /// The channels to send the responses through.
41    channels: ChannelMap,
42}
43
44/// Awaits a single response for a request that has been included in a batch.
45#[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."]
46#[pin_project]
47#[derive(Debug)]
48pub struct Waiter<Resp, Output = Resp, Map = fn(Resp) -> Output> {
49    #[pin]
50    rx: oneshot::Receiver<TransportResult<Box<RawValue>>>,
51    map: Option<Map>,
52    _resp: PhantomData<fn() -> (Output, Resp)>,
53}
54
55impl<Resp, Output, Map> Waiter<Resp, Output, Map> {
56    /// Map the response to a different type. This is usable for converting
57    /// the response to a more usable type, e.g. changing `U64` to `u64`.
58    ///
59    /// ## Note
60    ///
61    /// Carefully review the rust documentation on [fn pointers] before passing
62    /// them to this function. Unless the pointer is specifically coerced to a
63    /// `fn(_) -> _`, the `NewMap` will be inferred as that function's unique
64    /// type. This can lead to confusing error messages.
65    ///
66    /// [fn pointers]: https://doc.rust-lang.org/std/primitive.fn.html#creating-function-pointers
67    pub fn map_resp<NewOutput, NewMap>(self, map: NewMap) -> Waiter<Resp, NewOutput, NewMap>
68    where
69        NewMap: FnOnce(Resp) -> NewOutput,
70    {
71        Waiter { rx: self.rx, map: Some(map), _resp: PhantomData }
72    }
73}
74
75impl<Resp> From<oneshot::Receiver<TransportResult<Box<RawValue>>>> for Waiter<Resp> {
76    fn from(rx: oneshot::Receiver<TransportResult<Box<RawValue>>>) -> Self {
77        Self { rx, map: Some(std::convert::identity), _resp: PhantomData }
78    }
79}
80
81impl<Resp, Output, Map> std::future::Future for Waiter<Resp, Output, Map>
82where
83    Resp: RpcRecv,
84    Map: FnOnce(Resp) -> Output,
85{
86    type Output = TransportResult<Output>;
87
88    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
89        let this = self.get_mut();
90
91        match ready!(this.rx.poll_unpin(cx)) {
92            Ok(resp) => {
93                let resp: Result<Resp, _> = try_deserialize_ok(resp);
94                Ready(resp.map(this.map.take().expect("polled after completion")))
95            }
96            Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))),
97        }
98    }
99}
100
101#[pin_project::pin_project(project = CallStateProj)]
102#[expect(unnameable_types, missing_debug_implementations)]
103pub enum BatchFuture {
104    Prepared {
105        transport: BoxTransport,
106        requests: RequestPacket,
107        channels: ChannelMap,
108    },
109    AwaitingResponse {
110        channels: ChannelMap,
111        #[pin]
112        fut: TransportFut<'static>,
113    },
114    Complete,
115}
116
117impl<'a> BatchRequest<'a> {
118    /// Create a new batch request.
119    pub fn new(transport: &'a RpcClientInner) -> Self {
120        Self {
121            transport,
122            requests: RequestPacket::Batch(Vec::with_capacity(10)),
123            channels: HashMap::with_capacity_and_hasher(10, Default::default()),
124        }
125    }
126
127    fn push_raw(
128        &mut self,
129        request: SerializedRequest,
130    ) -> oneshot::Receiver<TransportResult<Box<RawValue>>> {
131        let (tx, rx) = oneshot::channel();
132        self.channels.insert(request.id().clone(), tx);
133        self.requests.push(request);
134        rx
135    }
136
137    fn push<Params: RpcSend, Resp: RpcRecv>(
138        &mut self,
139        request: Request<Params>,
140    ) -> TransportResult<Waiter<Resp>> {
141        let ser = request.serialize().map_err(TransportError::ser_err)?;
142        Ok(self.push_raw(ser).into())
143    }
144
145    /// Add a call to the batch.
146    ///
147    /// ### Errors
148    ///
149    /// If the request cannot be serialized, this will return an error.
150    pub fn add_call<Params: RpcSend, Resp: RpcRecv>(
151        &mut self,
152        method: impl Into<Cow<'static, str>>,
153        params: &Params,
154    ) -> TransportResult<Waiter<Resp>> {
155        let request = self.transport.make_request(method, Cow::Borrowed(params));
156        self.push(request)
157    }
158
159    /// Send the batch future via its connection.
160    pub fn send(self) -> BatchFuture {
161        BatchFuture::Prepared {
162            transport: self.transport.transport.clone(),
163            requests: self.requests,
164            channels: self.channels,
165        }
166    }
167}
168
169impl IntoFuture for BatchRequest<'_> {
170    type Output = <BatchFuture as Future>::Output;
171    type IntoFuture = BatchFuture;
172
173    fn into_future(self) -> Self::IntoFuture {
174        self.send()
175    }
176}
177
178impl BatchFuture {
179    fn poll_prepared(
180        mut self: Pin<&mut Self>,
181        cx: &mut task::Context<'_>,
182    ) -> Poll<<Self as Future>::Output> {
183        let CallStateProj::Prepared { transport, requests, channels } = self.as_mut().project()
184        else {
185            unreachable!("Called poll_prepared in incorrect state")
186        };
187
188        if let Err(e) = task::ready!(transport.poll_ready(cx)) {
189            self.set(Self::Complete);
190            return Poll::Ready(Err(e));
191        }
192
193        // We only have mut refs, and we want ownership, so we just replace with 0-capacity
194        // collections.
195        let channels = std::mem::take(channels);
196        let req = std::mem::replace(requests, RequestPacket::Batch(Vec::new()));
197
198        let fut = transport.call(req);
199        self.set(Self::AwaitingResponse { channels, fut });
200        cx.waker().wake_by_ref();
201        Poll::Pending
202    }
203
204    fn poll_awaiting_response(
205        mut self: Pin<&mut Self>,
206        cx: &mut task::Context<'_>,
207    ) -> Poll<<Self as Future>::Output> {
208        let CallStateProj::AwaitingResponse { channels, fut } = self.as_mut().project() else {
209            unreachable!("Called poll_awaiting_response in incorrect state")
210        };
211
212        // Has the service responded yet?
213        let responses = match ready!(fut.poll(cx)) {
214            Ok(responses) => responses,
215            Err(e) => {
216                self.set(Self::Complete);
217                return Poll::Ready(Err(e));
218            }
219        };
220
221        // Send all responses via channels
222        match responses {
223            ResponsePacket::Single(single) => {
224                if let Some(tx) = channels.remove(&single.id) {
225                    let _ = tx.send(transform_response(single));
226                }
227            }
228            ResponsePacket::Batch(responses) => {
229                for response in responses {
230                    if let Some(tx) = channels.remove(&response.id) {
231                        let _ = tx.send(transform_response(response));
232                    }
233                }
234            }
235        }
236
237        // Any channels remaining in the map are missing responses.
238        // To avoid hanging futures, we send an error.
239        for (id, tx) in channels.drain() {
240            let _ = tx.send(Err(TransportErrorKind::missing_batch_response(id)));
241        }
242
243        self.set(Self::Complete);
244        Poll::Ready(Ok(()))
245    }
246}
247
248impl Future for BatchFuture {
249    type Output = TransportResult<()>;
250
251    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
252        if matches!(*self.as_mut(), Self::Prepared { .. }) {
253            return self.poll_prepared(cx);
254        }
255
256        if matches!(*self.as_mut(), Self::AwaitingResponse { .. }) {
257            return self.poll_awaiting_response(cx);
258        }
259
260        panic!("Called poll on BatchFuture in invalid state")
261    }
262}