1use serde::de::{SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use crate::Request;
5
6#[derive(Debug, Clone)]
13pub enum MaybeBatchedRequests<'a, P> {
14 Single(Request<'a, P>),
16 Batch(Vec<Request<'a, P>>),
18}
19
20impl<'a, T> Serialize for MaybeBatchedRequests<'a, T>
21where
22 T: Clone + Serialize,
23{
24 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25 where
26 S: Serializer,
27 {
28 match self {
29 Self::Batch(batch) => batch.serialize(serializer),
30 Self::Single(single) => single.serialize(serializer),
31 }
32 }
33}
34
35impl<'de, 'a, P> Deserialize<'de> for MaybeBatchedRequests<'a, P>
36where
37 'de: 'a,
38 P: Deserialize<'de>,
39{
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: Deserializer<'de>,
43 {
44 struct MaybeBatchedVisitor<P>(std::marker::PhantomData<P>);
45
46 impl<'de, P> Visitor<'de> for MaybeBatchedVisitor<P>
47 where
48 P: Deserialize<'de>,
49 {
50 type Value = MaybeBatchedRequests<'de, P>;
51
52 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
53 formatter.write_str("a JSON-RPC 2.0 request")
54 }
55
56 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
57 where
58 A: SeqAccess<'de>,
59 {
60 Vec::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
61 .map(MaybeBatchedRequests::Batch)
62 }
63
64 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
65 where
66 A: serde::de::MapAccess<'de>,
67 {
68 Request::deserialize(serde::de::value::MapAccessDeserializer::new(map))
69 .map(MaybeBatchedRequests::Single)
70 }
71 }
72
73 deserializer.deserialize_any(MaybeBatchedVisitor(std::marker::PhantomData))
74 }
75}