1use crate::{ErrorPayload, Id, Response, ResponsePayload, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use http::HeaderMap;
4use serde::{
5 de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
6 Deserialize, Serialize,
7};
8use serde_json::value::RawValue;
9use std::{borrow::Borrow, fmt, hash::Hash, marker::PhantomData};
10
11#[derive(Clone, Debug)]
14pub enum RequestPacket {
15 Single(SerializedRequest),
17 Batch(Vec<SerializedRequest>),
19}
20
21impl FromIterator<SerializedRequest> for RequestPacket {
22 fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
23 Self::Batch(iter.into_iter().collect())
24 }
25}
26
27impl From<SerializedRequest> for RequestPacket {
28 fn from(req: SerializedRequest) -> Self {
29 Self::Single(req)
30 }
31}
32
33impl Serialize for RequestPacket {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: serde::Serializer,
37 {
38 match self {
39 Self::Single(single) => single.serialize(serializer),
40 Self::Batch(batch) => batch.serialize(serializer),
41 }
42 }
43}
44
45impl RequestPacket {
46 pub fn with_capacity(capacity: usize) -> Self {
48 Self::Batch(Vec::with_capacity(capacity))
49 }
50
51 pub const fn as_single(&self) -> Option<&SerializedRequest> {
53 match self {
54 Self::Single(req) => Some(req),
55 Self::Batch(_) => None,
56 }
57 }
58
59 pub fn as_batch(&self) -> Option<&[SerializedRequest]> {
61 match self {
62 Self::Batch(req) => Some(req.as_slice()),
63 Self::Single(_) => None,
64 }
65 }
66
67 pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
69 match self {
70 Self::Single(single) => Ok(single.take_request()),
71 Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
72 }
73 }
74
75 pub fn subscription_request_ids(&self) -> HashSet<&Id> {
77 match self {
78 Self::Single(single) => {
79 let id = single.is_subscription().then(|| single.id());
80 HashSet::from_iter(id)
81 }
82 Self::Batch(batch) => {
83 batch.iter().filter(|req| req.is_subscription()).map(|req| req.id()).collect()
84 }
85 }
86 }
87
88 pub fn len(&self) -> usize {
90 match self {
91 Self::Single(_) => 1,
92 Self::Batch(batch) => batch.len(),
93 }
94 }
95
96 pub fn is_empty(&self) -> bool {
98 self.len() == 0
99 }
100
101 pub fn push(&mut self, req: SerializedRequest) {
103 match self {
104 Self::Batch(batch) => batch.push(req),
105 Self::Single(_) => {
106 let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
107 if let Self::Single(single) = old {
108 self.push(single);
109 }
110 self.push(req);
111 }
112 }
113 }
114
115 pub fn requests(&self) -> &[SerializedRequest] {
117 match self {
118 Self::Single(req) => std::slice::from_ref(req),
119 Self::Batch(req) => req.as_slice(),
120 }
121 }
122
123 pub fn requests_mut(&mut self) -> &mut [SerializedRequest] {
125 match self {
126 Self::Single(req) => std::slice::from_mut(req),
127 Self::Batch(req) => req.as_mut_slice(),
128 }
129 }
130
131 pub fn method_names(&self) -> impl Iterator<Item = &str> + '_ {
133 self.requests().iter().map(|req| req.method())
134 }
135
136 pub fn headers(&self) -> HeaderMap {
139 let Some(single_req) = self.as_single() else {
141 return HeaderMap::new();
142 };
143 if let Some(http_header_extension) = single_req.meta().extensions().get::<HeaderMap>() {
145 return http_header_extension.clone();
146 };
147
148 HeaderMap::new()
149 }
150}
151
152#[derive(Clone, Debug)]
154pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
155 Single(Response<Payload, ErrData>),
157 Batch(Vec<Response<Payload, ErrData>>),
159}
160
161impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
162 for ResponsePacket<Payload, ErrData>
163{
164 fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
165 let mut iter = iter.into_iter().peekable();
166 if let Some(first) = iter.next() {
168 return if iter.peek().is_none() {
169 Self::Single(first)
170 } else {
171 let mut batch = Vec::new();
172 batch.push(first);
173 batch.extend(iter);
174 Self::Batch(batch)
175 };
176 }
177 Self::Batch(vec![])
178 }
179}
180
181impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
182 fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
183 if value.len() == 1 {
184 Self::Single(value.into_iter().next().unwrap())
185 } else {
186 Self::Batch(value)
187 }
188 }
189}
190
191impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
192where
193 Payload: Deserialize<'de>,
194 ErrData: Deserialize<'de>,
195{
196 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
197 where
198 D: Deserializer<'de>,
199 {
200 struct ResponsePacketVisitor<Payload, ErrData> {
201 marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
202 }
203
204 impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
205 where
206 Payload: Deserialize<'de>,
207 ErrData: Deserialize<'de>,
208 {
209 type Value = ResponsePacket<Payload, ErrData>;
210
211 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
212 formatter.write_str("a single response or a batch of responses")
213 }
214
215 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
216 where
217 A: SeqAccess<'de>,
218 {
219 let mut responses = Vec::new();
220
221 while let Some(response) = seq.next_element()? {
222 responses.push(response);
223 }
224
225 Ok(ResponsePacket::Batch(responses))
226 }
227
228 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
229 where
230 M: MapAccess<'de>,
231 {
232 let response =
233 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
234 Ok(ResponsePacket::Single(response))
235 }
236 }
237
238 deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
239 }
240}
241
242pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
251
252impl BorrowedResponsePacket<'_> {
253 pub fn into_owned(self) -> ResponsePacket {
256 match self {
257 Self::Single(single) => ResponsePacket::Single(single.into_owned()),
258 Self::Batch(batch) => {
259 ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
260 }
261 }
262 }
263}
264
265impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
266 pub const fn as_single(&self) -> Option<&Response<Payload, ErrData>> {
268 match self {
269 Self::Single(resp) => Some(resp),
270 Self::Batch(_) => None,
271 }
272 }
273
274 pub fn as_batch(&self) -> Option<&[Response<Payload, ErrData>]> {
276 match self {
277 Self::Batch(resp) => Some(resp.as_slice()),
278 Self::Single(_) => None,
279 }
280 }
281
282 pub fn single_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
284 self.as_single().map(|resp| &resp.payload)
285 }
286
287 pub fn is_success(&self) -> bool {
291 match self {
292 Self::Single(single) => single.is_success(),
293 Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
294 }
295 }
296
297 pub fn is_error(&self) -> bool {
301 match self {
302 Self::Single(single) => single.is_error(),
303 Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
304 }
305 }
306
307 pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
311 self.iter_errors().next()
312 }
313
314 pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
316 match self {
317 Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
318 Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
319 }
320 }
321
322 pub fn first_error_code(&self) -> Option<i64> {
324 self.as_error().map(|error| error.code)
325 }
326
327 pub fn first_error_message(&self) -> Option<&str> {
329 self.as_error().map(|error| error.message.as_ref())
330 }
331
332 pub fn first_error_data(&self) -> Option<&ErrData> {
334 self.as_error().and_then(|error| error.data.as_ref())
335 }
336
337 pub fn responses(&self) -> &[Response<Payload, ErrData>] {
339 match self {
340 Self::Single(req) => std::slice::from_ref(req),
341 Self::Batch(req) => req.as_slice(),
342 }
343 }
344
345 pub fn payloads(&self) -> impl Iterator<Item = &ResponsePayload<Payload, ErrData>> + '_ {
347 self.responses().iter().map(|resp| &resp.payload)
348 }
349
350 pub fn first_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
352 self.payloads().next()
353 }
354
355 pub fn response_ids(&self) -> impl Iterator<Item = &Id> + '_ {
357 self.responses().iter().map(|resp| &resp.id)
358 }
359
360 pub fn responses_by_ids<K>(&self, ids: &HashSet<K>) -> Vec<&Response<Payload, ErrData>>
372 where
373 K: Borrow<Id> + Eq + Hash,
374 {
375 match self {
376 Self::Single(single) if ids.contains(&single.id) => vec![single],
377 Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
378 _ => Vec::new(),
379 }
380 }
381}
382
383#[derive(Clone, Debug)]
385enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
386 Single(Option<&'a Response<Payload, ErrData>>),
387 Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
388}
389
390impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
391 type Item = &'a ErrorPayload<ErrData>;
392
393 fn next(&mut self) -> Option<Self::Item> {
394 match self {
395 ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
396 ResponsePacketErrorsIter::Batch(batch) => loop {
397 let res = batch.next()?;
398 if let Some(err) = res.payload.as_error() {
399 return Some(err);
400 }
401 },
402 }
403 }
404}