1use alloy_primitives::Bytes;
2use alloy_sol_types::{SolError, SolInterface};
3use serde::{
4    de::{DeserializeOwned, MapAccess, Visitor},
5    Deserialize, Deserializer, Serialize,
6};
7use serde_json::{
8    value::{to_raw_value, RawValue},
9    Value,
10};
11use std::{
12    borrow::{Borrow, Cow},
13    fmt,
14    marker::PhantomData,
15};
16
17use crate::RpcSend;
18
19const INTERNAL_ERROR: Cow<'static, str> = Cow::Borrowed("Internal error");
20
21#[derive(Clone, Debug, Serialize, PartialEq, Eq)]
27pub struct ErrorPayload<ErrData = Box<RawValue>> {
28    pub code: i64,
30    pub message: Cow<'static, str>,
32    pub data: Option<ErrData>,
34}
35
36impl<E> ErrorPayload<E> {
37    pub const fn parse_error() -> Self {
39        Self { code: -32700, message: Cow::Borrowed("Parse error"), data: None }
40    }
41
42    pub const fn invalid_request() -> Self {
44        Self { code: -32600, message: Cow::Borrowed("Invalid Request"), data: None }
45    }
46
47    pub const fn method_not_found() -> Self {
49        Self { code: -32601, message: Cow::Borrowed("Method not found"), data: None }
50    }
51
52    pub const fn invalid_params() -> Self {
54        Self { code: -32602, message: Cow::Borrowed("Invalid params"), data: None }
55    }
56
57    pub const fn internal_error() -> Self {
59        Self { code: -32603, message: INTERNAL_ERROR, data: None }
60    }
61
62    pub const fn internal_error_message(message: Cow<'static, str>) -> Self {
64        Self { code: -32603, message, data: None }
65    }
66
67    pub const fn internal_error_with_obj(data: E) -> Self
70    where
71        E: RpcSend,
72    {
73        Self { code: -32603, message: INTERNAL_ERROR, data: Some(data) }
74    }
75
76    pub const fn internal_error_with_message_and_obj(message: Cow<'static, str>, data: E) -> Self
78    where
79        E: RpcSend,
80    {
81        Self { code: -32603, message, data: Some(data) }
82    }
83
84    pub fn is_retry_err(&self) -> bool {
87        if self.code == 429 {
89            return true;
90        }
91
92        if self.code == -32005 {
94            return true;
95        }
96
97        if self.code == -32016 && self.message.contains("rate limit") {
99            return true;
100        }
101
102        if self.code == -32012 && self.message.contains("credits") {
105            return true;
106        }
107
108        if self.code == -32007 && self.message.contains("request limit reached") {
111            return true;
112        }
113
114        match self.message.as_ref() {
115            "header not found" => true,
117            "daily request count exceeded, request rate limited" => true,
119            msg => {
120                msg.contains("rate limit")
121                    || msg.contains("rate exceeded")
122                    || msg.contains("too many requests")
123                    || msg.contains("credits limited")
124                    || msg.contains("request limit")
125            }
126        }
127    }
128}
129
130impl<T> From<T> for ErrorPayload<T>
131where
132    T: std::error::Error + RpcSend,
133{
134    fn from(value: T) -> Self {
135        Self { code: -32603, message: INTERNAL_ERROR, data: Some(value) }
136    }
137}
138
139impl<E> ErrorPayload<E>
140where
141    E: RpcSend,
142{
143    pub fn serialize_payload(&self) -> serde_json::Result<ErrorPayload> {
145        Ok(ErrorPayload {
146            code: self.code,
147            message: self.message.clone(),
148            data: match self.data.as_ref() {
149                Some(data) => Some(to_raw_value(data)?),
150                None => None,
151            },
152        })
153    }
154}
155
156fn spelunk_revert(value: &Value) -> Option<Bytes> {
161    match value {
162        Value::String(s) => s.parse().ok(),
163        Value::Object(o) => o.values().find_map(spelunk_revert),
164        _ => None,
165    }
166}
167
168impl<ErrData: fmt::Display> fmt::Display for ErrorPayload<ErrData> {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        write!(
171            f,
172            "error code {}: {}{}",
173            self.code,
174            self.message,
175            self.data.as_ref().map(|data| format!(", data: {data}")).unwrap_or_default()
176        )
177    }
178}
179
180pub type BorrowedErrorPayload<'a> = ErrorPayload<&'a RawValue>;
188
189impl BorrowedErrorPayload<'_> {
190    pub fn into_owned(self) -> ErrorPayload {
193        ErrorPayload {
194            code: self.code,
195            message: self.message,
196            data: self.data.map(|data| data.to_owned()),
197        }
198    }
199}
200
201impl<'de, ErrData: Deserialize<'de>> Deserialize<'de> for ErrorPayload<ErrData> {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        enum Field {
207            Code,
208            Message,
209            Data,
210            Unknown,
211        }
212
213        impl<'de> Deserialize<'de> for Field {
214            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215            where
216                D: Deserializer<'de>,
217            {
218                struct FieldVisitor;
219
220                impl serde::de::Visitor<'_> for FieldVisitor {
221                    type Value = Field;
222
223                    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
224                        formatter.write_str("`code`, `message` and `data`")
225                    }
226
227                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
228                    where
229                        E: serde::de::Error,
230                    {
231                        match value {
232                            "code" => Ok(Field::Code),
233                            "message" => Ok(Field::Message),
234                            "data" => Ok(Field::Data),
235                            _ => Ok(Field::Unknown),
236                        }
237                    }
238                }
239                deserializer.deserialize_identifier(FieldVisitor)
240            }
241        }
242
243        struct ErrorPayloadVisitor<T>(PhantomData<T>);
244
245        impl<'de, Data> Visitor<'de> for ErrorPayloadVisitor<Data>
246        where
247            Data: Deserialize<'de>,
248        {
249            type Value = ErrorPayload<Data>;
250
251            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
252                write!(formatter, "a JSON-RPC 2.0 error object")
253            }
254
255            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
256            where
257                A: MapAccess<'de>,
258            {
259                let mut code = None;
260                let mut message = None;
261                let mut data = None;
262
263                while let Some(key) = map.next_key()? {
264                    match key {
265                        Field::Code => {
266                            if code.is_some() {
267                                return Err(serde::de::Error::duplicate_field("code"));
268                            }
269                            code = Some(map.next_value()?);
270                        }
271                        Field::Message => {
272                            if message.is_some() {
273                                return Err(serde::de::Error::duplicate_field("message"));
274                            }
275                            message = Some(map.next_value()?);
276                        }
277                        Field::Data => {
278                            if data.is_some() {
279                                return Err(serde::de::Error::duplicate_field("data"));
280                            }
281                            data = Some(map.next_value()?);
282                        }
283                        Field::Unknown => {
284                            let _: serde::de::IgnoredAny = map.next_value()?;
285                            }
287                    }
288                }
289                Ok(ErrorPayload {
290                    code: code.ok_or_else(|| serde::de::Error::missing_field("code"))?,
291                    message: message.unwrap_or_default(),
292                    data,
293                })
294            }
295        }
296
297        deserializer.deserialize_any(ErrorPayloadVisitor(PhantomData))
298    }
299}
300
301impl<'a, Data> ErrorPayload<Data>
302where
303    Data: Borrow<RawValue> + 'a,
304{
305    pub fn try_data_as<T: Deserialize<'a>>(&'a self) -> Option<serde_json::Result<T>> {
314        self.data.as_ref().map(|data| serde_json::from_str(data.borrow().get()))
315    }
316
317    pub fn deser_data<T: DeserializeOwned>(self) -> Result<ErrorPayload<T>, Self> {
324        match self.try_data_as::<T>() {
325            Some(Ok(data)) => {
326                Ok(ErrorPayload { code: self.code, message: self.message, data: Some(data) })
327            }
328            _ => Err(self),
329        }
330    }
331
332    pub fn as_revert_data(&self) -> Option<Bytes> {
344        if self.message.contains("revert") {
345            let value = Value::deserialize(self.data.as_ref()?.borrow()).ok()?;
346            spelunk_revert(&value)
347        } else {
348            None
349        }
350    }
351
352    pub fn as_decoded_interface_error<E: SolInterface>(&self) -> Option<E> {
355        self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
356    }
357
358    pub fn as_decoded_error<E: SolError>(&self) -> Option<E> {
360        self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
361    }
362}
363
364#[cfg(test)]
365mod test {
366    use alloy_primitives::U256;
367    use alloy_sol_types::sol;
368
369    use super::BorrowedErrorPayload;
370    use crate::ErrorPayload;
371
372    #[test]
373    fn smooth_borrowing() {
374        let json = r#"{ "code": -32000, "message": "b", "data": null }"#;
375        let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
376
377        assert_eq!(payload.code, -32000);
378        assert_eq!(payload.message, "b");
379        assert_eq!(payload.data.unwrap().get(), "null");
380    }
381
382    #[test]
383    fn smooth_deser() {
384        #[derive(Debug, PartialEq, serde::Deserialize)]
385        struct TestData {
386            a: u32,
387            b: Option<String>,
388        }
389
390        let json = r#"{ "code": -32000, "message": "b", "data": { "a": 5, "b": null } }"#;
391
392        let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
393        let data: TestData = payload.try_data_as().unwrap().unwrap();
394        assert_eq!(data, TestData { a: 5, b: None });
395    }
396
397    #[test]
398    fn missing_data() {
399        let json = r#"{"code":-32007,"message":"20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com"}"#;
400        let payload: ErrorPayload = serde_json::from_str(json).unwrap();
401
402        assert_eq!(payload.code, -32007);
403        assert_eq!(payload.message, "20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com");
404        assert!(payload.data.is_none());
405    }
406
407    #[test]
408    fn custom_error_decoding() {
409        sol!(
410            #[derive(Debug, PartialEq, Eq)]
411            library Errors {
412                error SomeCustomError(uint256 a);
413            }
414        );
415
416        let json = r#"{"code":3,"message":"execution reverted: ","data":"0x810f00230000000000000000000000000000000000000000000000000000000000000001"}"#;
417        let payload: ErrorPayload = serde_json::from_str(json).unwrap();
418
419        let Errors::ErrorsErrors::SomeCustomError(value) =
420            payload.as_decoded_interface_error::<Errors::ErrorsErrors>().unwrap();
421
422        assert_eq!(value.a, U256::from(1));
423
424        let decoded_err = payload.as_decoded_error::<Errors::SomeCustomError>().unwrap();
425
426        assert_eq!(decoded_err, Errors::SomeCustomError { a: U256::from(1) });
427    }
428}