1use crate::{client::RpcClientInner, ClientRef};
2use alloy_json_rpc::{
3 transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcRecv,
4 RpcSend, SerializedRequest,
5};
6use alloy_primitives::map::HashMap;
7use alloy_transport::{
8 BoxTransport, TransportError, TransportErrorKind, TransportFut, TransportResult,
9};
10use futures::FutureExt;
11use pin_project::pin_project;
12use serde_json::value::RawValue;
13use std::{
14 borrow::Cow,
15 future::{Future, IntoFuture},
16 marker::PhantomData,
17 pin::Pin,
18 task::{
19 self, ready,
20 Poll::{self, Ready},
21 },
22};
23use tokio::sync::oneshot;
24use tower::Service;
25
26pub(crate) type Channel = oneshot::Sender<TransportResult<Box<RawValue>>>;
27pub(crate) type ChannelMap = HashMap<Id, Channel>;
28
29#[derive(Debug)]
32#[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"]
33pub struct BatchRequest<'a> {
34 transport: ClientRef<'a>,
36
37 requests: RequestPacket,
39
40 channels: ChannelMap,
42}
43
44#[must_use = "A Waiter does nothing unless the corresponding BatchRequest is sent via `send_batch` and `.await`, AND the Waiter is awaited."]
46#[pin_project]
47#[derive(Debug)]
48pub struct Waiter<Resp, Output = Resp, Map = fn(Resp) -> Output> {
49 #[pin]
50 rx: oneshot::Receiver<TransportResult<Box<RawValue>>>,
51 map: Option<Map>,
52 _resp: PhantomData<fn() -> (Output, Resp)>,
53}
54
55impl<Resp, Output, Map> Waiter<Resp, Output, Map> {
56 pub fn map_resp<NewOutput, NewMap>(self, map: NewMap) -> Waiter<Resp, NewOutput, NewMap>
68 where
69 NewMap: FnOnce(Resp) -> NewOutput,
70 {
71 Waiter { rx: self.rx, map: Some(map), _resp: PhantomData }
72 }
73}
74
75impl<Resp> From<oneshot::Receiver<TransportResult<Box<RawValue>>>> for Waiter<Resp> {
76 fn from(rx: oneshot::Receiver<TransportResult<Box<RawValue>>>) -> Self {
77 Self { rx, map: Some(std::convert::identity), _resp: PhantomData }
78 }
79}
80
81impl<Resp, Output, Map> std::future::Future for Waiter<Resp, Output, Map>
82where
83 Resp: RpcRecv,
84 Map: FnOnce(Resp) -> Output,
85{
86 type Output = TransportResult<Output>;
87
88 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
89 let this = self.get_mut();
90
91 match ready!(this.rx.poll_unpin(cx)) {
92 Ok(resp) => {
93 let resp: Result<Resp, _> = try_deserialize_ok(resp);
94 Ready(resp.map(this.map.take().expect("polled after completion")))
95 }
96 Err(e) => Poll::Ready(Err(TransportErrorKind::custom(e))),
97 }
98 }
99}
100
101#[pin_project::pin_project(project = CallStateProj)]
102#[expect(unnameable_types, missing_debug_implementations)]
103pub enum BatchFuture {
104 Prepared {
105 transport: BoxTransport,
106 requests: RequestPacket,
107 channels: ChannelMap,
108 },
109 AwaitingResponse {
110 channels: ChannelMap,
111 #[pin]
112 fut: TransportFut<'static>,
113 },
114 Complete,
115}
116
117impl<'a> BatchRequest<'a> {
118 pub fn new(transport: &'a RpcClientInner) -> Self {
120 Self {
121 transport,
122 requests: RequestPacket::Batch(Vec::with_capacity(10)),
123 channels: HashMap::with_capacity_and_hasher(10, Default::default()),
124 }
125 }
126
127 fn push_raw(
128 &mut self,
129 request: SerializedRequest,
130 ) -> oneshot::Receiver<TransportResult<Box<RawValue>>> {
131 let (tx, rx) = oneshot::channel();
132 self.channels.insert(request.id().clone(), tx);
133 self.requests.push(request);
134 rx
135 }
136
137 fn push<Params: RpcSend, Resp: RpcRecv>(
138 &mut self,
139 request: Request<Params>,
140 ) -> TransportResult<Waiter<Resp>> {
141 let ser = request.serialize().map_err(TransportError::ser_err)?;
142 Ok(self.push_raw(ser).into())
143 }
144
145 pub fn add_call<Params: RpcSend, Resp: RpcRecv>(
151 &mut self,
152 method: impl Into<Cow<'static, str>>,
153 params: &Params,
154 ) -> TransportResult<Waiter<Resp>> {
155 let request = self.transport.make_request(method, Cow::Borrowed(params));
156 self.push(request)
157 }
158
159 pub fn send(self) -> BatchFuture {
161 BatchFuture::Prepared {
162 transport: self.transport.transport.clone(),
163 requests: self.requests,
164 channels: self.channels,
165 }
166 }
167}
168
169impl IntoFuture for BatchRequest<'_> {
170 type Output = <BatchFuture as Future>::Output;
171 type IntoFuture = BatchFuture;
172
173 fn into_future(self) -> Self::IntoFuture {
174 self.send()
175 }
176}
177
178impl BatchFuture {
179 fn poll_prepared(
180 mut self: Pin<&mut Self>,
181 cx: &mut task::Context<'_>,
182 ) -> Poll<<Self as Future>::Output> {
183 let CallStateProj::Prepared { transport, requests, channels } = self.as_mut().project()
184 else {
185 unreachable!("Called poll_prepared in incorrect state")
186 };
187
188 if let Err(e) = task::ready!(transport.poll_ready(cx)) {
189 self.set(Self::Complete);
190 return Poll::Ready(Err(e));
191 }
192
193 let channels = std::mem::take(channels);
196 let req = std::mem::replace(requests, RequestPacket::Batch(Vec::new()));
197
198 let fut = transport.call(req);
199 self.set(Self::AwaitingResponse { channels, fut });
200 cx.waker().wake_by_ref();
201 Poll::Pending
202 }
203
204 fn poll_awaiting_response(
205 mut self: Pin<&mut Self>,
206 cx: &mut task::Context<'_>,
207 ) -> Poll<<Self as Future>::Output> {
208 let CallStateProj::AwaitingResponse { channels, fut } = self.as_mut().project() else {
209 unreachable!("Called poll_awaiting_response in incorrect state")
210 };
211
212 let responses = match ready!(fut.poll(cx)) {
214 Ok(responses) => responses,
215 Err(e) => {
216 self.set(Self::Complete);
217 return Poll::Ready(Err(e));
218 }
219 };
220
221 match responses {
223 ResponsePacket::Single(single) => {
224 if let Some(tx) = channels.remove(&single.id) {
225 let _ = tx.send(transform_response(single));
226 }
227 }
228 ResponsePacket::Batch(responses) => {
229 for response in responses {
230 if let Some(tx) = channels.remove(&response.id) {
231 let _ = tx.send(transform_response(response));
232 }
233 }
234 }
235 }
236
237 for (id, tx) in channels.drain() {
240 let _ = tx.send(Err(TransportErrorKind::missing_batch_response(id)));
241 }
242
243 self.set(Self::Complete);
244 Poll::Ready(Ok(()))
245 }
246}
247
248impl Future for BatchFuture {
249 type Output = TransportResult<()>;
250
251 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
252 if matches!(*self.as_mut(), Self::Prepared { .. }) {
253 return self.poll_prepared(cx);
254 }
255
256 if matches!(*self.as_mut(), Self::AwaitingResponse { .. }) {
257 return self.poll_awaiting_response(cx);
258 }
259
260 panic!("Called poll on BatchFuture in invalid state")
261 }
262}