1use crate::{ErrorPayload, Id, Response, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use serde::{
4 de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
5 Deserialize, Serialize,
6};
7use serde_json::value::RawValue;
8use std::{fmt, marker::PhantomData};
9
10#[derive(Clone, Debug)]
13pub enum RequestPacket {
14 Single(SerializedRequest),
16 Batch(Vec<SerializedRequest>),
18}
19
20impl FromIterator<SerializedRequest> for RequestPacket {
21 fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
22 Self::Batch(iter.into_iter().collect())
23 }
24}
25
26impl From<SerializedRequest> for RequestPacket {
27 fn from(req: SerializedRequest) -> Self {
28 Self::Single(req)
29 }
30}
31
32impl Serialize for RequestPacket {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: serde::Serializer,
36 {
37 match self {
38 Self::Single(single) => single.serialize(serializer),
39 Self::Batch(batch) => batch.serialize(serializer),
40 }
41 }
42}
43
44impl RequestPacket {
45 pub fn with_capacity(capacity: usize) -> Self {
47 Self::Batch(Vec::with_capacity(capacity))
48 }
49
50 pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
52 match self {
53 Self::Single(single) => Ok(single.take_request()),
54 Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
55 }
56 }
57
58 pub fn subscription_request_ids(&self) -> HashSet<&Id> {
60 match self {
61 Self::Single(single) => {
62 let id = (single.method() == "eth_subscribe").then(|| single.id());
63 HashSet::from_iter(id)
64 }
65 Self::Batch(batch) => batch
66 .iter()
67 .filter(|req| req.method() == "eth_subscribe")
68 .map(|req| req.id())
69 .collect(),
70 }
71 }
72
73 pub fn len(&self) -> usize {
75 match self {
76 Self::Single(_) => 1,
77 Self::Batch(batch) => batch.len(),
78 }
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.len() == 0
84 }
85
86 pub fn push(&mut self, req: SerializedRequest) {
88 match self {
89 Self::Batch(batch) => batch.push(req),
90 Self::Single(_) => {
91 let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
92 if let Self::Single(single) = old {
93 self.push(single);
94 }
95 self.push(req);
96 }
97 }
98 }
99}
100
101#[derive(Clone, Debug)]
103pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
104 Single(Response<Payload, ErrData>),
106 Batch(Vec<Response<Payload, ErrData>>),
108}
109
110impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
111 for ResponsePacket<Payload, ErrData>
112{
113 fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
114 let mut iter = iter.into_iter().peekable();
115 if let Some(first) = iter.next() {
117 return if iter.peek().is_none() {
118 Self::Single(first)
119 } else {
120 let mut batch = Vec::new();
121 batch.push(first);
122 batch.extend(iter);
123 Self::Batch(batch)
124 };
125 }
126 Self::Batch(vec![])
127 }
128}
129
130impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
131 fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
132 if value.len() == 1 {
133 Self::Single(value.into_iter().next().unwrap())
134 } else {
135 Self::Batch(value)
136 }
137 }
138}
139
140impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
141where
142 Payload: Deserialize<'de>,
143 ErrData: Deserialize<'de>,
144{
145 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146 where
147 D: Deserializer<'de>,
148 {
149 struct ResponsePacketVisitor<Payload, ErrData> {
150 marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
151 }
152
153 impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
154 where
155 Payload: Deserialize<'de>,
156 ErrData: Deserialize<'de>,
157 {
158 type Value = ResponsePacket<Payload, ErrData>;
159
160 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
161 formatter.write_str("a single response or a batch of responses")
162 }
163
164 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
165 where
166 A: SeqAccess<'de>,
167 {
168 let mut responses = Vec::new();
169
170 while let Some(response) = seq.next_element()? {
171 responses.push(response);
172 }
173
174 Ok(ResponsePacket::Batch(responses))
175 }
176
177 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
178 where
179 M: MapAccess<'de>,
180 {
181 let response =
182 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
183 Ok(ResponsePacket::Single(response))
184 }
185 }
186
187 deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
188 }
189}
190
191pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
200
201impl BorrowedResponsePacket<'_> {
202 pub fn into_owned(self) -> ResponsePacket {
205 match self {
206 Self::Single(single) => ResponsePacket::Single(single.into_owned()),
207 Self::Batch(batch) => {
208 ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
209 }
210 }
211 }
212}
213
214impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
215 pub fn is_success(&self) -> bool {
219 match self {
220 Self::Single(single) => single.is_success(),
221 Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
222 }
223 }
224
225 pub fn is_error(&self) -> bool {
229 match self {
230 Self::Single(single) => single.is_error(),
231 Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
232 }
233 }
234
235 pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
239 self.iter_errors().next()
240 }
241
242 pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
244 match self {
245 Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
246 Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
247 }
248 }
249
250 pub fn responses_by_ids(&self, ids: &HashSet<Id>) -> Vec<&Response<Payload, ErrData>> {
262 match self {
263 Self::Single(single) if ids.contains(&single.id) => vec![single],
264 Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
265 _ => Vec::new(),
266 }
267 }
268}
269
270#[derive(Clone, Debug)]
272enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
273 Single(Option<&'a Response<Payload, ErrData>>),
274 Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
275}
276
277impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
278 type Item = &'a ErrorPayload<ErrData>;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 match self {
282 ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
283 ResponsePacketErrorsIter::Batch(batch) => loop {
284 let res = batch.next()?;
285 if let Some(err) = res.payload.as_error() {
286 return Some(err);
287 }
288 },
289 }
290 }
291}