use crate::{common::Id, RpcBorrow, RpcSend};
use alloy_primitives::{keccak256, B256};
use http::Extensions;
use serde::{
de::{DeserializeOwned, MapAccess},
ser::SerializeMap,
Deserialize, Serialize,
};
use serde_json::value::RawValue;
use std::{borrow::Cow, marker::PhantomData, mem::MaybeUninit};
#[derive(Clone, Debug)]
pub struct RequestMeta {
pub method: Cow<'static, str>,
pub id: Id,
is_subscription: bool,
extensions: Extensions,
}
impl RequestMeta {
pub fn new(method: Cow<'static, str>, id: Id) -> Self {
Self { method, id, is_subscription: false, extensions: Extensions::new() }
}
pub fn is_subscription(&self) -> bool {
self.is_subscription || self.method == "eth_subscribe"
}
pub const fn set_is_subscription(&mut self) {
self.set_subscription_status(true);
}
pub const fn set_subscription_status(&mut self, sub: bool) {
self.is_subscription = sub;
}
pub const fn extensions(&self) -> &Extensions {
&self.extensions
}
pub const fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
pub fn headers(&self) -> Option<&http::HeaderMap> {
self.extensions.get::<http::HeaderMap>()
}
pub fn headers_mut(&mut self) -> &mut http::HeaderMap {
self.extensions.get_or_insert_default::<http::HeaderMap>()
}
}
impl PartialEq for RequestMeta {
fn eq(&self, other: &Self) -> bool {
self.method == other.method
&& self.id == other.id
&& self.is_subscription == other.is_subscription
}
}
impl Eq for RequestMeta {}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Request<Params> {
pub meta: RequestMeta,
pub params: Params,
}
impl<Params> Request<Params> {
pub fn new(method: impl Into<Cow<'static, str>>, id: Id, params: Params) -> Self {
Self { meta: RequestMeta::new(method.into(), id), params }
}
pub fn is_subscription(&self) -> bool {
self.meta.is_subscription()
}
pub const fn set_is_subscription(&mut self) {
self.meta.set_is_subscription()
}
pub const fn set_subscription_status(&mut self, sub: bool) {
self.meta.set_subscription_status(sub);
}
pub fn map_params<NewParams>(
self,
map: impl FnOnce(Params) -> NewParams,
) -> Request<NewParams> {
Request { meta: self.meta, params: map(self.params) }
}
pub fn map_meta<F>(self, f: F) -> Self
where
F: FnOnce(RequestMeta) -> RequestMeta,
{
Self { meta: f(self.meta), params: self.params }
}
}
pub type PartiallySerializedRequest = Request<Box<RawValue>>;
impl<Params> Request<Params>
where
Params: RpcSend,
{
pub fn box_params(self) -> PartiallySerializedRequest {
Request { meta: self.meta, params: serde_json::value::to_raw_value(&self.params).unwrap() }
}
pub fn serialize(self) -> serde_json::Result<SerializedRequest> {
let request = serde_json::value::to_raw_value(&self)?;
Ok(SerializedRequest { meta: self.meta, request })
}
}
impl<Params> Request<&Params>
where
Params: ToOwned,
Params::Owned: RpcSend,
{
pub fn into_owned_params(self) -> Request<Params::Owned> {
Request { meta: self.meta, params: self.params.to_owned() }
}
}
impl<'a, Params> Request<Params>
where
Params: AsRef<RawValue> + 'a,
{
pub fn try_params_as<T: DeserializeOwned>(&self) -> serde_json::Result<T> {
serde_json::from_str(self.params.as_ref().get())
}
pub fn try_borrow_params_as<T: Deserialize<'a>>(&'a self) -> serde_json::Result<T> {
serde_json::from_str(self.params.as_ref().get())
}
}
impl<Params> Serialize for Request<Params>
where
Params: RpcSend,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let sized_params = std::mem::size_of::<Params>() != 0;
let mut map = serializer.serialize_map(Some(3 + sized_params as usize))?;
map.serialize_entry("method", &self.meta.method[..])?;
if sized_params {
map.serialize_entry("params", &self.params)?;
}
map.serialize_entry("id", &self.meta.id)?;
map.serialize_entry("jsonrpc", "2.0")?;
map.end()
}
}
impl<'de, Params> Deserialize<'de> for Request<Params>
where
Params: RpcBorrow<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor<Params>(PhantomData<Params>);
impl<'de, Params> serde::de::Visitor<'de> for Visitor<Params>
where
Params: RpcBorrow<'de>,
{
type Value = Request<Params>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
formatter,
"a JSON-RPC 2.0 request object with params of type {}",
std::any::type_name::<Params>()
)
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut id = None;
let mut params = None;
let mut method = None;
let mut jsonrpc = None;
while let Some(key) = map.next_key()? {
match key {
"id" => {
if id.is_some() {
return Err(serde::de::Error::duplicate_field("id"));
}
id = Some(map.next_value()?);
}
"params" => {
if params.is_some() {
return Err(serde::de::Error::duplicate_field("params"));
}
params = Some(map.next_value()?);
}
"method" => {
if method.is_some() {
return Err(serde::de::Error::duplicate_field("method"));
}
method = Some(map.next_value()?);
}
"jsonrpc" => {
let version: String = map.next_value()?;
if version != "2.0" {
return Err(serde::de::Error::custom(format!(
"unsupported JSON-RPC version: {version}"
)));
}
jsonrpc = Some(());
}
other => {
return Err(serde::de::Error::unknown_field(
other,
&["id", "params", "method", "jsonrpc"],
));
}
}
}
if jsonrpc.is_none() {
return Err(serde::de::Error::missing_field("jsonrpc"));
}
if method.is_none() {
return Err(serde::de::Error::missing_field("method"));
}
if params.is_none() {
if std::mem::size_of::<Params>() == 0 {
unsafe { params = Some(MaybeUninit::<Params>::zeroed().assume_init()) }
} else {
return Err(serde::de::Error::missing_field("params"));
}
}
Ok(Request {
meta: RequestMeta::new(method.unwrap(), id.unwrap_or(Id::None)),
params: params.unwrap(),
})
}
}
deserializer.deserialize_map(Visitor(PhantomData))
}
}
#[derive(Clone, Debug)]
pub struct SerializedRequest {
meta: RequestMeta,
request: Box<RawValue>,
}
impl<Params> TryFrom<Request<Params>> for SerializedRequest
where
Params: RpcSend,
{
type Error = serde_json::Error;
fn try_from(value: Request<Params>) -> Result<Self, Self::Error> {
value.serialize()
}
}
impl SerializedRequest {
pub const fn meta(&self) -> &RequestMeta {
&self.meta
}
pub const fn meta_mut(&mut self) -> &mut RequestMeta {
&mut self.meta
}
pub fn headers(&self) -> Option<&http::HeaderMap> {
self.meta.headers()
}
pub fn headers_mut(&mut self) -> &mut http::HeaderMap {
self.meta.headers_mut()
}
pub const fn id(&self) -> &Id {
&self.meta.id
}
pub fn method(&self) -> &str {
&self.meta.method
}
pub fn method_clone(&self) -> Cow<'static, str> {
match &self.meta.method {
Cow::Borrowed(b) => Cow::Borrowed(b),
Cow::Owned(o) => Cow::Owned(o.clone()),
}
}
pub const fn set_is_subscription(&mut self) {
self.meta.set_is_subscription();
}
pub fn is_subscription(&self) -> bool {
self.meta.is_subscription()
}
pub const fn serialized(&self) -> &RawValue {
&self.request
}
pub fn into_serialized(self) -> Box<RawValue> {
self.request
}
pub fn decompose(self) -> (RequestMeta, Box<RawValue>) {
(self.meta, self.request)
}
pub fn take_request(self) -> Box<RawValue> {
self.request
}
pub fn params(&self) -> Option<&RawValue> {
#[derive(Deserialize)]
struct Req<'a> {
#[serde(borrow)]
params: Option<&'a RawValue>,
}
let req: Req<'_> = serde_json::from_str(self.request.get()).unwrap();
req.params
}
pub fn params_hash(&self) -> B256 {
self.params().map_or_else(|| keccak256(""), |params| keccak256(params.get()))
}
}
impl Serialize for SerializedRequest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.request.serialize(serializer)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::RpcObject;
fn test_inner<T: RpcObject + PartialEq>(t: T) {
let ser = serde_json::to_string(&t).unwrap();
let de: T = serde_json::from_str(&ser).unwrap();
let reser = serde_json::to_string(&de).unwrap();
assert_eq!(de, t, "deser error for {}", std::any::type_name::<T>());
assert_eq!(ser, reser, "reser error for {}", std::any::type_name::<T>());
}
#[test]
fn test_ser_deser() {
test_inner(Request::<()>::new("test", 1.into(), ()));
test_inner(Request::<u64>::new("test", "hello".to_string().into(), 1));
test_inner(Request::<String>::new("test", Id::None, "test".to_string()));
test_inner(Request::<Vec<u64>>::new("test", u64::MAX.into(), vec![1, 2, 3]));
}
}