use alloy_primitives::Bytes;
use alloy_sol_types::{SolError, SolInterface};
use serde::{
de::{DeserializeOwned, MapAccess, Visitor},
Deserialize, Deserializer, Serialize,
};
use serde_json::{
value::{to_raw_value, RawValue},
Value,
};
use std::{
borrow::{Borrow, Cow},
fmt,
marker::PhantomData,
};
use crate::RpcSend;
const INTERNAL_ERROR: Cow<'static, str> = Cow::Borrowed("Internal error");
#[derive(Clone, Debug, Serialize, PartialEq, Eq)]
pub struct ErrorPayload<ErrData = Box<RawValue>> {
pub code: i64,
pub message: Cow<'static, str>,
pub data: Option<ErrData>,
}
impl<E> ErrorPayload<E> {
pub const fn parse_error() -> Self {
Self { code: -32700, message: Cow::Borrowed("Parse error"), data: None }
}
pub const fn invalid_request() -> Self {
Self { code: -32600, message: Cow::Borrowed("Invalid Request"), data: None }
}
pub const fn method_not_found() -> Self {
Self { code: -32601, message: Cow::Borrowed("Method not found"), data: None }
}
pub const fn invalid_params() -> Self {
Self { code: -32602, message: Cow::Borrowed("Invalid params"), data: None }
}
pub const fn internal_error() -> Self {
Self { code: -32603, message: INTERNAL_ERROR, data: None }
}
pub const fn internal_error_message(message: Cow<'static, str>) -> Self {
Self { code: -32603, message, data: None }
}
pub const fn internal_error_with_obj(data: E) -> Self
where
E: RpcSend,
{
Self { code: -32603, message: INTERNAL_ERROR, data: Some(data) }
}
pub const fn internal_error_with_message_and_obj(message: Cow<'static, str>, data: E) -> Self
where
E: RpcSend,
{
Self { code: -32603, message, data: Some(data) }
}
pub fn is_retry_err(&self) -> bool {
if self.code == 429 {
return true;
}
if self.code == -32005 {
return true;
}
if self.code == -32016 && self.message.contains("rate limit") {
return true;
}
if self.code == -32012 && self.message.contains("credits") {
return true;
}
if self.code == -32007 && self.message.contains("request limit reached") {
return true;
}
if self.code == 1008 {
return true;
}
if self.code == -32055 {
return true;
}
match self.message.as_ref() {
"header not found" => true,
"daily request count exceeded, request rate limited" => true,
msg => {
msg.contains("rate limit")
|| msg.contains("rate exceeded")
|| msg.contains("too many requests")
|| msg.contains("credits limited")
|| msg.contains("request limit")
|| msg.contains("maximum number of concurrent requests")
}
}
}
}
impl<T> From<T> for ErrorPayload<T>
where
T: std::error::Error + RpcSend,
{
fn from(value: T) -> Self {
Self { code: -32603, message: INTERNAL_ERROR, data: Some(value) }
}
}
impl<E> ErrorPayload<E>
where
E: RpcSend,
{
pub fn serialize_payload(&self) -> serde_json::Result<ErrorPayload> {
Ok(ErrorPayload {
code: self.code,
message: self.message.clone(),
data: match self.data.as_ref() {
Some(data) => Some(to_raw_value(data)?),
None => None,
},
})
}
}
fn spelunk_revert(value: &Value) -> Option<Bytes> {
match value {
Value::String(s) => s.parse().ok(),
Value::Object(o) => o.values().find_map(spelunk_revert),
_ => None,
}
}
impl<ErrData: fmt::Display> fmt::Display for ErrorPayload<ErrData> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"error code {}: {}{}",
self.code,
self.message,
self.data.as_ref().map(|data| format!(", data: {data}")).unwrap_or_default()
)
}
}
pub type BorrowedErrorPayload<'a> = ErrorPayload<&'a RawValue>;
impl BorrowedErrorPayload<'_> {
pub fn into_owned(self) -> ErrorPayload {
ErrorPayload {
code: self.code,
message: self.message,
data: self.data.map(|data| data.to_owned()),
}
}
}
impl<'de, ErrData: Deserialize<'de>> Deserialize<'de> for ErrorPayload<ErrData> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Code,
Message,
Data,
Unknown,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;
impl serde::de::Visitor<'_> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("`code`, `message` and `data`")
}
fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: serde::de::Error,
{
match value {
"code" => Ok(Field::Code),
"message" => Ok(Field::Message),
"data" => Ok(Field::Data),
_ => Ok(Field::Unknown),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct ErrorPayloadVisitor<T>(PhantomData<T>);
impl<'de, Data> Visitor<'de> for ErrorPayloadVisitor<Data>
where
Data: Deserialize<'de>,
{
type Value = ErrorPayload<Data>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a JSON-RPC 2.0 error object")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut code = None;
let mut message = None;
let mut data = None;
while let Some(key) = map.next_key()? {
match key {
Field::Code => {
if code.is_some() {
return Err(serde::de::Error::duplicate_field("code"));
}
code = Some(map.next_value()?);
}
Field::Message => {
if message.is_some() {
return Err(serde::de::Error::duplicate_field("message"));
}
message = Some(map.next_value()?);
}
Field::Data => {
if data.is_some() {
return Err(serde::de::Error::duplicate_field("data"));
}
data = Some(map.next_value()?);
}
Field::Unknown => {
let _: serde::de::IgnoredAny = map.next_value()?;
}
}
}
Ok(ErrorPayload {
code: code.ok_or_else(|| serde::de::Error::missing_field("code"))?,
message: message.unwrap_or_default(),
data,
})
}
}
deserializer.deserialize_any(ErrorPayloadVisitor(PhantomData))
}
}
impl<'a, Data> ErrorPayload<Data>
where
Data: Borrow<RawValue> + 'a,
{
pub fn try_data_as<T: Deserialize<'a>>(&'a self) -> Option<serde_json::Result<T>> {
self.data.as_ref().map(|data| serde_json::from_str(data.borrow().get()))
}
pub fn deser_data<T: DeserializeOwned>(self) -> Result<ErrorPayload<T>, Self> {
match self.try_data_as::<T>() {
Some(Ok(data)) => {
Ok(ErrorPayload { code: self.code, message: self.message, data: Some(data) })
}
_ => Err(self),
}
}
pub fn as_revert_data(&self) -> Option<Bytes> {
if self.message.contains("revert") {
let value = Value::deserialize(self.data.as_ref()?.borrow()).ok()?;
spelunk_revert(&value)
} else {
None
}
}
pub fn as_decoded_interface_error<E: SolInterface>(&self) -> Option<E> {
self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
}
pub fn as_decoded_error<E: SolError>(&self) -> Option<E> {
self.as_revert_data().and_then(|data| E::abi_decode(&data).ok())
}
}
#[cfg(test)]
mod test {
use alloy_primitives::U256;
use alloy_sol_types::sol;
use super::BorrowedErrorPayload;
use crate::ErrorPayload;
#[test]
fn smooth_borrowing() {
let json = r#"{ "code": -32000, "message": "b", "data": null }"#;
let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
assert_eq!(payload.code, -32000);
assert_eq!(payload.message, "b");
assert_eq!(payload.data.unwrap().get(), "null");
}
#[test]
fn smooth_deser() {
#[derive(Debug, PartialEq, serde::Deserialize)]
struct TestData {
a: u32,
b: Option<String>,
}
let json = r#"{ "code": -32000, "message": "b", "data": { "a": 5, "b": null } }"#;
let payload: BorrowedErrorPayload<'_> = serde_json::from_str(json).unwrap();
let data: TestData = payload.try_data_as().unwrap().unwrap();
assert_eq!(data, TestData { a: 5, b: None });
}
#[test]
fn missing_data() {
let json = r#"{"code":-32007,"message":"20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com"}"#;
let payload: ErrorPayload = serde_json::from_str(json).unwrap();
assert_eq!(payload.code, -32007);
assert_eq!(payload.message, "20/second request limit reached - reduce calls per second or upgrade your account at quicknode.com");
assert!(payload.data.is_none());
}
#[test]
fn custom_error_decoding() {
sol!(
#[derive(Debug, PartialEq, Eq)]
library Errors {
error SomeCustomError(uint256 a);
}
);
let json = r#"{"code":3,"message":"execution reverted: ","data":"0x810f00230000000000000000000000000000000000000000000000000000000000000001"}"#;
let payload: ErrorPayload = serde_json::from_str(json).unwrap();
let Errors::ErrorsErrors::SomeCustomError(value) =
payload.as_decoded_interface_error::<Errors::ErrorsErrors>().unwrap();
assert_eq!(value.a, U256::from(1));
let decoded_err = payload.as_decoded_error::<Errors::SomeCustomError>().unwrap();
assert_eq!(decoded_err, Errors::SomeCustomError { a: U256::from(1) });
}
#[test]
fn max_concurrent_requests() {
let json = r#"{"code":1008,"message":"You have exceeded the maximum number of concurrent requests on a single WebSocket. At most 200 concurrent requests are allowed per WebSocket."}"#;
let payload: ErrorPayload = serde_json::from_str(json).unwrap();
assert!(payload.is_retry_err());
}
#[test]
fn extract_transaction_hash_box_raw_value() {
use crate::RpcError;
use alloy_primitives::B256;
use serde_json::value::RawValue;
let tx_hash = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
let json = format!(r#"{{"code":5,"message":"insufficient funds","data":"{}"}}"#, tx_hash);
let error_payload: ErrorPayload = serde_json::from_str(&json).unwrap();
let rpc_error: RpcError<(), Box<RawValue>> = RpcError::ErrorResp(error_payload);
let extracted_hash = rpc_error.tx_hash_data();
assert!(extracted_hash.is_some());
assert_eq!(extracted_hash.unwrap(), tx_hash.parse::<B256>().unwrap());
}
}