jsonrpc_sys/
batch.rs

1use serde::de::{SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use crate::Request;
5
6/// Represents either one, or multiple JSON-RPC [`Request`]s.
7///
8/// Note that this type is only really useful for deserializing requests, using the
9/// [`UnknownParams`] type as the parameter type.
10///
11/// [`UnknownParams`]: crate::UnknownParams
12#[derive(Debug, Clone)]
13pub enum MaybeBatchedRequests<'a, P> {
14    /// A single request.
15    Single(Request<'a, P>),
16    /// A batch of requests.
17    Batch(Vec<Request<'a, P>>),
18}
19
20impl<'a, T> Serialize for MaybeBatchedRequests<'a, T>
21where
22    T: Clone + Serialize,
23{
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: Serializer,
27    {
28        match self {
29            Self::Batch(batch) => batch.serialize(serializer),
30            Self::Single(single) => single.serialize(serializer),
31        }
32    }
33}
34
35impl<'de, 'a, P> Deserialize<'de> for MaybeBatchedRequests<'a, P>
36where
37    'de: 'a,
38    P: Deserialize<'de>,
39{
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        struct MaybeBatchedVisitor<P>(std::marker::PhantomData<P>);
45
46        impl<'de, P> Visitor<'de> for MaybeBatchedVisitor<P>
47        where
48            P: Deserialize<'de>,
49        {
50            type Value = MaybeBatchedRequests<'de, P>;
51
52            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
53                formatter.write_str("a JSON-RPC 2.0 request")
54            }
55
56            fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
57            where
58                A: SeqAccess<'de>,
59            {
60                Vec::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
61                    .map(MaybeBatchedRequests::Batch)
62            }
63
64            fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
65            where
66                A: serde::de::MapAccess<'de>,
67            {
68                Request::deserialize(serde::de::value::MapAccessDeserializer::new(map))
69                    .map(MaybeBatchedRequests::Single)
70            }
71        }
72
73        deserializer.deserialize_any(MaybeBatchedVisitor(std::marker::PhantomData))
74    }
75}