1use std::collections::hash_map::Entry;
28use std::fmt::{self, Debug};
29use std::future::Future;
30use std::ops::{Deref, DerefMut};
31use std::sync::Arc;
32
33use crate::error::RegisterMethodError;
34use crate::id_providers::RandomIntegerIdProvider;
35use crate::server::helpers::MethodSink;
36use crate::server::subscription::{
37 BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink, Subscribers, Subscription,
38 SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit, SubscriptionState, sub_message_to_json,
39};
40use crate::server::{LOG_TARGET, MethodResponse, ResponsePayload};
41use crate::traits::ToRpcParams;
42use futures_util::{FutureExt, future::BoxFuture};
43use http::Extensions;
44use jsonrpsee_types::error::{ErrorCode, ErrorObject};
45use jsonrpsee_types::{
46 ErrorObjectOwned, Id, Params, Request, Response, ResponseSuccess, SubscriptionId as RpcSubscriptionId,
47};
48use rustc_hash::FxHashMap;
49use serde::de::DeserializeOwned;
50use serde_json::value::RawValue;
51use tokio::sync::{mpsc, oneshot};
52
53use super::{IntoResponse, sub_err_to_json};
54
55pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, MaxResponseSize, Extensions) -> MethodResponse>;
60pub type AsyncMethod<'a> = Arc<
62 dyn Send
63 + Sync
64 + Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Extensions) -> BoxFuture<'a, MethodResponse>,
65>;
66
67pub type SubscriptionMethod<'a> =
69 Arc<dyn Send + Sync + Fn(Id, Params, MethodSink, SubscriptionState, Extensions) -> BoxFuture<'a, MethodResponse>>;
70type UnsubscriptionMethod =
72 Arc<dyn Send + Sync + Fn(Id, Params, ConnectionId, MaxResponseSize, Extensions) -> MethodResponse>;
73
74#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, serde::Deserialize, serde::Serialize)]
76pub struct ConnectionId(pub usize);
77
78impl From<u32> for ConnectionId {
79 fn from(id: u32) -> Self {
80 Self(id as usize)
81 }
82}
83
84impl From<usize> for ConnectionId {
85 fn from(id: usize) -> Self {
86 Self(id)
87 }
88}
89
90pub type MaxResponseSize = usize;
92
93pub type RawRpcResponse = (Box<RawValue>, mpsc::Receiver<Box<RawValue>>);
98
99#[derive(thiserror::Error, Debug)]
101pub enum MethodsError {
102 #[error(transparent)]
104 Parse(#[from] serde_json::Error),
105 #[error(transparent)]
107 JsonRpc(#[from] ErrorObjectOwned),
108 #[error("Invalid subscription ID: `{0}`")]
109 InvalidSubscriptionId(String),
111}
112
113#[derive(Debug)]
118pub enum CallOrSubscription {
119 Subscription(MethodResponse),
122 Call(MethodResponse),
124}
125
126impl CallOrSubscription {
127 pub fn as_response(&self) -> &MethodResponse {
129 match &self {
130 Self::Subscription(r) => r,
131 Self::Call(r) => r,
132 }
133 }
134
135 pub fn into_response(self) -> MethodResponse {
137 match self {
138 Self::Subscription(r) => r,
139 Self::Call(r) => r,
140 }
141 }
142}
143
144#[derive(Clone)]
146pub enum MethodCallback {
147 Sync(SyncMethod),
149 Async(AsyncMethod<'static>),
151 Subscription(SubscriptionMethod<'static>),
153 Unsubscription(UnsubscriptionMethod),
155}
156
157#[derive(Debug, Copy, Clone)]
159pub enum MethodKind {
160 Subscription,
162 Unsubscription,
164 MethodCall,
166 NotFound,
168}
169
170impl std::fmt::Display for MethodKind {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 let s = match self {
173 Self::Subscription => "subscription",
174 Self::MethodCall => "method call",
175 Self::NotFound => "method not found",
176 Self::Unsubscription => "unsubscription",
177 };
178
179 write!(f, "{s}")
180 }
181}
182
183pub enum MethodResult<T> {
185 Sync(T),
187 Async(BoxFuture<'static, T>),
189}
190
191impl<T: Debug> Debug for MethodResult<T> {
192 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 match self {
194 MethodResult::Sync(result) => result.fmt(f),
195 MethodResult::Async(_) => f.write_str("<future>"),
196 }
197 }
198}
199
200impl Debug for MethodCallback {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 match self {
203 Self::Async(_) => write!(f, "Async"),
204 Self::Sync(_) => write!(f, "Sync"),
205 Self::Subscription(_) => write!(f, "Subscription"),
206 Self::Unsubscription(_) => write!(f, "Unsubscription"),
207 }
208 }
209}
210
211#[derive(Default, Debug, Clone)]
213pub struct Methods {
214 callbacks: Arc<FxHashMap<&'static str, MethodCallback>>,
215 extensions: Extensions,
216}
217
218impl Methods {
219 pub fn new() -> Self {
221 Self::default()
222 }
223
224 pub fn verify_method_name(&mut self, name: &'static str) -> Result<(), RegisterMethodError> {
226 if self.callbacks.contains_key(name) {
227 return Err(RegisterMethodError::AlreadyRegistered(name.into()));
228 }
229
230 Ok(())
231 }
232
233 pub fn verify_and_insert(
236 &mut self,
237 name: &'static str,
238 callback: MethodCallback,
239 ) -> Result<&mut MethodCallback, RegisterMethodError> {
240 match self.mut_callbacks().entry(name) {
241 Entry::Occupied(_) => Err(RegisterMethodError::AlreadyRegistered(name.into())),
242 Entry::Vacant(vacant) => Ok(vacant.insert(callback)),
243 }
244 }
245
246 fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> {
248 Arc::make_mut(&mut self.callbacks)
249 }
250
251 pub fn merge(&mut self, other: impl Into<Methods>) -> Result<(), RegisterMethodError> {
254 let mut other = other.into();
255
256 for name in other.callbacks.keys() {
257 self.verify_method_name(name)?;
258 }
259
260 let callbacks = self.mut_callbacks();
261
262 for (name, callback) in other.mut_callbacks().drain() {
263 callbacks.insert(name, callback);
264 }
265
266 Ok(())
267 }
268
269 pub fn method(&self, method_name: &str) -> Option<&MethodCallback> {
271 self.callbacks.get(method_name)
272 }
273
274 pub fn method_with_name(&self, method_name: &str) -> Option<(&'static str, &MethodCallback)> {
277 self.callbacks.get_key_value(method_name).map(|(k, v)| (*k, v))
278 }
279
280 pub async fn call<Params: ToRpcParams, T: DeserializeOwned + Clone>(
304 &self,
305 method: &str,
306 params: Params,
307 ) -> Result<T, MethodsError> {
308 let params = params.to_rpc_params()?;
309 let req = Request::borrowed(method, params.as_ref().map(|p| p.as_ref()), Id::Number(0));
310 tracing::trace!(target: LOG_TARGET, "[Methods::call] Method: {:?}, params: {:?}", method, params);
311 let (rp, _) = self.inner_call(req, 1, mock_subscription_permit()).await;
312
313 let rp = serde_json::from_str::<Response<T>>(rp.get())?;
314 ResponseSuccess::try_from(rp).map(|s| s.result).map_err(|e| MethodsError::JsonRpc(e.into_owned()))
315 }
316
317 pub async fn raw_json_request(
352 &self,
353 request: &str,
354 buf_size: usize,
355 ) -> Result<(Box<RawValue>, mpsc::Receiver<Box<RawValue>>), serde_json::Error> {
356 tracing::trace!("[Methods::raw_json_request] Request: {:?}", request);
357 let req: Request = serde_json::from_str(request)?;
358 let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
359
360 Ok((resp, rx))
361 }
362
363 async fn inner_call(
365 &self,
366 req: Request<'_>,
367 buf_size: usize,
368 subscription_permit: SubscriptionPermit,
369 ) -> RawRpcResponse {
370 let (tx, mut rx) = mpsc::channel(buf_size);
371 let Request { id, method, params, .. } = req;
374 let params = Params::new(params.as_ref().map(|params| params.as_ref().get()));
375 let max_response_size = usize::MAX;
376 let conn_id = ConnectionId(0);
377 let mut ext = self.extensions.clone();
378 ext.insert(conn_id);
379
380 let response = match self.method(&method) {
381 None => MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)),
382 Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, ext),
383 Some(MethodCallback::Async(cb)) => {
384 (cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, ext).await
385 }
386 Some(MethodCallback::Subscription(cb)) => {
387 let conn_state =
388 SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit };
389 let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, ext).await;
390
391 let _ = rx.recv().await.expect("Every call must at least produce one response; qed");
396
397 res
398 }
399 Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, ext),
400 };
401
402 let is_success = response.is_success();
403 let (rp, notif, _) = response.into_parts();
404
405 if let Some(n) = notif {
406 n.notify(is_success);
407 }
408
409 tracing::trace!(target: LOG_TARGET, "[Methods::inner_call] Method: {}, response: {}", method, rp);
410
411 (rp, rx)
412 }
413
414 pub async fn subscribe_unbounded(
443 &self,
444 sub_method: &str,
445 params: impl ToRpcParams,
446 ) -> Result<Subscription, MethodsError> {
447 self.subscribe(sub_method, params, u32::MAX as usize).await
448 }
449
450 pub async fn subscribe(
454 &self,
455 sub_method: &str,
456 params: impl ToRpcParams,
457 buf_size: usize,
458 ) -> Result<Subscription, MethodsError> {
459 let params = params.to_rpc_params()?;
460 let req = Request::borrowed(sub_method, params.as_ref().map(|p| p.as_ref()), Id::Number(0));
461
462 tracing::trace!(target: LOG_TARGET, "[Methods::subscribe] Method: {}, params: {:?}", sub_method, params);
463
464 let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
465 let as_success: ResponseSuccess<&RawValue> = serde_json::from_str::<Response<_>>(resp.get())?.try_into()?;
466 let sub_id: RpcSubscriptionId = serde_json::from_str(as_success.result.get())?;
467
468 Ok(Subscription { sub_id: sub_id.into_owned(), rx })
469 }
470
471 pub fn method_names(&self) -> impl Iterator<Item = &'static str> + '_ {
473 self.callbacks.keys().copied()
474 }
475
476 pub fn extensions(&mut self) -> &Extensions {
478 &self.extensions
479 }
480
481 pub fn extensions_mut(&mut self) -> &mut Extensions {
510 &mut self.extensions
511 }
512}
513
514impl<Context> Deref for RpcModule<Context> {
515 type Target = Methods;
516
517 fn deref(&self) -> &Methods {
518 &self.methods
519 }
520}
521
522impl<Context> DerefMut for RpcModule<Context> {
523 fn deref_mut(&mut self) -> &mut Methods {
524 &mut self.methods
525 }
526}
527
528#[derive(Debug, Clone)]
532pub struct RpcModule<Context> {
533 ctx: Arc<Context>,
534 methods: Methods,
535}
536
537impl<Context> RpcModule<Context> {
538 pub fn new(ctx: Context) -> Self {
540 Self::from_arc(Arc::new(ctx))
541 }
542
543 pub fn from_arc(ctx: Arc<Context>) -> Self {
547 Self { ctx, methods: Default::default() }
548 }
549
550 pub fn remove_context(self) -> RpcModule<()> {
552 let mut module = RpcModule::new(());
553 module.methods = self.methods;
554 module
555 }
556}
557
558impl<Context> From<RpcModule<Context>> for Methods {
559 fn from(module: RpcModule<Context>) -> Methods {
560 module.methods
561 }
562}
563
564impl<Context: Send + Sync + 'static> RpcModule<Context> {
565 pub fn register_method<R, F>(
576 &mut self,
577 method_name: &'static str,
578 callback: F,
579 ) -> Result<&mut MethodCallback, RegisterMethodError>
580 where
581 Context: Send + Sync + 'static,
582 R: IntoResponse + 'static,
583 F: Fn(Params, &Context, &Extensions) -> R + Send + Sync + 'static,
584 {
585 let ctx = self.ctx.clone();
586 self.methods.verify_and_insert(
587 method_name,
588 MethodCallback::Sync(Arc::new(move |id, params, max_response_size, extensions| {
589 let rp = callback(params, &*ctx, &extensions).into_response();
590 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
591 })),
592 )
593 }
594
595 pub fn remove_method(&mut self, method_name: &'static str) -> Option<MethodCallback> {
600 self.methods.mut_callbacks().remove(method_name)
601 }
602
603 pub fn register_async_method<R, Fun, Fut>(
616 &mut self,
617 method_name: &'static str,
618 callback: Fun,
619 ) -> Result<&mut MethodCallback, RegisterMethodError>
620 where
621 R: IntoResponse + 'static,
622 Fut: Future<Output = R> + Send,
623 Fun: (Fn(Params<'static>, Arc<Context>, Extensions) -> Fut) + Clone + Send + Sync + 'static,
624 {
625 let ctx = self.ctx.clone();
626 self.methods.verify_and_insert(
627 method_name,
628 MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
629 let ctx = ctx.clone();
630 let callback = callback.clone();
631
632 let future = async move {
635 let rp = callback(params, ctx, extensions.clone()).await.into_response();
636 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
637 };
638 future.boxed()
639 })),
640 )
641 }
642
643 pub fn register_blocking_method<R, F>(
647 &mut self,
648 method_name: &'static str,
649 callback: F,
650 ) -> Result<&mut MethodCallback, RegisterMethodError>
651 where
652 Context: Send + Sync + 'static,
653 R: IntoResponse + 'static,
654 F: Fn(Params, Arc<Context>, Extensions) -> R + Clone + Send + Sync + 'static,
655 {
656 let ctx = self.ctx.clone();
657 let callback = self.methods.verify_and_insert(
658 method_name,
659 MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
660 let ctx = ctx.clone();
661 let callback = callback.clone();
662
663 let extensions2 = extensions.clone();
666
667 tokio::task::spawn_blocking(move || {
668 let rp = callback(params, ctx, extensions2.clone()).into_response();
669 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions2)
670 })
671 .map(|result| match result {
672 Ok(r) => r,
673 Err(err) => {
674 tracing::error!(target: LOG_TARGET, "Join error for blocking RPC method: {:?}", err);
675 MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError))
676 .with_extensions(extensions)
677 }
678 })
679 .boxed()
680 })),
681 )?;
682
683 Ok(callback)
684 }
685
686 pub fn register_subscription<R, F, Fut>(
781 &mut self,
782 subscribe_method_name: &'static str,
783 notif_method_name: &'static str,
784 unsubscribe_method_name: &'static str,
785 callback: F,
786 ) -> Result<&mut MethodCallback, RegisterMethodError>
787 where
788 Context: Send + Sync + 'static,
789 F: (Fn(Params<'static>, PendingSubscriptionSink, Arc<Context>, Extensions) -> Fut)
790 + Send
791 + Sync
792 + Clone
793 + 'static,
794 Fut: Future<Output = R> + Send + 'static,
795 R: IntoSubscriptionCloseResponse + Send,
796 {
797 let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
798 let ctx = self.ctx.clone();
799
800 let callback = {
802 self.methods.verify_and_insert(
803 subscribe_method_name,
804 MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
805 let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
806
807 let (tx, rx) = oneshot::channel();
809 let (accepted_tx, accepted_rx) = oneshot::channel();
810
811 let sub_id = uniq_sub.sub_id.clone();
812 let method = notif_method_name;
813
814 let sink = PendingSubscriptionSink {
815 inner: method_sink.clone(),
816 method: notif_method_name,
817 subscribers: subscribers.clone(),
818 uniq_sub,
819 id: id.clone().into_owned(),
820 subscribe: tx,
821 permit: conn.subscription_permit,
822 };
823
824 let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone());
832
833 tokio::spawn(async move {
834 let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await {
836 Ok((r, _)) => r.into_response(),
837 Err(_) => return,
839 };
840
841 match response {
842 SubscriptionCloseResponse::Notif(msg) => {
843 let json = sub_message_to_json(msg, &sub_id, method);
844 let _ = method_sink.send(json).await;
845 }
846 SubscriptionCloseResponse::NotifErr(err) => {
847 let json = sub_err_to_json(err, sub_id, method);
848 let _ = method_sink.send(json).await;
849 }
850 SubscriptionCloseResponse::None => (),
851 }
852 });
853
854 let id = id.clone().into_owned();
855
856 Box::pin(async move {
857 let rp = match rx.await {
858 Ok(rp) => {
859 if rp.is_success() {
862 let _ = accepted_tx.send(());
863 }
864 rp
865 }
866 Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
867 };
868
869 rp.with_extensions(extensions)
870 })
871 })),
872 )?
873 };
874
875 Ok(callback)
876 }
877
878 pub fn register_subscription_raw<R, F>(
924 &mut self,
925 subscribe_method_name: &'static str,
926 notif_method_name: &'static str,
927 unsubscribe_method_name: &'static str,
928 callback: F,
929 ) -> Result<&mut MethodCallback, RegisterMethodError>
930 where
931 Context: Send + Sync + 'static,
932 F: (Fn(Params, PendingSubscriptionSink, Arc<Context>, &Extensions) -> R) + Send + Sync + Clone + 'static,
933 R: IntoSubscriptionCloseResponse,
934 {
935 let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
936 let ctx = self.ctx.clone();
937
938 let callback = {
940 self.methods.verify_and_insert(
941 subscribe_method_name,
942 MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
943 let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
944
945 let (tx, rx) = oneshot::channel();
947
948 let sink = PendingSubscriptionSink {
949 inner: method_sink.clone(),
950 method: notif_method_name,
951 subscribers: subscribers.clone(),
952 uniq_sub,
953 id: id.clone().into_owned(),
954 subscribe: tx,
955 permit: conn.subscription_permit,
956 };
957
958 callback(params, sink, ctx.clone(), &extensions);
959
960 let id = id.clone().into_owned();
961
962 Box::pin(async move {
963 let rp = match rx.await {
964 Ok(rp) => rp,
965 Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
966 };
967
968 rp.with_extensions(extensions)
969 })
970 })),
971 )?
972 };
973
974 Ok(callback)
975 }
976
977 fn verify_and_register_unsubscribe(
980 &mut self,
981 subscribe_method_name: &'static str,
982 unsubscribe_method_name: &'static str,
983 ) -> Result<Subscribers, RegisterMethodError> {
984 if subscribe_method_name == unsubscribe_method_name {
985 return Err(RegisterMethodError::SubscriptionNameConflict(subscribe_method_name.into()));
986 }
987
988 self.methods.verify_method_name(subscribe_method_name)?;
989 self.methods.verify_method_name(unsubscribe_method_name)?;
990
991 let subscribers = Subscribers::default();
992
993 {
995 let subscribers = subscribers.clone();
996 self.methods.mut_callbacks().insert(
997 unsubscribe_method_name,
998 MethodCallback::Unsubscription(Arc::new(move |id, params, conn_id, max_response_size, extensions| {
999 let sub_id = match params.one::<RpcSubscriptionId>() {
1000 Ok(sub_id) => sub_id,
1001 Err(_) => {
1002 tracing::warn!(
1003 target: LOG_TARGET,
1004 "Unsubscribe call `{}` failed: couldn't parse subscription id={:?} request id={:?}",
1005 unsubscribe_method_name,
1006 params,
1007 id
1008 );
1009
1010 return MethodResponse::response(id, ResponsePayload::success(false), max_response_size)
1011 .with_extensions(extensions);
1012 }
1013 };
1014
1015 let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() };
1016 let result = subscribers.lock().remove(&key).is_some();
1017
1018 if !result {
1019 tracing::debug!(
1020 target: LOG_TARGET,
1021 "Unsubscribe call `{}` subscription key={:?} not an active subscription",
1022 unsubscribe_method_name,
1023 key,
1024 );
1025 }
1026
1027 MethodResponse::response(id, ResponsePayload::success(result), max_response_size)
1028 })),
1029 );
1030 }
1031
1032 Ok(subscribers)
1033 }
1034
1035 pub fn register_alias(
1037 &mut self,
1038 alias: &'static str,
1039 existing_method: &'static str,
1040 ) -> Result<(), RegisterMethodError> {
1041 self.methods.verify_method_name(alias)?;
1042
1043 let callback = match self.methods.callbacks.get(existing_method) {
1044 Some(callback) => callback.clone(),
1045 None => return Err(RegisterMethodError::MethodNotFound(existing_method.into())),
1046 };
1047
1048 self.methods.mut_callbacks().insert(alias, callback);
1049
1050 Ok(())
1051 }
1052}
1053
1054fn mock_subscription_permit() -> SubscriptionPermit {
1055 BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed")
1056}