1use super::*;
2
3use std::future::Future;
4use std::pin::Pin;
5
6type ServerRpcHandlerResponse<E> = Result<Vec<u8>, ServerRpcProtocolError<E>>;
7type ServerRpcHandlerFuture<'a, E> =
8 Pin<Box<dyn Send + Future<Output = ServerRpcHandlerResponse<E>> + 'a>>;
9type ServerRpcHandlerType<S, E> =
10 Box<dyn 'static + Send + Sync + for<'a> Fn(S, &'a [u8]) -> ServerRpcHandlerFuture<'a, E>>;
11
12#[allow(dead_code)]
16pub struct ServerRpcHandler<S: ServerRpcService> {
17 uri: &'static str,
18 handler: ServerRpcHandlerType<S::ServerState, <S::Format as RpcFormat>::Error>,
19}
20
21impl<S: ServerRpcService> ServerRpcHandler<S> {
22 pub fn new<M, H, F>(uri: &'static str, handler: H) -> Self
24 where
25 M: Send + RpcMethod<S>,
26 H: 'static + Send + Sync + Clone + Fn(S::ServerState, M) -> F,
27 F: Send + Future<Output = Result<M::Response, M::Error>>,
28 S::ServerState: 'static + Send,
29 {
30 Self {
31 uri,
32 handler: Box::new(move |state, buffer| {
33 let handler = handler.clone();
34 Box::pin(async move {
35 let inner = async move {
36 let req = <S::Format as RpcFormat>::deserialize_request(buffer)
37 .map_err(|e| RpcError::ServerDeserializeError(e))?;
38
39 let res = handler(state, req)
40 .await
41 .map_err(|e| RpcError::HandlerError(e))?;
42 <Result<M::Response, RpcError<M::Error, _>>>::Ok(res)
43 };
44
45 let res = <S::Format as RpcFormat>::serialize_response(inner.await)?;
46
47 Ok(res)
48 })
49 }),
50 }
51 }
52}
53
54pub trait ServerRpcService: RpcService + Sized {
59 type ServerState;
60 type RegistryItem: ServerRpcRegistryItem<Self>;
61}
62
63pub trait ServerRpcRegistryItem<S: ServerRpcService> {
64 fn handler(&self) -> &ServerRpcHandler<S>;
65}
66
67#[derive(thiserror::Error, Debug)]
69pub enum ServerRpcProtocolError<E> {
70 #[error("rpc serialize error")]
72 SerializeError(#[from] E),
73}
74
75#[cfg(feature = "server")]
77pub fn find_rpc_handler<S: ServerRpcService>(uri: &str) -> Option<&'static ServerRpcHandler<S>>
78where
79 &'static S::RegistryItem: inventory::Collect,
80{
81 inventory::iter::<&'static S::RegistryItem>
82 .into_iter()
83 .map(|h| h.handler())
84 .filter(|h| h.uri == uri)
85 .next()
86}
87
88#[cfg(feature = "server")]
90pub async fn handle_rpc<S: 'static + ServerRpcService>(
91 uri: &str,
92 state: S::ServerState,
93 payload: &[u8],
94) -> Result<Vec<u8>, ServerRpcProtocolError<<S::Format as RpcFormat>::Error>>
95where
96 &'static S::RegistryItem: inventory::Collect,
97{
98 if let Some(handler) = find_rpc_handler::<S>(uri) {
99 (handler.handler)(state, payload).await
100 } else {
101 Ok(<S::Format as RpcFormat>::serialize_response::<(), ()>(
102 Err(RpcError::NoEndpointFound),
103 )?)
104 }
105}