use std::collections::hash_map::Entry;
use std::fmt::{self, Debug};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use crate::error::RegisterMethodError;
use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::MethodSink;
use crate::server::subscription::{
BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink, Subscribers, Subscription,
SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit, SubscriptionState, sub_message_to_json,
};
use crate::server::{LOG_TARGET, MethodResponse, ResponsePayload};
use crate::traits::ToRpcParams;
use futures_util::{FutureExt, future::BoxFuture};
use http::Extensions;
use jsonrpsee_types::error::{ErrorCode, ErrorObject};
use jsonrpsee_types::{
ErrorObjectOwned, Id, Params, Request, Response, ResponseSuccess, SubscriptionId as RpcSubscriptionId,
};
use rustc_hash::FxHashMap;
use serde::de::DeserializeOwned;
use serde_json::value::RawValue;
use tokio::sync::{mpsc, oneshot};
use super::{IntoResponse, sub_err_to_json};
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, MaxResponseSize, Extensions) -> MethodResponse>;
pub type AsyncMethod<'a> = Arc<
dyn Send
+ Sync
+ Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Extensions) -> BoxFuture<'a, MethodResponse>,
>;
pub type SubscriptionMethod<'a> =
Arc<dyn Send + Sync + Fn(Id, Params, MethodSink, SubscriptionState, Extensions) -> BoxFuture<'a, MethodResponse>>;
type UnsubscriptionMethod =
Arc<dyn Send + Sync + Fn(Id, Params, ConnectionId, MaxResponseSize, Extensions) -> MethodResponse>;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, serde::Deserialize, serde::Serialize)]
pub struct ConnectionId(pub usize);
impl From<u32> for ConnectionId {
fn from(id: u32) -> Self {
Self(id as usize)
}
}
impl From<usize> for ConnectionId {
fn from(id: usize) -> Self {
Self(id)
}
}
pub type MaxResponseSize = usize;
pub type RawRpcResponse = (Box<RawValue>, mpsc::Receiver<Box<RawValue>>);
#[derive(thiserror::Error, Debug)]
pub enum MethodsError {
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
JsonRpc(#[from] ErrorObjectOwned),
#[error("Invalid subscription ID: `{0}`")]
InvalidSubscriptionId(String),
}
#[derive(Debug)]
pub enum CallOrSubscription {
Subscription(MethodResponse),
Call(MethodResponse),
}
impl CallOrSubscription {
pub fn as_response(&self) -> &MethodResponse {
match &self {
Self::Subscription(r) => r,
Self::Call(r) => r,
}
}
pub fn into_response(self) -> MethodResponse {
match self {
Self::Subscription(r) => r,
Self::Call(r) => r,
}
}
}
#[derive(Clone)]
pub enum MethodCallback {
Sync(SyncMethod),
Async(AsyncMethod<'static>),
Subscription(SubscriptionMethod<'static>),
Unsubscription(UnsubscriptionMethod),
}
#[derive(Debug, Copy, Clone)]
pub enum MethodKind {
Subscription,
Unsubscription,
MethodCall,
NotFound,
}
impl std::fmt::Display for MethodKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Subscription => "subscription",
Self::MethodCall => "method call",
Self::NotFound => "method not found",
Self::Unsubscription => "unsubscription",
};
write!(f, "{s}")
}
}
pub enum MethodResult<T> {
Sync(T),
Async(BoxFuture<'static, T>),
}
impl<T: Debug> Debug for MethodResult<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MethodResult::Sync(result) => result.fmt(f),
MethodResult::Async(_) => f.write_str("<future>"),
}
}
}
impl Debug for MethodCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Async(_) => write!(f, "Async"),
Self::Sync(_) => write!(f, "Sync"),
Self::Subscription(_) => write!(f, "Subscription"),
Self::Unsubscription(_) => write!(f, "Unsubscription"),
}
}
}
#[derive(Default, Debug, Clone)]
pub struct Methods {
callbacks: Arc<FxHashMap<&'static str, MethodCallback>>,
extensions: Extensions,
}
impl Methods {
pub fn new() -> Self {
Self::default()
}
pub fn verify_method_name(&mut self, name: &'static str) -> Result<(), RegisterMethodError> {
if self.callbacks.contains_key(name) {
return Err(RegisterMethodError::AlreadyRegistered(name.into()));
}
Ok(())
}
pub fn verify_and_insert(
&mut self,
name: &'static str,
callback: MethodCallback,
) -> Result<&mut MethodCallback, RegisterMethodError> {
match self.mut_callbacks().entry(name) {
Entry::Occupied(_) => Err(RegisterMethodError::AlreadyRegistered(name.into())),
Entry::Vacant(vacant) => Ok(vacant.insert(callback)),
}
}
fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> {
Arc::make_mut(&mut self.callbacks)
}
pub fn merge(&mut self, other: impl Into<Methods>) -> Result<(), RegisterMethodError> {
let mut other = other.into();
for name in other.callbacks.keys() {
self.verify_method_name(name)?;
}
let callbacks = self.mut_callbacks();
for (name, callback) in other.mut_callbacks().drain() {
callbacks.insert(name, callback);
}
Ok(())
}
pub fn method(&self, method_name: &str) -> Option<&MethodCallback> {
self.callbacks.get(method_name)
}
pub fn method_with_name(&self, method_name: &str) -> Option<(&'static str, &MethodCallback)> {
self.callbacks.get_key_value(method_name).map(|(k, v)| (*k, v))
}
pub async fn call<Params: ToRpcParams, T: DeserializeOwned + Clone>(
&self,
method: &str,
params: Params,
) -> Result<T, MethodsError> {
let params = params.to_rpc_params()?;
let req = Request::borrowed(method, params.as_ref().map(|p| p.as_ref()), Id::Number(0));
tracing::trace!(target: LOG_TARGET, "[Methods::call] Method: {:?}, params: {:?}", method, params);
let (rp, _) = self.inner_call(req, 1, mock_subscription_permit()).await;
let rp = serde_json::from_str::<Response<T>>(rp.get())?;
ResponseSuccess::try_from(rp).map(|s| s.result).map_err(|e| MethodsError::JsonRpc(e.into_owned()))
}
pub async fn raw_json_request(
&self,
request: &str,
buf_size: usize,
) -> Result<(Box<RawValue>, mpsc::Receiver<Box<RawValue>>), serde_json::Error> {
tracing::trace!("[Methods::raw_json_request] Request: {:?}", request);
let req: Request = serde_json::from_str(request)?;
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
Ok((resp, rx))
}
async fn inner_call(
&self,
req: Request<'_>,
buf_size: usize,
subscription_permit: SubscriptionPermit,
) -> RawRpcResponse {
let (tx, mut rx) = mpsc::channel(buf_size);
let Request { id, method, params, .. } = req;
let params = Params::new(params.as_ref().map(|params| params.as_ref().get()));
let max_response_size = usize::MAX;
let conn_id = ConnectionId(0);
let mut ext = self.extensions.clone();
ext.insert(conn_id);
let response = match self.method(&method) {
None => MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)),
Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, ext),
Some(MethodCallback::Async(cb)) => {
(cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, ext).await
}
Some(MethodCallback::Subscription(cb)) => {
let conn_state =
SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit };
let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, ext).await;
let _ = rx.recv().await.expect("Every call must at least produce one response; qed");
res
}
Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, ext),
};
let is_success = response.is_success();
let (rp, notif, _) = response.into_parts();
if let Some(n) = notif {
n.notify(is_success);
}
tracing::trace!(target: LOG_TARGET, "[Methods::inner_call] Method: {}, response: {}", method, rp);
(rp, rx)
}
pub async fn subscribe_unbounded(
&self,
sub_method: &str,
params: impl ToRpcParams,
) -> Result<Subscription, MethodsError> {
self.subscribe(sub_method, params, u32::MAX as usize).await
}
pub async fn subscribe(
&self,
sub_method: &str,
params: impl ToRpcParams,
buf_size: usize,
) -> Result<Subscription, MethodsError> {
let params = params.to_rpc_params()?;
let req = Request::borrowed(sub_method, params.as_ref().map(|p| p.as_ref()), Id::Number(0));
tracing::trace!(target: LOG_TARGET, "[Methods::subscribe] Method: {}, params: {:?}", sub_method, params);
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
let as_success: ResponseSuccess<&RawValue> = serde_json::from_str::<Response<_>>(resp.get())?.try_into()?;
let sub_id: RpcSubscriptionId = serde_json::from_str(as_success.result.get())?;
Ok(Subscription { sub_id: sub_id.into_owned(), rx })
}
pub fn method_names(&self) -> impl Iterator<Item = &'static str> + '_ {
self.callbacks.keys().copied()
}
pub fn extensions(&mut self) -> &Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}
impl<Context> Deref for RpcModule<Context> {
type Target = Methods;
fn deref(&self) -> &Methods {
&self.methods
}
}
impl<Context> DerefMut for RpcModule<Context> {
fn deref_mut(&mut self) -> &mut Methods {
&mut self.methods
}
}
#[derive(Debug, Clone)]
pub struct RpcModule<Context> {
ctx: Arc<Context>,
methods: Methods,
}
impl<Context> RpcModule<Context> {
pub fn new(ctx: Context) -> Self {
Self::from_arc(Arc::new(ctx))
}
pub fn from_arc(ctx: Arc<Context>) -> Self {
Self { ctx, methods: Default::default() }
}
pub fn remove_context(self) -> RpcModule<()> {
let mut module = RpcModule::new(());
module.methods = self.methods;
module
}
}
impl<Context> From<RpcModule<Context>> for Methods {
fn from(module: RpcModule<Context>) -> Methods {
module.methods
}
}
impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub fn register_method<R, F>(
&mut self,
method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
Context: Send + Sync + 'static,
R: IntoResponse + 'static,
F: Fn(Params, &Context, &Extensions) -> R + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
self.methods.verify_and_insert(
method_name,
MethodCallback::Sync(Arc::new(move |id, params, max_response_size, extensions| {
let rp = callback(params, &*ctx, &extensions).into_response();
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
})),
)
}
pub fn remove_method(&mut self, method_name: &'static str) -> Option<MethodCallback> {
self.methods.mut_callbacks().remove(method_name)
}
pub fn register_async_method<R, Fun, Fut>(
&mut self,
method_name: &'static str,
callback: Fun,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
R: IntoResponse + 'static,
Fut: Future<Output = R> + Send,
Fun: (Fn(Params<'static>, Arc<Context>, Extensions) -> Fut) + Clone + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
self.methods.verify_and_insert(
method_name,
MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
let ctx = ctx.clone();
let callback = callback.clone();
let future = async move {
let rp = callback(params, ctx, extensions.clone()).await.into_response();
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
};
future.boxed()
})),
)
}
pub fn register_blocking_method<R, F>(
&mut self,
method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
Context: Send + Sync + 'static,
R: IntoResponse + 'static,
F: Fn(Params, Arc<Context>, Extensions) -> R + Clone + Send + Sync + 'static,
{
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
let ctx = ctx.clone();
let callback = callback.clone();
let extensions2 = extensions.clone();
tokio::task::spawn_blocking(move || {
let rp = callback(params, ctx, extensions2.clone()).into_response();
MethodResponse::response(id, rp, max_response_size).with_extensions(extensions2)
})
.map(|result| match result {
Ok(r) => r,
Err(err) => {
tracing::error!(target: LOG_TARGET, "Join error for blocking RPC method: {:?}", err);
MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError))
.with_extensions(extensions)
}
})
.boxed()
})),
)?;
Ok(callback)
}
pub fn register_subscription<R, F, Fut>(
&mut self,
subscribe_method_name: &'static str,
notif_method_name: &'static str,
unsubscribe_method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
Context: Send + Sync + 'static,
F: (Fn(Params<'static>, PendingSubscriptionSink, Arc<Context>, Extensions) -> Fut)
+ Send
+ Sync
+ Clone
+ 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoSubscriptionCloseResponse + Send,
{
let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
let ctx = self.ctx.clone();
let callback = {
self.methods.verify_and_insert(
subscribe_method_name,
MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
let (tx, rx) = oneshot::channel();
let (accepted_tx, accepted_rx) = oneshot::channel();
let sub_id = uniq_sub.sub_id.clone();
let method = notif_method_name;
let sink = PendingSubscriptionSink {
inner: method_sink.clone(),
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub,
id: id.clone().into_owned(),
subscribe: tx,
permit: conn.subscription_permit,
};
let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone());
tokio::spawn(async move {
let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await {
Ok((r, _)) => r.into_response(),
Err(_) => return,
};
match response {
SubscriptionCloseResponse::Notif(msg) => {
let json = sub_message_to_json(msg, &sub_id, method);
let _ = method_sink.send(json).await;
}
SubscriptionCloseResponse::NotifErr(err) => {
let json = sub_err_to_json(err, sub_id, method);
let _ = method_sink.send(json).await;
}
SubscriptionCloseResponse::None => (),
}
});
let id = id.clone().into_owned();
Box::pin(async move {
let rp = match rx.await {
Ok(rp) => {
if rp.is_success() {
let _ = accepted_tx.send(());
}
rp
}
Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
};
rp.with_extensions(extensions)
})
})),
)?
};
Ok(callback)
}
pub fn register_subscription_raw<R, F>(
&mut self,
subscribe_method_name: &'static str,
notif_method_name: &'static str,
unsubscribe_method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
Context: Send + Sync + 'static,
F: (Fn(Params, PendingSubscriptionSink, Arc<Context>, &Extensions) -> R) + Send + Sync + Clone + 'static,
R: IntoSubscriptionCloseResponse,
{
let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
let ctx = self.ctx.clone();
let callback = {
self.methods.verify_and_insert(
subscribe_method_name,
MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
let (tx, rx) = oneshot::channel();
let sink = PendingSubscriptionSink {
inner: method_sink.clone(),
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub,
id: id.clone().into_owned(),
subscribe: tx,
permit: conn.subscription_permit,
};
callback(params, sink, ctx.clone(), &extensions);
let id = id.clone().into_owned();
Box::pin(async move {
let rp = match rx.await {
Ok(rp) => rp,
Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
};
rp.with_extensions(extensions)
})
})),
)?
};
Ok(callback)
}
fn verify_and_register_unsubscribe(
&mut self,
subscribe_method_name: &'static str,
unsubscribe_method_name: &'static str,
) -> Result<Subscribers, RegisterMethodError> {
if subscribe_method_name == unsubscribe_method_name {
return Err(RegisterMethodError::SubscriptionNameConflict(subscribe_method_name.into()));
}
self.methods.verify_method_name(subscribe_method_name)?;
self.methods.verify_method_name(unsubscribe_method_name)?;
let subscribers = Subscribers::default();
{
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::Unsubscription(Arc::new(move |id, params, conn_id, max_response_size, extensions| {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
tracing::warn!(
target: LOG_TARGET,
"Unsubscribe call `{}` failed: couldn't parse subscription id={:?} request id={:?}",
unsubscribe_method_name,
params,
id
);
return MethodResponse::response(id, ResponsePayload::success(false), max_response_size)
.with_extensions(extensions);
}
};
let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() };
let result = subscribers.lock().remove(&key).is_some();
if !result {
tracing::debug!(
target: LOG_TARGET,
"Unsubscribe call `{}` subscription key={:?} not an active subscription",
unsubscribe_method_name,
key,
);
}
MethodResponse::response(id, ResponsePayload::success(result), max_response_size)
})),
);
}
Ok(subscribers)
}
pub fn register_alias(
&mut self,
alias: &'static str,
existing_method: &'static str,
) -> Result<(), RegisterMethodError> {
self.methods.verify_method_name(alias)?;
let callback = match self.methods.callbacks.get(existing_method) {
Some(callback) => callback.clone(),
None => return Err(RegisterMethodError::MethodNotFound(existing_method.into())),
};
self.methods.mut_callbacks().insert(alias, callback);
Ok(())
}
}
fn mock_subscription_permit() -> SubscriptionPermit {
BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed")
}