1use alloy_primitives::Bytes;
2use alloy_sol_types::{SolError, SolInterface};
3use serde::{
4 de::{DeserializeOwned, MapAccess, Visitor},
5 Deserialize, Deserializer, Serialize,
6};
7use serde_json::{
8 value::{to_raw_value, RawValue},
9 Value,
10};
11use std::{
12 borrow::{Borrow, Cow},
13 fmt,
14 marker::PhantomData,
15};
16
17use crate::RpcSend;
18
19const INTERNAL_ERROR: Cow<'static, str> = Cow::Borrowed("Internal error");
20
21#[derive(Clone, Debug, Serialize, PartialEq, Eq)]
27pub struct ErrorPayload<ErrData = Box<RawValue>> {
28 pub code: i64,
30 pub message: Cow<'static, str>,
32 pub data: Option<ErrData>,
34}
35
36impl<E> ErrorPayload<E> {
37 pub const fn parse_error() -> Self {
39 Self { code: -32700, message: Cow::Borrowed("Parse error"), data: None }
40 }
41
42 pub const fn invalid_request() -> Self {
44 Self { code: -32600, message: Cow::Borrowed("Invalid Request"), data: None }
45 }
46
47 pub const fn method_not_found() -> Self {
49 Self { code: -32601, message: Cow::Borrowed("Method not found"), data: None }
50 }
51
52 pub const fn invalid_params() -> Self {
54 Self { code: -32602, message: Cow::Borrowed("Invalid params"), data: None }
55 }
56
57 pub const fn internal_error() -> Self {
59 Self { code: -32603, message: INTERNAL_ERROR, data: None }
60 }
61
62 pub const fn internal_error_message(message: Cow<'static, str>) -> Self {
64 Self { code: -32603, message, data: None }
65 }
66
67 pub const fn internal_error_with_obj(data: E) -> Self
70 where
71 E: RpcSend,
72 {
73 Self { code: -32603, message: INTERNAL_ERROR, data: Some(data) }
74 }
75
76 pub const fn internal_error_with_message_and_obj(message: Cow<'static, str>, data: E) -> Self
78 where
79 E: RpcSend,
80 {
81 Self { code: -32603, message, data: Some(data) }
82 }
83
84 pub fn is_retry_err(&self) -> bool {
87 if self.code == 429 {
89 return true;
90 }
91
92 if self.code == -32005 {
94 return true;
95 }
96
97 if self.code == -32016 && self.message.contains("rate limit") {
99 return true;
100 }
101
102 if self.code == -32012 && self.message.contains("credits") {
105 return true;
106 }
107
108 if self.code == -32007 && self.message.contains("request limit reached") {
111 return true;
112 }
113
114 if self.code == 1008 {
117 return true;
118 }
119
120 if self.code == -32055 {
122 return true;
123 }
124
125 match self.message.as_ref() {
126 "header not found" => true,
128 "daily request count exceeded, request rate limited" => true,
130 msg => {
131 msg.contains("rate limit")
132 || msg.contains("rate exceeded")
133 || msg.contains("too many requests")
134 || msg.contains("credits limited")
135 || msg.contains("request limit")
136 || msg.contains("maximum number of concurrent requests")
137 }
138 }
139 }
140}
141
142impl<T> From<T> for ErrorPayload<T>
143where
144 T: std::error::Error + RpcSend,
145{
146 fn from(value: T) -> Self {
147 Self { code: -32603, message: INTERNAL_ERROR, data: Some(value) }
148 }
149}
150
151impl<E> ErrorPayload<E>
152where
153 E: RpcSend,
154{
155 pub fn serialize_payload(&self) -> serde_json::Result<ErrorPayload> {
157 Ok(ErrorPayload {
158 code: self.code,
159 message: self.message.clone(),
160 data: match self.data.as_ref() {
161 Some(data) => Some(to_raw_value(data)?),
162 None => None,
163 },
164 })
165 }
166}
167
168fn spelunk_revert(value: &Value) -> Option<Bytes> {
173 match value {
174 Value::String(s) => s.parse().ok(),
175 Value::Object(o) => o.values().find_map(spelunk_revert),
176 _ => None,
177 }
178}
179
180impl<ErrData: fmt::Display> fmt::Display for ErrorPayload<ErrData> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 write!(
183 f,
184 "error code {}: {}{}",
185 self.code,
186 self.message,
187 self.data.as_ref().map(|data| format!(", data: {data}")).unwrap_or_default()
188 )
189 }
190}
191
192pub type BorrowedErrorPayload<'a> = ErrorPayload<&'a RawValue>;
200
201impl BorrowedErrorPayload<'_> {
202 pub fn into_owned(self) -> ErrorPayload {
205 ErrorPayload {
206 code: self.code,
207 message: self.message,
208 data: self.data.map(|data| data.to_owned()),
209 }
210 }
211}
212
213impl<'de, ErrData: Deserialize<'de>> Deserialize<'de> for ErrorPayload<ErrData> {
214 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215 where
216 D: Deserializer<'de>,
217 {
218 enum Field {
219 Code,
220 Message,
221 Data,
222 Unknown,
223 }
224
225 impl<'de> Deserialize<'de> for Field {
226 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227 where
228 D: Deserializer<'de>,
229 {
230 struct FieldVisitor;
231
232 impl serde::de::Visitor<'_> for FieldVisitor {
233 type Value = Field;
234
235 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
236 formatter.write_str("`code`, `message` and `data`")
237 }
238
239 fn visit_str<E>(self, value: &str) -> Result<Field, E>
240 where
241 E: serde::de::Error,
242 {
243 match value {
244 "code" => Ok(Field::Code),
245 "message" => Ok(Field::Message),
246 "data" => Ok(Field::Data),
247 _ => Ok(Field::Unknown),
248 }
249 }
250 }
251 deserializer.deserialize_identifier(FieldVisitor)
252 }
253 }
254
255 struct ErrorPayloadVisitor<T>(PhantomData<T>);
256
257 impl<'de, Data> Visitor<'de> for ErrorPayloadVisitor<Data>
258 where
259 Data: Deserialize<'de>,
260 {
261 type Value = ErrorPayload<Data>;
262
263 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
264 write!(formatter, "a JSON-RPC 2.0 error object")
265 }
266
267 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
268 where
269 A: MapAccess<'de>,
270 {
271 let mut code = None;
272 let mut message = None;
273 let mut data = None;
274
275 while let Some(key) = map.next_key()? {
276 match key {
277 Field::Code => {
278 if code.is_some() {
279 return Err(serde::de::Error::duplicate_field("code"));
280 }
281 code = Some(map.next_value()?);
282 }
283 Field::Message => {
284 if message.is_some() {
285 return Err(serde::de::Error::duplicate_field("message"));
286 }
287 message = Some(map.next_value()?);
288 }
289 Field::Data => {
290 if data.is_some() {
291 return Err(serde::de::Error::duplicate_field("data"));
292 }
293 data = Some(map.next_value()?);
294 }
295 Field::Unknown => {
296 let _: serde::de::IgnoredAny = map.next_value()?;
297 }
299 }
300 }
301 Ok(ErrorPayload {
302 code: code.ok_or_else(|| serde::de::Error::missing_field("code"))?,
303 message: message.unwrap_or_default(),
304 data,
305 })
306 }
307 }
308
309 deserializer.deserialize_any(ErrorPayloadVisitor(PhantomData))
310 }
311}
312
313impl<'a, Data> ErrorPayload<Data>
314where
315 Data: Borrow<RawValue> + 'a,
316{
317 pub fn try_data_as<T: Deserialize<'a>>(&'a self) -> Option<serde_json::Result<T>> {
326 self.data.as_ref().map(|data| serde_json::from_str(data.borrow().get()))
327 }
328
329 pub fn deser_data<T: DeserializeOwned>(self) -> Result<ErrorPayload<T>, Self> {
336 match self.try_data_as::<T>() {
337 Some(Ok(data)) => {
338 Ok(ErrorPayload { code: self.code, message: self.message, data: Some(data) })
339 }
340 _ => Err(self),
341 }
342 }
343
344 pub fn as_revert_data(&self) -> Option<Bytes> {
356 if self.message.contains("revert") {
357 let value = Value::deserialize(self.data.as_ref()?.borrow()).ok()?;
358 spelunk_revert(&value)
359 } else {
360 None
361 }
362 }
363
364 pub fn as_decoded_interface_error<E: SolInterface>(&self) -> Option<E> {
367 self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
368 }
369
370 pub fn as_decoded_error<E: SolError>(&self) -> Option<E> {
372 self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
373 }
374}
375
376#[cfg(test)]
377mod test {
378 use alloy_primitives::U256;
379 use alloy_sol_types::sol;
380
381 use super::BorrowedErrorPayload;
382 use crate::ErrorPayload;
383
384 #[test]
385 fn smooth_borrowing() {
386 let json = r#"{ "code": -32000, "message": "b", "data": null }"#;
387 let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
388
389 assert_eq!(payload.code, -32000);
390 assert_eq!(payload.message, "b");
391 assert_eq!(payload.data.unwrap().get(), "null");
392 }
393
394 #[test]
395 fn smooth_deser() {
396 #[derive(Debug, PartialEq, serde::Deserialize)]
397 struct TestData {
398 a: u32,
399 b: Option<String>,
400 }
401
402 let json = r#"{ "code": -32000, "message": "b", "data": { "a": 5, "b": null } }"#;
403
404 let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
405 let data: TestData = payload.try_data_as().unwrap().unwrap();
406 assert_eq!(data, TestData { a: 5, b: None });
407 }
408
409 #[test]
410 fn missing_data() {
411 let json = r#"{"code":-32007,"message":"20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com"}"#;
412 let payload: ErrorPayload = serde_json::from_str(json).unwrap();
413
414 assert_eq!(payload.code, -32007);
415 assert_eq!(payload.message, "20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com");
416 assert!(payload.data.is_none());
417 }
418
419 #[test]
420 fn custom_error_decoding() {
421 sol!(
422 #[derive(Debug, PartialEq, Eq)]
423 library Errors {
424 error SomeCustomError(uint256 a);
425 }
426 );
427
428 let json = r#"{"code":3,"message":"execution reverted: ","data":"0x810f00230000000000000000000000000000000000000000000000000000000000000001"}"#;
429 let payload: ErrorPayload = serde_json::from_str(json).unwrap();
430
431 let Errors::ErrorsErrors::SomeCustomError(value) =
432 payload.as_decoded_interface_error::<Errors::ErrorsErrors>().unwrap();
433
434 assert_eq!(value.a, U256::from(1));
435
436 let decoded_err = payload.as_decoded_error::<Errors::SomeCustomError>().unwrap();
437
438 assert_eq!(decoded_err, Errors::SomeCustomError { a: U256::from(1) });
439 }
440
441 #[test]
442 fn max_concurrent_requests() {
443 let json = r#"{"code":1008,"message":"You have exceeded the maximum number of concurrent requests on a single WebSocket. At most 200 concurrent requests are allowed per WebSocket."}"#;
444 let payload: ErrorPayload = serde_json::from_str(json).unwrap();
445 assert!(payload.is_retry_err());
446 }
447}