alloy_json_rpc/
packet.rs

1use crate::{ErrorPayload, Id, Response, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use serde::{
4    de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
5    Deserialize, Serialize,
6};
7use serde_json::value::RawValue;
8use std::{fmt, marker::PhantomData};
9
10/// A [`RequestPacket`] is a [`SerializedRequest`] or a batch of serialized
11/// request.
12#[derive(Clone, Debug)]
13pub enum RequestPacket {
14    /// A single request.
15    Single(SerializedRequest),
16    /// A batch of requests.
17    Batch(Vec<SerializedRequest>),
18}
19
20impl FromIterator<SerializedRequest> for RequestPacket {
21    fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
22        Self::Batch(iter.into_iter().collect())
23    }
24}
25
26impl From<SerializedRequest> for RequestPacket {
27    fn from(req: SerializedRequest) -> Self {
28        Self::Single(req)
29    }
30}
31
32impl Serialize for RequestPacket {
33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: serde::Serializer,
36    {
37        match self {
38            Self::Single(single) => single.serialize(serializer),
39            Self::Batch(batch) => batch.serialize(serializer),
40        }
41    }
42}
43
44impl RequestPacket {
45    /// Create a new empty packet with the given capacity.
46    pub fn with_capacity(capacity: usize) -> Self {
47        Self::Batch(Vec::with_capacity(capacity))
48    }
49
50    /// Serialize the packet as a boxed [`RawValue`].
51    pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
52        match self {
53            Self::Single(single) => Ok(single.take_request()),
54            Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
55        }
56    }
57
58    /// Get the request IDs of all subscription requests in the packet.
59    pub fn subscription_request_ids(&self) -> HashSet<&Id> {
60        match self {
61            Self::Single(single) => {
62                let id = (single.method() == "eth_subscribe").then(|| single.id());
63                HashSet::from_iter(id)
64            }
65            Self::Batch(batch) => batch
66                .iter()
67                .filter(|req| req.method() == "eth_subscribe")
68                .map(|req| req.id())
69                .collect(),
70        }
71    }
72
73    /// Get the number of requests in the packet.
74    pub fn len(&self) -> usize {
75        match self {
76            Self::Single(_) => 1,
77            Self::Batch(batch) => batch.len(),
78        }
79    }
80
81    /// Check if the packet is empty.
82    pub fn is_empty(&self) -> bool {
83        self.len() == 0
84    }
85
86    /// Push a request into the packet.
87    pub fn push(&mut self, req: SerializedRequest) {
88        match self {
89            Self::Batch(batch) => batch.push(req),
90            Self::Single(_) => {
91                let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
92                if let Self::Single(single) = old {
93                    self.push(single);
94                }
95                self.push(req);
96            }
97        }
98    }
99}
100
101/// A [`ResponsePacket`] is a [`Response`] or a batch of responses.
102#[derive(Clone, Debug)]
103pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
104    /// A single response.
105    Single(Response<Payload, ErrData>),
106    /// A batch of responses.
107    Batch(Vec<Response<Payload, ErrData>>),
108}
109
110impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
111    for ResponsePacket<Payload, ErrData>
112{
113    fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
114        let mut iter = iter.into_iter().peekable();
115        // return single if iter has exactly one element, else make a batch
116        if let Some(first) = iter.next() {
117            return if iter.peek().is_none() {
118                Self::Single(first)
119            } else {
120                let mut batch = Vec::new();
121                batch.push(first);
122                batch.extend(iter);
123                Self::Batch(batch)
124            };
125        }
126        Self::Batch(vec![])
127    }
128}
129
130impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
131    fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
132        if value.len() == 1 {
133            Self::Single(value.into_iter().next().unwrap())
134        } else {
135            Self::Batch(value)
136        }
137    }
138}
139
140impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
141where
142    Payload: Deserialize<'de>,
143    ErrData: Deserialize<'de>,
144{
145    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146    where
147        D: Deserializer<'de>,
148    {
149        struct ResponsePacketVisitor<Payload, ErrData> {
150            marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
151        }
152
153        impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
154        where
155            Payload: Deserialize<'de>,
156            ErrData: Deserialize<'de>,
157        {
158            type Value = ResponsePacket<Payload, ErrData>;
159
160            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
161                formatter.write_str("a single response or a batch of responses")
162            }
163
164            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
165            where
166                A: SeqAccess<'de>,
167            {
168                let mut responses = Vec::new();
169
170                while let Some(response) = seq.next_element()? {
171                    responses.push(response);
172                }
173
174                Ok(ResponsePacket::Batch(responses))
175            }
176
177            fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
178            where
179                M: MapAccess<'de>,
180            {
181                let response =
182                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
183                Ok(ResponsePacket::Single(response))
184            }
185        }
186
187        deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
188    }
189}
190
191/// A [`BorrowedResponsePacket`] is a [`ResponsePacket`] that has been partially deserialized,
192/// borrowing its contents from the deserializer.
193///
194/// This is used primarily for intermediate deserialization. Most users will not require it.
195///
196/// See the [top-level docs] for more info.
197///
198/// [top-level docs]: crate
199pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
200
201impl BorrowedResponsePacket<'_> {
202    /// Convert this borrowed response packet into an owned packet by copying
203    /// the data from the deserializer (if necessary).
204    pub fn into_owned(self) -> ResponsePacket {
205        match self {
206            Self::Single(single) => ResponsePacket::Single(single.into_owned()),
207            Self::Batch(batch) => {
208                ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
209            }
210        }
211    }
212}
213
214impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
215    /// Returns `true` if the response payload is a success.
216    ///
217    /// For batch responses, this returns `true` if __all__ responses are successful.
218    pub fn is_success(&self) -> bool {
219        match self {
220            Self::Single(single) => single.is_success(),
221            Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
222        }
223    }
224
225    /// Returns `true` if the response payload is an error.
226    ///
227    /// For batch responses, this returns `true` there's at least one error response.
228    pub fn is_error(&self) -> bool {
229        match self {
230            Self::Single(single) => single.is_error(),
231            Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
232        }
233    }
234
235    /// Returns the [ErrorPayload] if the response is an error.
236    ///
237    /// For batch responses, this returns the first error response.
238    pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
239        self.iter_errors().next()
240    }
241
242    /// Returns an iterator over the [ErrorPayload]s in the response.
243    pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
244        match self {
245            Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
246            Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
247        }
248    }
249
250    /// Find responses by a list of IDs.
251    ///
252    /// This is intended to be used in conjunction with
253    /// [`RequestPacket::subscription_request_ids`] to identify subscription
254    /// responses.
255    ///
256    /// # Note
257    ///
258    /// - Responses are not guaranteed to be in the same order.
259    /// - Responses are not guaranteed to be in the set.
260    /// - If the packet contains duplicate IDs, both will be found.
261    pub fn responses_by_ids(&self, ids: &HashSet<Id>) -> Vec<&Response<Payload, ErrData>> {
262        match self {
263            Self::Single(single) if ids.contains(&single.id) => vec![single],
264            Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
265            _ => Vec::new(),
266        }
267    }
268}
269
270/// An Iterator over the [ErrorPayload]s in a [ResponsePacket].
271#[derive(Clone, Debug)]
272enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
273    Single(Option<&'a Response<Payload, ErrData>>),
274    Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
275}
276
277impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
278    type Item = &'a ErrorPayload<ErrData>;
279
280    fn next(&mut self) -> Option<Self::Item> {
281        match self {
282            ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
283            ResponsePacketErrorsIter::Batch(batch) => loop {
284                let res = batch.next()?;
285                if let Some(err) = res.payload.as_error() {
286                    return Some(err);
287                }
288            },
289        }
290    }
291}