alloy_json_rpc/
packet.rs

1use crate::{ErrorPayload, Id, Response, ResponsePayload, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use http::HeaderMap;
4use serde::{
5    de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
6    Deserialize, Serialize,
7};
8use serde_json::value::RawValue;
9use std::{borrow::Borrow, fmt, hash::Hash, marker::PhantomData};
10
11/// A [`RequestPacket`] is a [`SerializedRequest`] or a batch of serialized
12/// request.
13#[derive(Clone, Debug)]
14pub enum RequestPacket {
15    /// A single request.
16    Single(SerializedRequest),
17    /// A batch of requests.
18    Batch(Vec<SerializedRequest>),
19}
20
21impl FromIterator<SerializedRequest> for RequestPacket {
22    fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
23        Self::Batch(iter.into_iter().collect())
24    }
25}
26
27impl From<SerializedRequest> for RequestPacket {
28    fn from(req: SerializedRequest) -> Self {
29        Self::Single(req)
30    }
31}
32
33impl Serialize for RequestPacket {
34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35    where
36        S: serde::Serializer,
37    {
38        match self {
39            Self::Single(single) => single.serialize(serializer),
40            Self::Batch(batch) => batch.serialize(serializer),
41        }
42    }
43}
44
45impl RequestPacket {
46    /// Create a new empty packet with the given capacity.
47    pub fn with_capacity(capacity: usize) -> Self {
48        Self::Batch(Vec::with_capacity(capacity))
49    }
50
51    /// Returns the [`SerializedRequest`] if this packet is [`RequestPacket::Single`]
52    pub const fn as_single(&self) -> Option<&SerializedRequest> {
53        match self {
54            Self::Single(req) => Some(req),
55            Self::Batch(_) => None,
56        }
57    }
58
59    /// Returns the batch of [`SerializedRequest`] if this packet is [`RequestPacket::Batch`]
60    pub const fn as_batch(&self) -> Option<&[SerializedRequest]> {
61        match self {
62            Self::Batch(req) => Some(req.as_slice()),
63            Self::Single(_) => None,
64        }
65    }
66
67    /// Serialize the packet as a boxed [`RawValue`].
68    pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
69        match self {
70            Self::Single(single) => Ok(single.take_request()),
71            Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
72        }
73    }
74
75    /// Get the request IDs of all subscription requests in the packet.
76    pub fn subscription_request_ids(&self) -> HashSet<&Id> {
77        match self {
78            Self::Single(single) => {
79                let id = single.is_subscription().then(|| single.id());
80                HashSet::from_iter(id)
81            }
82            Self::Batch(batch) => {
83                batch.iter().filter(|req| req.is_subscription()).map(|req| req.id()).collect()
84            }
85        }
86    }
87
88    /// Get the number of requests in the packet.
89    pub const fn len(&self) -> usize {
90        match self {
91            Self::Single(_) => 1,
92            Self::Batch(batch) => batch.len(),
93        }
94    }
95
96    /// Check if the packet is empty.
97    pub const fn is_empty(&self) -> bool {
98        self.len() == 0
99    }
100
101    /// Push a request into the packet.
102    pub fn push(&mut self, req: SerializedRequest) {
103        match self {
104            Self::Batch(batch) => batch.push(req),
105            Self::Single(_) => {
106                let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
107                if let Self::Single(single) = old {
108                    self.push(single);
109                }
110                self.push(req);
111            }
112        }
113    }
114
115    /// Returns all [`SerializedRequest`].
116    pub const fn requests(&self) -> &[SerializedRequest] {
117        match self {
118            Self::Single(req) => std::slice::from_ref(req),
119            Self::Batch(req) => req.as_slice(),
120        }
121    }
122
123    /// Returns a mutable reference to all [`SerializedRequest`].
124    pub const fn requests_mut(&mut self) -> &mut [SerializedRequest] {
125        match self {
126            Self::Single(req) => std::slice::from_mut(req),
127            Self::Batch(req) => req.as_mut_slice(),
128        }
129    }
130
131    /// Returns an iterator over the requests' method names
132    pub fn method_names(&self) -> impl Iterator<Item = &str> + '_ {
133        self.requests().iter().map(|req| req.method())
134    }
135
136    /// Retrieves the combined headers from all requests in the packet. If
137    /// multiple requests contain the same header, the last one wins.
138    pub fn headers(&self) -> HeaderMap {
139        self.requests().iter().fold(HeaderMap::new(), |mut acc, req| {
140            if let Some(http_header_extension) = req.meta().extensions().get::<HeaderMap>() {
141                acc.extend(http_header_extension.iter().map(|(k, v)| (k.clone(), v.clone())));
142            };
143            acc
144        })
145    }
146}
147
148/// A [`ResponsePacket`] is a [`Response`] or a batch of responses.
149#[derive(Clone, Debug)]
150pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
151    /// A single response.
152    Single(Response<Payload, ErrData>),
153    /// A batch of responses.
154    Batch(Vec<Response<Payload, ErrData>>),
155}
156
157impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
158    for ResponsePacket<Payload, ErrData>
159{
160    fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
161        let mut iter = iter.into_iter().peekable();
162        // return single if iter has exactly one element, else make a batch
163        if let Some(first) = iter.next() {
164            return if iter.peek().is_none() {
165                Self::Single(first)
166            } else {
167                let mut batch = Vec::new();
168                batch.push(first);
169                batch.extend(iter);
170                Self::Batch(batch)
171            };
172        }
173        Self::Batch(vec![])
174    }
175}
176
177impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
178    fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
179        if value.len() == 1 {
180            Self::Single(value.into_iter().next().unwrap())
181        } else {
182            Self::Batch(value)
183        }
184    }
185}
186
187impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
188where
189    Payload: Deserialize<'de>,
190    ErrData: Deserialize<'de>,
191{
192    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193    where
194        D: Deserializer<'de>,
195    {
196        struct ResponsePacketVisitor<Payload, ErrData> {
197            marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
198        }
199
200        impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
201        where
202            Payload: Deserialize<'de>,
203            ErrData: Deserialize<'de>,
204        {
205            type Value = ResponsePacket<Payload, ErrData>;
206
207            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
208                formatter.write_str("a single response or a batch of responses")
209            }
210
211            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
212            where
213                A: SeqAccess<'de>,
214            {
215                let mut responses = Vec::new();
216
217                while let Some(response) = seq.next_element()? {
218                    responses.push(response);
219                }
220
221                Ok(ResponsePacket::Batch(responses))
222            }
223
224            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
225            where
226                M: MapAccess<'de>,
227            {
228                let response =
229                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
230                Ok(ResponsePacket::Single(response))
231            }
232        }
233
234        deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
235    }
236}
237
238/// A [`BorrowedResponsePacket`] is a [`ResponsePacket`] that has been partially deserialized,
239/// borrowing its contents from the deserializer.
240///
241/// This is used primarily for intermediate deserialization. Most users will not require it.
242///
243/// See the [top-level docs] for more info.
244///
245/// [top-level docs]: crate
246pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
247
248impl BorrowedResponsePacket<'_> {
249    /// Convert this borrowed response packet into an owned packet by copying
250    /// the data from the deserializer (if necessary).
251    pub fn into_owned(self) -> ResponsePacket {
252        match self {
253            Self::Single(single) => ResponsePacket::Single(single.into_owned()),
254            Self::Batch(batch) => {
255                ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
256            }
257        }
258    }
259}
260
261impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
262    /// Returns the [`Response`] if this packet is [`ResponsePacket::Single`].
263    pub const fn as_single(&self) -> Option<&Response<Payload, ErrData>> {
264        match self {
265            Self::Single(resp) => Some(resp),
266            Self::Batch(_) => None,
267        }
268    }
269
270    /// Returns the batch of [`Response`] if this packet is [`ResponsePacket::Batch`].
271    pub const fn as_batch(&self) -> Option<&[Response<Payload, ErrData>]> {
272        match self {
273            Self::Batch(resp) => Some(resp.as_slice()),
274            Self::Single(_) => None,
275        }
276    }
277
278    /// Returns the [`ResponsePayload`] if this packet is [`ResponsePacket::Single`].
279    pub fn single_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
280        self.as_single().map(|resp| &resp.payload)
281    }
282
283    /// Returns `true` if the response payload is a success.
284    ///
285    /// For batch responses, this returns `true` if __all__ responses are successful.
286    pub fn is_success(&self) -> bool {
287        match self {
288            Self::Single(single) => single.is_success(),
289            Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
290        }
291    }
292
293    /// Returns `true` if the response payload is an error.
294    ///
295    /// For batch responses, this returns `true` there's at least one error response.
296    pub fn is_error(&self) -> bool {
297        match self {
298            Self::Single(single) => single.is_error(),
299            Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
300        }
301    }
302
303    /// Returns the [ErrorPayload] if the response is an error.
304    ///
305    /// For batch responses, this returns the first error response.
306    pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
307        self.iter_errors().next()
308    }
309
310    /// Returns an iterator over the [ErrorPayload]s in the response.
311    pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
312        match self {
313            Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
314            Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
315        }
316    }
317
318    /// Returns the first error code in this packet if it contains any error responses.
319    pub fn first_error_code(&self) -> Option<i64> {
320        self.as_error().map(|error| error.code)
321    }
322
323    /// Returns the first error message in this packet if it contains any error responses.
324    pub fn first_error_message(&self) -> Option<&str> {
325        self.as_error().map(|error| error.message.as_ref())
326    }
327
328    /// Returns the first error data in this packet if it contains any error responses.
329    pub fn first_error_data(&self) -> Option<&ErrData> {
330        self.as_error().and_then(|error| error.data.as_ref())
331    }
332
333    /// Returns a all [`Response`].
334    pub const fn responses(&self) -> &[Response<Payload, ErrData>] {
335        match self {
336            Self::Single(req) => std::slice::from_ref(req),
337            Self::Batch(req) => req.as_slice(),
338        }
339    }
340
341    /// Returns an iterator over the responses' payloads.
342    pub fn payloads(&self) -> impl Iterator<Item = &ResponsePayload<Payload, ErrData>> + '_ {
343        self.responses().iter().map(|resp| &resp.payload)
344    }
345
346    /// Returns the first [`ResponsePayload`] in this packet.
347    pub fn first_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
348        self.payloads().next()
349    }
350
351    /// Returns an iterator over the responses' identifiers.
352    pub fn response_ids(&self) -> impl Iterator<Item = &Id> + '_ {
353        self.responses().iter().map(|resp| &resp.id)
354    }
355
356    /// Find responses by a list of IDs.
357    ///
358    /// This is intended to be used in conjunction with
359    /// [`RequestPacket::subscription_request_ids`] to identify subscription
360    /// responses.
361    ///
362    /// # Note
363    ///
364    /// - Responses are not guaranteed to be in the same order.
365    /// - Responses are not guaranteed to be in the set.
366    /// - If the packet contains duplicate IDs, both will be found.
367    pub fn responses_by_ids<K>(&self, ids: &HashSet<K>) -> Vec<&Response<Payload, ErrData>>
368    where
369        K: Borrow<Id> + Eq + Hash,
370    {
371        match self {
372            Self::Single(single) if ids.contains(&single.id) => vec![single],
373            Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
374            _ => Vec::new(),
375        }
376    }
377}
378
379/// An Iterator over the [ErrorPayload]s in a [ResponsePacket].
380#[derive(Clone, Debug)]
381enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
382    Single(Option<&'a Response<Payload, ErrData>>),
383    Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
384}
385
386impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
387    type Item = &'a ErrorPayload<ErrData>;
388
389    fn next(&mut self) -> Option<Self::Item> {
390        match self {
391            ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
392            ResponsePacketErrorsIter::Batch(batch) => loop {
393                let res = batch.next()?;
394                if let Some(err) = res.payload.as_error() {
395                    return Some(err);
396                }
397            },
398        }
399    }
400}