alloy_json_rpc/
packet.rs

1use crate::{ErrorPayload, Id, Response, ResponsePayload, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use http::HeaderMap;
4use serde::{
5    de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
6    Deserialize, Serialize,
7};
8use serde_json::value::RawValue;
9use std::{borrow::Borrow, fmt, hash::Hash, marker::PhantomData};
10
11/// A [`RequestPacket`] is a [`SerializedRequest`] or a batch of serialized
12/// request.
13#[derive(Clone, Debug)]
14pub enum RequestPacket {
15    /// A single request.
16    Single(SerializedRequest),
17    /// A batch of requests.
18    Batch(Vec<SerializedRequest>),
19}
20
21impl FromIterator<SerializedRequest> for RequestPacket {
22    fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
23        Self::Batch(iter.into_iter().collect())
24    }
25}
26
27impl From<SerializedRequest> for RequestPacket {
28    fn from(req: SerializedRequest) -> Self {
29        Self::Single(req)
30    }
31}
32
33impl Serialize for RequestPacket {
34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35    where
36        S: serde::Serializer,
37    {
38        match self {
39            Self::Single(single) => single.serialize(serializer),
40            Self::Batch(batch) => batch.serialize(serializer),
41        }
42    }
43}
44
45impl RequestPacket {
46    /// Create a new empty packet with the given capacity.
47    pub fn with_capacity(capacity: usize) -> Self {
48        Self::Batch(Vec::with_capacity(capacity))
49    }
50
51    /// Returns the [`SerializedRequest`] if this packet is [`ResponsePacket::Single`]
52    pub const fn as_single(&self) -> Option<&SerializedRequest> {
53        match self {
54            Self::Single(req) => Some(req),
55            Self::Batch(_) => None,
56        }
57    }
58
59    /// Returns the batch of [`SerializedRequest`] if this packet is [`ResponsePacket::Batch`]
60    pub fn as_batch(&self) -> Option<&[SerializedRequest]> {
61        match self {
62            Self::Batch(req) => Some(req.as_slice()),
63            Self::Single(_) => None,
64        }
65    }
66
67    /// Serialize the packet as a boxed [`RawValue`].
68    pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
69        match self {
70            Self::Single(single) => Ok(single.take_request()),
71            Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
72        }
73    }
74
75    /// Get the request IDs of all subscription requests in the packet.
76    pub fn subscription_request_ids(&self) -> HashSet<&Id> {
77        match self {
78            Self::Single(single) => {
79                let id = single.is_subscription().then(|| single.id());
80                HashSet::from_iter(id)
81            }
82            Self::Batch(batch) => {
83                batch.iter().filter(|req| req.is_subscription()).map(|req| req.id()).collect()
84            }
85        }
86    }
87
88    /// Get the number of requests in the packet.
89    pub fn len(&self) -> usize {
90        match self {
91            Self::Single(_) => 1,
92            Self::Batch(batch) => batch.len(),
93        }
94    }
95
96    /// Check if the packet is empty.
97    pub fn is_empty(&self) -> bool {
98        self.len() == 0
99    }
100
101    /// Push a request into the packet.
102    pub fn push(&mut self, req: SerializedRequest) {
103        match self {
104            Self::Batch(batch) => batch.push(req),
105            Self::Single(_) => {
106                let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
107                if let Self::Single(single) = old {
108                    self.push(single);
109                }
110                self.push(req);
111            }
112        }
113    }
114
115    /// Returns all [`SerializedRequest`].
116    pub fn requests(&self) -> &[SerializedRequest] {
117        match self {
118            Self::Single(req) => std::slice::from_ref(req),
119            Self::Batch(req) => req.as_slice(),
120        }
121    }
122
123    /// Returns a mutable reference to all [`SerializedRequest`].
124    pub fn requests_mut(&mut self) -> &mut [SerializedRequest] {
125        match self {
126            Self::Single(req) => std::slice::from_mut(req),
127            Self::Batch(req) => req.as_mut_slice(),
128        }
129    }
130
131    /// Returns an iterator over the requests' method names
132    pub fn method_names(&self) -> impl Iterator<Item = &str> + '_ {
133        self.requests().iter().map(|req| req.method())
134    }
135
136    /// Retrieves the [`HeaderMap`] from the request metadata if available;
137    /// otherwise, returns an empty map. This functionality is only supported for single requests.
138    pub fn headers(&self) -> HeaderMap {
139        // If this is a batch request, we cannot return headers.
140        let Some(single_req) = self.as_single() else {
141            return HeaderMap::new();
142        };
143        // If the request provides a `HeaderMap` return it.
144        if let Some(http_header_extension) = single_req.meta().extensions().get::<HeaderMap>() {
145            return http_header_extension.clone();
146        };
147
148        HeaderMap::new()
149    }
150}
151
152/// A [`ResponsePacket`] is a [`Response`] or a batch of responses.
153#[derive(Clone, Debug)]
154pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
155    /// A single response.
156    Single(Response<Payload, ErrData>),
157    /// A batch of responses.
158    Batch(Vec<Response<Payload, ErrData>>),
159}
160
161impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
162    for ResponsePacket<Payload, ErrData>
163{
164    fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
165        let mut iter = iter.into_iter().peekable();
166        // return single if iter has exactly one element, else make a batch
167        if let Some(first) = iter.next() {
168            return if iter.peek().is_none() {
169                Self::Single(first)
170            } else {
171                let mut batch = Vec::new();
172                batch.push(first);
173                batch.extend(iter);
174                Self::Batch(batch)
175            };
176        }
177        Self::Batch(vec![])
178    }
179}
180
181impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
182    fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
183        if value.len() == 1 {
184            Self::Single(value.into_iter().next().unwrap())
185        } else {
186            Self::Batch(value)
187        }
188    }
189}
190
191impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
192where
193    Payload: Deserialize<'de>,
194    ErrData: Deserialize<'de>,
195{
196    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
197    where
198        D: Deserializer<'de>,
199    {
200        struct ResponsePacketVisitor<Payload, ErrData> {
201            marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
202        }
203
204        impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
205        where
206            Payload: Deserialize<'de>,
207            ErrData: Deserialize<'de>,
208        {
209            type Value = ResponsePacket<Payload, ErrData>;
210
211            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
212                formatter.write_str("a single response or a batch of responses")
213            }
214
215            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
216            where
217                A: SeqAccess<'de>,
218            {
219                let mut responses = Vec::new();
220
221                while let Some(response) = seq.next_element()? {
222                    responses.push(response);
223                }
224
225                Ok(ResponsePacket::Batch(responses))
226            }
227
228            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
229            where
230                M: MapAccess<'de>,
231            {
232                let response =
233                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
234                Ok(ResponsePacket::Single(response))
235            }
236        }
237
238        deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
239    }
240}
241
242/// A [`BorrowedResponsePacket`] is a [`ResponsePacket`] that has been partially deserialized,
243/// borrowing its contents from the deserializer.
244///
245/// This is used primarily for intermediate deserialization. Most users will not require it.
246///
247/// See the [top-level docs] for more info.
248///
249/// [top-level docs]: crate
250pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
251
252impl BorrowedResponsePacket<'_> {
253    /// Convert this borrowed response packet into an owned packet by copying
254    /// the data from the deserializer (if necessary).
255    pub fn into_owned(self) -> ResponsePacket {
256        match self {
257            Self::Single(single) => ResponsePacket::Single(single.into_owned()),
258            Self::Batch(batch) => {
259                ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
260            }
261        }
262    }
263}
264
265impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
266    /// Returns the [`Response`] if this packet is [`ResponsePacket::Single`].
267    pub const fn as_single(&self) -> Option<&Response<Payload, ErrData>> {
268        match self {
269            Self::Single(resp) => Some(resp),
270            Self::Batch(_) => None,
271        }
272    }
273
274    /// Returns the batch of [`Response`] if this packet is [`ResponsePacket::Batch`].
275    pub fn as_batch(&self) -> Option<&[Response<Payload, ErrData>]> {
276        match self {
277            Self::Batch(resp) => Some(resp.as_slice()),
278            Self::Single(_) => None,
279        }
280    }
281
282    /// Returns the [`ResponsePayload`] if this packet is [`ResponsePacket::Single`].
283    pub fn single_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
284        self.as_single().map(|resp| &resp.payload)
285    }
286
287    /// Returns `true` if the response payload is a success.
288    ///
289    /// For batch responses, this returns `true` if __all__ responses are successful.
290    pub fn is_success(&self) -> bool {
291        match self {
292            Self::Single(single) => single.is_success(),
293            Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
294        }
295    }
296
297    /// Returns `true` if the response payload is an error.
298    ///
299    /// For batch responses, this returns `true` there's at least one error response.
300    pub fn is_error(&self) -> bool {
301        match self {
302            Self::Single(single) => single.is_error(),
303            Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
304        }
305    }
306
307    /// Returns the [ErrorPayload] if the response is an error.
308    ///
309    /// For batch responses, this returns the first error response.
310    pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
311        self.iter_errors().next()
312    }
313
314    /// Returns an iterator over the [ErrorPayload]s in the response.
315    pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
316        match self {
317            Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
318            Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
319        }
320    }
321
322    /// Returns the first error code in this packet if it contains any error responses.
323    pub fn first_error_code(&self) -> Option<i64> {
324        self.as_error().map(|error| error.code)
325    }
326
327    /// Returns the first error message in this packet if it contains any error responses.
328    pub fn first_error_message(&self) -> Option<&str> {
329        self.as_error().map(|error| error.message.as_ref())
330    }
331
332    /// Returns the first error data in this packet if it contains any error responses.
333    pub fn first_error_data(&self) -> Option<&ErrData> {
334        self.as_error().and_then(|error| error.data.as_ref())
335    }
336
337    /// Returns a all [`Response`].
338    pub fn responses(&self) -> &[Response<Payload, ErrData>] {
339        match self {
340            Self::Single(req) => std::slice::from_ref(req),
341            Self::Batch(req) => req.as_slice(),
342        }
343    }
344
345    /// Returns an iterator over the responses' payloads.
346    pub fn payloads(&self) -> impl Iterator<Item = &ResponsePayload<Payload, ErrData>> + '_ {
347        self.responses().iter().map(|resp| &resp.payload)
348    }
349
350    /// Returns the first [`ResponsePayload`] in this packet.
351    pub fn first_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
352        self.payloads().next()
353    }
354
355    /// Returns an iterator over the responses' identifiers.
356    pub fn response_ids(&self) -> impl Iterator<Item = &Id> + '_ {
357        self.responses().iter().map(|resp| &resp.id)
358    }
359
360    /// Find responses by a list of IDs.
361    ///
362    /// This is intended to be used in conjunction with
363    /// [`RequestPacket::subscription_request_ids`] to identify subscription
364    /// responses.
365    ///
366    /// # Note
367    ///
368    /// - Responses are not guaranteed to be in the same order.
369    /// - Responses are not guaranteed to be in the set.
370    /// - If the packet contains duplicate IDs, both will be found.
371    pub fn responses_by_ids<K>(&self, ids: &HashSet<K>) -> Vec<&Response<Payload, ErrData>>
372    where
373        K: Borrow<Id> + Eq + Hash,
374    {
375        match self {
376            Self::Single(single) if ids.contains(&single.id) => vec![single],
377            Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
378            _ => Vec::new(),
379        }
380    }
381}
382
383/// An Iterator over the [ErrorPayload]s in a [ResponsePacket].
384#[derive(Clone, Debug)]
385enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
386    Single(Option<&'a Response<Payload, ErrData>>),
387    Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
388}
389
390impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
391    type Item = &'a ErrorPayload<ErrData>;
392
393    fn next(&mut self) -> Option<Self::Item> {
394        match self {
395            ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
396            ResponsePacketErrorsIter::Batch(batch) => loop {
397                let res = batch.next()?;
398                if let Some(err) = res.payload.as_error() {
399                    return Some(err);
400                }
401            },
402        }
403    }
404}