1use crate::{ErrorPayload, Id, Response, ResponsePayload, SerializedRequest};
2use alloy_primitives::map::HashSet;
3use http::HeaderMap;
4use serde::{
5 de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
6 Deserialize, Serialize,
7};
8use serde_json::value::RawValue;
9use std::{borrow::Borrow, fmt, hash::Hash, marker::PhantomData};
10
11#[derive(Clone, Debug)]
14pub enum RequestPacket {
15 Single(SerializedRequest),
17 Batch(Vec<SerializedRequest>),
19}
20
21impl FromIterator<SerializedRequest> for RequestPacket {
22 fn from_iter<T: IntoIterator<Item = SerializedRequest>>(iter: T) -> Self {
23 Self::Batch(iter.into_iter().collect())
24 }
25}
26
27impl From<SerializedRequest> for RequestPacket {
28 fn from(req: SerializedRequest) -> Self {
29 Self::Single(req)
30 }
31}
32
33impl Serialize for RequestPacket {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: serde::Serializer,
37 {
38 match self {
39 Self::Single(single) => single.serialize(serializer),
40 Self::Batch(batch) => batch.serialize(serializer),
41 }
42 }
43}
44
45impl RequestPacket {
46 pub fn with_capacity(capacity: usize) -> Self {
48 Self::Batch(Vec::with_capacity(capacity))
49 }
50
51 pub const fn as_single(&self) -> Option<&SerializedRequest> {
53 match self {
54 Self::Single(req) => Some(req),
55 Self::Batch(_) => None,
56 }
57 }
58
59 pub const fn as_batch(&self) -> Option<&[SerializedRequest]> {
61 match self {
62 Self::Batch(req) => Some(req.as_slice()),
63 Self::Single(_) => None,
64 }
65 }
66
67 pub fn serialize(self) -> serde_json::Result<Box<RawValue>> {
69 match self {
70 Self::Single(single) => Ok(single.take_request()),
71 Self::Batch(batch) => serde_json::value::to_raw_value(&batch),
72 }
73 }
74
75 pub fn subscription_request_ids(&self) -> HashSet<&Id> {
77 match self {
78 Self::Single(single) => {
79 let id = single.is_subscription().then(|| single.id());
80 HashSet::from_iter(id)
81 }
82 Self::Batch(batch) => {
83 batch.iter().filter(|req| req.is_subscription()).map(|req| req.id()).collect()
84 }
85 }
86 }
87
88 pub const fn len(&self) -> usize {
90 match self {
91 Self::Single(_) => 1,
92 Self::Batch(batch) => batch.len(),
93 }
94 }
95
96 pub const fn is_empty(&self) -> bool {
98 self.len() == 0
99 }
100
101 pub fn push(&mut self, req: SerializedRequest) {
103 match self {
104 Self::Batch(batch) => batch.push(req),
105 Self::Single(_) => {
106 let old = std::mem::replace(self, Self::Batch(Vec::with_capacity(10)));
107 if let Self::Single(single) = old {
108 self.push(single);
109 }
110 self.push(req);
111 }
112 }
113 }
114
115 pub const fn requests(&self) -> &[SerializedRequest] {
117 match self {
118 Self::Single(req) => std::slice::from_ref(req),
119 Self::Batch(req) => req.as_slice(),
120 }
121 }
122
123 pub const fn requests_mut(&mut self) -> &mut [SerializedRequest] {
125 match self {
126 Self::Single(req) => std::slice::from_mut(req),
127 Self::Batch(req) => req.as_mut_slice(),
128 }
129 }
130
131 pub fn method_names(&self) -> impl Iterator<Item = &str> + '_ {
133 self.requests().iter().map(|req| req.method())
134 }
135
136 pub fn headers(&self) -> HeaderMap {
139 self.requests().iter().fold(HeaderMap::new(), |mut acc, req| {
140 if let Some(http_header_extension) = req.meta().extensions().get::<HeaderMap>() {
141 acc.extend(http_header_extension.iter().map(|(k, v)| (k.clone(), v.clone())));
142 };
143 acc
144 })
145 }
146}
147
148#[derive(Clone, Debug)]
150pub enum ResponsePacket<Payload = Box<RawValue>, ErrData = Box<RawValue>> {
151 Single(Response<Payload, ErrData>),
153 Batch(Vec<Response<Payload, ErrData>>),
155}
156
157impl<Payload, ErrData> FromIterator<Response<Payload, ErrData>>
158 for ResponsePacket<Payload, ErrData>
159{
160 fn from_iter<T: IntoIterator<Item = Response<Payload, ErrData>>>(iter: T) -> Self {
161 let mut iter = iter.into_iter().peekable();
162 if let Some(first) = iter.next() {
164 return if iter.peek().is_none() {
165 Self::Single(first)
166 } else {
167 let mut batch = Vec::new();
168 batch.push(first);
169 batch.extend(iter);
170 Self::Batch(batch)
171 };
172 }
173 Self::Batch(vec![])
174 }
175}
176
177impl<Payload, ErrData> From<Vec<Response<Payload, ErrData>>> for ResponsePacket<Payload, ErrData> {
178 fn from(value: Vec<Response<Payload, ErrData>>) -> Self {
179 if value.len() == 1 {
180 Self::Single(value.into_iter().next().unwrap())
181 } else {
182 Self::Batch(value)
183 }
184 }
185}
186
187impl<'de, Payload, ErrData> Deserialize<'de> for ResponsePacket<Payload, ErrData>
188where
189 Payload: Deserialize<'de>,
190 ErrData: Deserialize<'de>,
191{
192 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193 where
194 D: Deserializer<'de>,
195 {
196 struct ResponsePacketVisitor<Payload, ErrData> {
197 marker: PhantomData<fn() -> ResponsePacket<Payload, ErrData>>,
198 }
199
200 impl<'de, Payload, ErrData> Visitor<'de> for ResponsePacketVisitor<Payload, ErrData>
201 where
202 Payload: Deserialize<'de>,
203 ErrData: Deserialize<'de>,
204 {
205 type Value = ResponsePacket<Payload, ErrData>;
206
207 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
208 formatter.write_str("a single response or a batch of responses")
209 }
210
211 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
212 where
213 A: SeqAccess<'de>,
214 {
215 let mut responses = Vec::new();
216
217 while let Some(response) = seq.next_element()? {
218 responses.push(response);
219 }
220
221 Ok(ResponsePacket::Batch(responses))
222 }
223
224 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
225 where
226 M: MapAccess<'de>,
227 {
228 let response =
229 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
230 Ok(ResponsePacket::Single(response))
231 }
232 }
233
234 deserializer.deserialize_any(ResponsePacketVisitor { marker: PhantomData })
235 }
236}
237
238pub type BorrowedResponsePacket<'a> = ResponsePacket<&'a RawValue, &'a RawValue>;
247
248impl BorrowedResponsePacket<'_> {
249 pub fn into_owned(self) -> ResponsePacket {
252 match self {
253 Self::Single(single) => ResponsePacket::Single(single.into_owned()),
254 Self::Batch(batch) => {
255 ResponsePacket::Batch(batch.into_iter().map(Response::into_owned).collect())
256 }
257 }
258 }
259}
260
261impl<Payload, ErrData> ResponsePacket<Payload, ErrData> {
262 pub const fn as_single(&self) -> Option<&Response<Payload, ErrData>> {
264 match self {
265 Self::Single(resp) => Some(resp),
266 Self::Batch(_) => None,
267 }
268 }
269
270 pub const fn as_batch(&self) -> Option<&[Response<Payload, ErrData>]> {
272 match self {
273 Self::Batch(resp) => Some(resp.as_slice()),
274 Self::Single(_) => None,
275 }
276 }
277
278 pub fn single_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
280 self.as_single().map(|resp| &resp.payload)
281 }
282
283 pub fn is_success(&self) -> bool {
287 match self {
288 Self::Single(single) => single.is_success(),
289 Self::Batch(batch) => batch.iter().all(|res| res.is_success()),
290 }
291 }
292
293 pub fn is_error(&self) -> bool {
297 match self {
298 Self::Single(single) => single.is_error(),
299 Self::Batch(batch) => batch.iter().any(|res| res.is_error()),
300 }
301 }
302
303 pub fn as_error(&self) -> Option<&ErrorPayload<ErrData>> {
307 self.iter_errors().next()
308 }
309
310 pub fn iter_errors(&self) -> impl Iterator<Item = &ErrorPayload<ErrData>> + '_ {
312 match self {
313 Self::Single(single) => ResponsePacketErrorsIter::Single(Some(single)),
314 Self::Batch(batch) => ResponsePacketErrorsIter::Batch(batch.iter()),
315 }
316 }
317
318 pub fn first_error_code(&self) -> Option<i64> {
320 self.as_error().map(|error| error.code)
321 }
322
323 pub fn first_error_message(&self) -> Option<&str> {
325 self.as_error().map(|error| error.message.as_ref())
326 }
327
328 pub fn first_error_data(&self) -> Option<&ErrData> {
330 self.as_error().and_then(|error| error.data.as_ref())
331 }
332
333 pub const fn responses(&self) -> &[Response<Payload, ErrData>] {
335 match self {
336 Self::Single(req) => std::slice::from_ref(req),
337 Self::Batch(req) => req.as_slice(),
338 }
339 }
340
341 pub fn payloads(&self) -> impl Iterator<Item = &ResponsePayload<Payload, ErrData>> + '_ {
343 self.responses().iter().map(|resp| &resp.payload)
344 }
345
346 pub fn first_payload(&self) -> Option<&ResponsePayload<Payload, ErrData>> {
348 self.payloads().next()
349 }
350
351 pub fn response_ids(&self) -> impl Iterator<Item = &Id> + '_ {
353 self.responses().iter().map(|resp| &resp.id)
354 }
355
356 pub fn responses_by_ids<K>(&self, ids: &HashSet<K>) -> Vec<&Response<Payload, ErrData>>
368 where
369 K: Borrow<Id> + Eq + Hash,
370 {
371 match self {
372 Self::Single(single) if ids.contains(&single.id) => vec![single],
373 Self::Batch(batch) => batch.iter().filter(|res| ids.contains(&res.id)).collect(),
374 _ => Vec::new(),
375 }
376 }
377}
378
379#[derive(Clone, Debug)]
381enum ResponsePacketErrorsIter<'a, Payload, ErrData> {
382 Single(Option<&'a Response<Payload, ErrData>>),
383 Batch(std::slice::Iter<'a, Response<Payload, ErrData>>),
384}
385
386impl<'a, Payload, ErrData> Iterator for ResponsePacketErrorsIter<'a, Payload, ErrData> {
387 type Item = &'a ErrorPayload<ErrData>;
388
389 fn next(&mut self) -> Option<Self::Item> {
390 match self {
391 ResponsePacketErrorsIter::Single(single) => single.take()?.payload.as_error(),
392 ResponsePacketErrorsIter::Batch(batch) => loop {
393 let res = batch.next()?;
394 if let Some(err) = res.payload.as_error() {
395 return Some(err);
396 }
397 },
398 }
399 }
400}