alloy_json_rpc/
request.rs

1use crate::{common::Id, RpcBorrow, RpcSend};
2use alloy_primitives::{keccak256, B256};
3use serde::{
4    de::{DeserializeOwned, MapAccess},
5    ser::SerializeMap,
6    Deserialize, Serialize,
7};
8use serde_json::value::RawValue;
9use std::{borrow::Cow, marker::PhantomData, mem::MaybeUninit};
10
11/// `RequestMeta` contains the [`Id`] and method name of a request.
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct RequestMeta {
14    /// The method name.
15    pub method: Cow<'static, str>,
16    /// The request ID.
17    pub id: Id,
18    /// Whether the request is a subscription, other than `eth_subscribe`.
19    is_subscription: bool,
20}
21
22impl RequestMeta {
23    /// Create a new `RequestMeta`.
24    pub const fn new(method: Cow<'static, str>, id: Id) -> Self {
25        Self { method, id, is_subscription: false }
26    }
27
28    /// Returns `true` if the request is a subscription.
29    pub fn is_subscription(&self) -> bool {
30        self.is_subscription || self.method == "eth_subscribe"
31    }
32
33    /// Indicates that the request is a non-standard subscription (i.e. not
34    /// "eth_subscribe").
35    pub fn set_is_subscription(&mut self) {
36        self.set_subscription_status(true);
37    }
38
39    /// Setter for `is_subscription`. Indicates to RPC clients that the request
40    /// triggers a stream of notifications.
41    pub fn set_subscription_status(&mut self, sub: bool) {
42        self.is_subscription = sub;
43    }
44}
45
46/// A JSON-RPC 2.0 request object.
47///
48/// This is a generic type that can be used to represent any JSON-RPC request.
49/// The `Params` type parameter is used to represent the parameters of the
50/// request, and the `method` field is used to represent the method name.
51///
52/// ### Note
53///
54/// The value of `method` should be known at compile time.
55#[derive(Clone, Debug, PartialEq, Eq)]
56pub struct Request<Params> {
57    /// The request metadata (ID and method).
58    pub meta: RequestMeta,
59    /// The request parameters.
60    pub params: Params,
61}
62
63impl<Params> Request<Params> {
64    /// Create a new `Request`.
65    pub fn new(method: impl Into<Cow<'static, str>>, id: Id, params: Params) -> Self {
66        Self { meta: RequestMeta::new(method.into(), id), params }
67    }
68
69    /// Returns `true` if the request is a subscription.
70    pub fn is_subscription(&self) -> bool {
71        self.meta.is_subscription()
72    }
73
74    /// Indicates that the request is a non-standard subscription (i.e. not
75    /// "eth_subscribe").
76    pub fn set_is_subscription(&mut self) {
77        self.meta.set_is_subscription()
78    }
79
80    /// Setter for `is_subscription`. Indicates to RPC clients that the request
81    /// triggers a stream of notifications.
82    pub fn set_subscription_status(&mut self, sub: bool) {
83        self.meta.set_subscription_status(sub);
84    }
85
86    /// Change type of the request parameters.
87    pub fn map_params<NewParams>(
88        self,
89        map: impl FnOnce(Params) -> NewParams,
90    ) -> Request<NewParams> {
91        Request { meta: self.meta, params: map(self.params) }
92    }
93}
94
95/// A [`Request`] that has been partially serialized.
96///
97/// The request parameters have been serialized, and are represented as a boxed [`RawValue`]. This
98/// is useful for collections containing many requests, as it erases the `Param` type. It can be
99/// created with [`Request::box_params()`].
100///
101/// See the [top-level docs] for more info.
102///
103/// [top-level docs]: crate
104pub type PartiallySerializedRequest = Request<Box<RawValue>>;
105
106impl<Params> Request<Params>
107where
108    Params: RpcSend,
109{
110    /// Serialize the request parameters as a boxed [`RawValue`].
111    ///
112    /// # Panics
113    ///
114    /// If serialization of the params fails.
115    pub fn box_params(self) -> PartiallySerializedRequest {
116        Request { meta: self.meta, params: serde_json::value::to_raw_value(&self.params).unwrap() }
117    }
118
119    /// Serialize the request, including the request parameters.
120    pub fn serialize(self) -> serde_json::Result<SerializedRequest> {
121        let request = serde_json::value::to_raw_value(&self)?;
122        Ok(SerializedRequest { meta: self.meta, request })
123    }
124}
125
126impl<Params> Request<&Params>
127where
128    Params: ToOwned,
129    Params::Owned: RpcSend,
130{
131    /// Clone the request, including the request parameters.
132    pub fn into_owned_params(self) -> Request<Params::Owned> {
133        Request { meta: self.meta, params: self.params.to_owned() }
134    }
135}
136
137impl<'a, Params> Request<Params>
138where
139    Params: AsRef<RawValue> + 'a,
140{
141    /// Attempt to deserialize the params.
142    ///
143    /// To borrow from the params via the deserializer, use
144    /// [`Request::try_borrow_params_as`].
145    ///
146    /// # Returns
147    /// - `Ok(T)` if the params can be deserialized as `T`
148    /// - `Err(e)` if the params cannot be deserialized as `T`
149    pub fn try_params_as<T: DeserializeOwned>(&self) -> serde_json::Result<T> {
150        serde_json::from_str(self.params.as_ref().get())
151    }
152
153    /// Attempt to deserialize the params, borrowing from the params
154    ///
155    /// # Returns
156    /// - `Ok(T)` if the params can be deserialized as `T`
157    /// - `Err(e)` if the params cannot be deserialized as `T`
158    pub fn try_borrow_params_as<T: Deserialize<'a>>(&'a self) -> serde_json::Result<T> {
159        serde_json::from_str(self.params.as_ref().get())
160    }
161}
162
163// manually implemented to avoid adding a type for the protocol-required
164// `jsonrpc` field
165impl<Params> Serialize for Request<Params>
166where
167    Params: RpcSend,
168{
169    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
170    where
171        S: serde::Serializer,
172    {
173        let sized_params = std::mem::size_of::<Params>() != 0;
174
175        let mut map = serializer.serialize_map(Some(3 + sized_params as usize))?;
176        map.serialize_entry("method", &self.meta.method[..])?;
177
178        // Params may be omitted if it is 0-sized
179        if sized_params {
180            map.serialize_entry("params", &self.params)?;
181        }
182
183        map.serialize_entry("id", &self.meta.id)?;
184        map.serialize_entry("jsonrpc", "2.0")?;
185        map.end()
186    }
187}
188
189impl<'de, Params> Deserialize<'de> for Request<Params>
190where
191    Params: RpcBorrow<'de>,
192{
193    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194    where
195        D: serde::Deserializer<'de>,
196    {
197        struct Visitor<Params>(PhantomData<Params>);
198        impl<'de, Params> serde::de::Visitor<'de> for Visitor<Params>
199        where
200            Params: RpcBorrow<'de>,
201        {
202            type Value = Request<Params>;
203
204            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205                write!(
206                    formatter,
207                    "a JSON-RPC 2.0 request object with params of type {}",
208                    std::any::type_name::<Params>()
209                )
210            }
211
212            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
213            where
214                A: MapAccess<'de>,
215            {
216                let mut id = None;
217                let mut params = None;
218                let mut method = None;
219                let mut jsonrpc = None;
220
221                while let Some(key) = map.next_key()? {
222                    match key {
223                        "id" => {
224                            if id.is_some() {
225                                return Err(serde::de::Error::duplicate_field("id"));
226                            }
227                            id = Some(map.next_value()?);
228                        }
229                        "params" => {
230                            if params.is_some() {
231                                return Err(serde::de::Error::duplicate_field("params"));
232                            }
233                            params = Some(map.next_value()?);
234                        }
235                        "method" => {
236                            if method.is_some() {
237                                return Err(serde::de::Error::duplicate_field("method"));
238                            }
239                            method = Some(map.next_value()?);
240                        }
241                        "jsonrpc" => {
242                            let version: String = map.next_value()?;
243                            if version != "2.0" {
244                                return Err(serde::de::Error::custom(format!(
245                                    "unsupported JSON-RPC version: {}",
246                                    version
247                                )));
248                            }
249                            jsonrpc = Some(());
250                        }
251                        other => {
252                            return Err(serde::de::Error::unknown_field(
253                                other,
254                                &["id", "params", "method", "jsonrpc"],
255                            ));
256                        }
257                    }
258                }
259                if jsonrpc.is_none() {
260                    return Err(serde::de::Error::missing_field("jsonrpc"));
261                }
262                if method.is_none() {
263                    return Err(serde::de::Error::missing_field("method"));
264                }
265
266                if params.is_none() {
267                    if std::mem::size_of::<Params>() == 0 {
268                        // SAFETY: params is a ZST, so it's safe to fail to initialize it
269                        unsafe { params = Some(MaybeUninit::<Params>::zeroed().assume_init()) }
270                    } else {
271                        return Err(serde::de::Error::missing_field("params"));
272                    }
273                }
274
275                Ok(Request {
276                    meta: RequestMeta::new(method.unwrap(), id.unwrap_or(Id::None)),
277                    params: params.unwrap(),
278                })
279            }
280        }
281
282        deserializer.deserialize_map(Visitor(PhantomData))
283    }
284}
285
286/// A JSON-RPC 2.0 request object that has been serialized, with its [`Id`] and
287/// method preserved.
288///
289/// This struct is used to represent a request that has been serialized, but
290/// not yet sent. It is used by RPC clients to build batch requests and manage
291/// in-flight requests.
292#[derive(Clone, Debug)]
293pub struct SerializedRequest {
294    meta: RequestMeta,
295    request: Box<RawValue>,
296}
297
298impl<Params> TryFrom<Request<Params>> for SerializedRequest
299where
300    Params: RpcSend,
301{
302    type Error = serde_json::Error;
303
304    fn try_from(value: Request<Params>) -> Result<Self, Self::Error> {
305        value.serialize()
306    }
307}
308
309impl SerializedRequest {
310    /// Returns the request metadata (ID and Method).
311    pub const fn meta(&self) -> &RequestMeta {
312        &self.meta
313    }
314
315    /// Returns the request ID.
316    pub const fn id(&self) -> &Id {
317        &self.meta.id
318    }
319
320    /// Returns the request method.
321    pub fn method(&self) -> &str {
322        &self.meta.method
323    }
324
325    /// Mark the request as a non-standard subscription (i.e. not
326    /// `eth_subscribe`)
327    pub fn set_is_subscription(&mut self) {
328        self.meta.set_is_subscription();
329    }
330
331    /// Returns `true` if the request is a subscription.
332    pub fn is_subscription(&self) -> bool {
333        self.meta.is_subscription()
334    }
335
336    /// Returns the serialized request.
337    pub const fn serialized(&self) -> &RawValue {
338        &self.request
339    }
340
341    /// Consume the serialized request, returning the underlying [`RawValue`].
342    pub fn into_serialized(self) -> Box<RawValue> {
343        self.request
344    }
345
346    /// Consumes the serialized request, returning the underlying
347    /// [`RequestMeta`] and the [`RawValue`].
348    pub fn decompose(self) -> (RequestMeta, Box<RawValue>) {
349        (self.meta, self.request)
350    }
351
352    /// Take the serialized request, consuming the [`SerializedRequest`].
353    pub fn take_request(self) -> Box<RawValue> {
354        self.request
355    }
356
357    /// Get a reference to the serialized request's params.
358    ///
359    /// This partially deserializes the request, and should be avoided if
360    /// possible.
361    pub fn params(&self) -> Option<&RawValue> {
362        #[derive(Deserialize)]
363        struct Req<'a> {
364            #[serde(borrow)]
365            params: Option<&'a RawValue>,
366        }
367
368        let req: Req<'_> = serde_json::from_str(self.request.get()).unwrap();
369        req.params
370    }
371
372    /// Get the hash of the serialized request's params.
373    ///
374    /// This partially deserializes the request, and should be avoided if
375    /// possible.
376    pub fn params_hash(&self) -> B256 {
377        self.params().map_or_else(|| keccak256(""), |params| keccak256(params.get()))
378    }
379}
380
381impl Serialize for SerializedRequest {
382    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
383    where
384        S: serde::Serializer,
385    {
386        self.request.serialize(serializer)
387    }
388}
389
390#[cfg(test)]
391mod test {
392    use super::*;
393    use crate::RpcObject;
394
395    fn test_inner<T: RpcObject + PartialEq>(t: T) {
396        let ser = serde_json::to_string(&t).unwrap();
397        let de: T = serde_json::from_str(&ser).unwrap();
398        let reser = serde_json::to_string(&de).unwrap();
399        assert_eq!(de, t, "deser error for {}", std::any::type_name::<T>());
400        assert_eq!(ser, reser, "reser error for {}", std::any::type_name::<T>());
401    }
402
403    #[test]
404    fn test_ser_deser() {
405        test_inner(Request::<()>::new("test", 1.into(), ()));
406        test_inner(Request::<u64>::new("test", "hello".to_string().into(), 1));
407        test_inner(Request::<String>::new("test", Id::None, "test".to_string()));
408        test_inner(Request::<Vec<u64>>::new("test", u64::MAX.into(), vec![1, 2, 3]));
409    }
410}