use super::*;
use std::future::Future;
use std::pin::Pin;
type ServerRpcHandlerResponse<E> = Result<Vec<u8>, ServerRpcProtocolError<E>>;
type ServerRpcHandlerFuture<'a, E> =
Pin<Box<dyn Send + Future<Output = ServerRpcHandlerResponse<E>> + 'a>>;
type ServerRpcHandlerType<S, E> =
Box<dyn 'static + Send + Sync + for<'a> Fn(S, &'a [u8]) -> ServerRpcHandlerFuture<'a, E>>;
#[allow(dead_code)]
pub struct ServerRpcHandler<S: ServerRpcService> {
uri: &'static str,
handler: ServerRpcHandlerType<S::ServerState, <S::Format as RpcFormat>::Error>,
}
impl<S: ServerRpcService> ServerRpcHandler<S> {
pub fn new<M, H, F>(uri: &'static str, handler: H) -> Self
where
M: Send + RpcMethod<S>,
H: 'static + Send + Sync + Clone + Fn(S::ServerState, M) -> F,
F: Send + Future<Output = Result<M::Response, M::Error>>,
S::ServerState: 'static + Send,
{
Self {
uri,
handler: Box::new(move |state, buffer| {
let handler = handler.clone();
Box::pin(async move {
let inner = async move {
let req = <S::Format as RpcFormat>::deserialize_request(buffer)
.map_err(|e| RpcError::ServerDeserializeError(e))?;
let res = handler(state, req)
.await
.map_err(|e| RpcError::HandlerError(e))?;
<Result<M::Response, RpcError<M::Error, _>>>::Ok(res)
};
let res = <S::Format as RpcFormat>::serialize_response(inner.await)?;
Ok(res)
})
}),
}
}
}
pub trait ServerRpcService: RpcService + Sized {
type ServerState;
type RegistryItem: ServerRpcRegistryItem<Self>;
}
pub trait ServerRpcRegistryItem<S: ServerRpcService> {
fn handler(&self) -> &ServerRpcHandler<S>;
}
#[derive(thiserror::Error, Debug)]
pub enum ServerRpcProtocolError<E> {
#[error("rpc serialize error")]
SerializeError(#[from] E),
}
#[cfg(feature = "server")]
pub fn find_rpc_handler<S: ServerRpcService>(uri: &str) -> Option<&'static ServerRpcHandler<S>>
where
&'static S::RegistryItem: inventory::Collect,
{
inventory::iter::<&'static S::RegistryItem>
.into_iter()
.map(|h| h.handler())
.filter(|h| h.uri == uri)
.next()
}
#[cfg(feature = "server")]
pub async fn handle_rpc<S: 'static + ServerRpcService>(
uri: &str,
state: S::ServerState,
payload: &[u8],
) -> Result<Vec<u8>, ServerRpcProtocolError<<S::Format as RpcFormat>::Error>>
where
&'static S::RegistryItem: inventory::Collect,
{
if let Some(handler) = find_rpc_handler::<S>(uri) {
(handler.handler)(state, payload).await
} else {
Ok(<S::Format as RpcFormat>::serialize_response::<(), ()>(
Err(RpcError::NoEndpointFound),
)?)
}
}