use super::{RpcClientT, RawRpcFuture, RawRpcSubscription};
use crate::{Error, UserError};
use core::future::Future;
use futures::StreamExt;
use serde_json::value::RawValue;
use std::sync::{Arc, Mutex};
use std::collections::{HashMap, VecDeque};
type MethodHandlerFnOnce = Box<dyn FnOnce(&str, Option<Box<serde_json::value::RawValue>>) -> RawRpcFuture<'static, Box<RawValue>> + Send + Sync + 'static>;
type SubscriptionHandlerFnOnce = Box<dyn FnOnce(&str, Option<Box<serde_json::value::RawValue>>, &str) -> RawRpcFuture<'static, RawRpcSubscription> + Send + Sync + 'static>;
type MethodHandlerFn = Box<dyn FnMut(&str, Option<Box<serde_json::value::RawValue>>) -> RawRpcFuture<'static, Box<RawValue>> + Send + Sync + 'static>;
type SubscriptionHandlerFn = Box<dyn FnMut(&str, Option<Box<serde_json::value::RawValue>>, &str) -> RawRpcFuture<'static, RawRpcSubscription> + Send + Sync + 'static>;
#[derive(Default)]
pub struct MockRpcClientBuilder {
method_handlers_once: HashMap<String, VecDeque<MethodHandlerFnOnce>>,
method_handlers: HashMap<String, MethodHandlerFn>,
method_fallback: Option<MethodHandlerFn>,
subscription_handlers_once: HashMap<String, VecDeque<SubscriptionHandlerFnOnce>>,
subscription_handlers: HashMap<String, SubscriptionHandlerFn>,
subscription_fallback: Option<SubscriptionHandlerFn>
}
impl MockRpcClientBuilder {
pub fn method_handler_once<MethodHandler, MFut, MRes>(mut self, name: impl Into<String>, f: MethodHandler) -> Self
where
MethodHandler: FnOnce(Option<Box<serde_json::value::RawValue>>) -> MFut + Send + Sync + 'static,
MFut: Future<Output = MRes> + Send + 'static,
MRes: IntoHandlerResponse,
{
let handler: MethodHandlerFnOnce = Box::new(move |_method: &str, params: Option<Box<serde_json::value::RawValue>>| {
let fut = f(params);
Box::pin(async move { fut.await.into_handler_response() })
});
self.method_handlers_once.entry(name.into()).or_default().push_back(handler);
self
}
pub fn method_handler<MethodHandler, MFut, MRes>(mut self, name: impl Into<String>, mut f: MethodHandler) -> Self
where
MethodHandler: FnMut(Option<Box<serde_json::value::RawValue>>) -> MFut + Send + Sync + 'static,
MFut: Future<Output = MRes> + Send + 'static,
MRes: IntoHandlerResponse,
{
let handler: MethodHandlerFn = Box::new(move |_method: &str, params: Option<Box<serde_json::value::RawValue>>| {
let fut = f(params);
Box::pin(async move { fut.await.into_handler_response() })
});
self.method_handlers.insert(name.into(), handler);
self
}
pub fn method_fallback<MethodHandler, MFut, MRes>(mut self, mut f: MethodHandler) -> Self
where
MethodHandler: FnMut(String, Option<Box<serde_json::value::RawValue>>) -> MFut + Send + Sync + 'static,
MFut: Future<Output = MRes> + Send + 'static,
MRes: IntoHandlerResponse,
{
let handler: MethodHandlerFn = Box::new(move |method: &str, params: Option<Box<serde_json::value::RawValue>>| {
let fut = f(method.to_owned(), params);
Box::pin(async move { fut.await.into_handler_response() })
});
self.method_fallback = Some(handler);
self
}
pub fn subscription_handler_once<SubscriptionHandler, SFut, SRes>(mut self, name: impl Into<String>, f: SubscriptionHandler) -> Self
where
SubscriptionHandler: FnOnce(Option<Box<serde_json::value::RawValue>>, String) -> SFut + Send + Sync + 'static,
SFut: Future<Output = SRes> + Send + 'static,
SRes: IntoSubscriptionResponse,
{
let handler: SubscriptionHandlerFnOnce = Box::new(move |_sub: &str, params: Option<Box<serde_json::value::RawValue>>, unsub: &str| {
let fut = f(params, unsub.to_owned());
Box::pin(async move { fut.await.into_subscription_response() })
});
self.subscription_handlers_once.entry(name.into()).or_default().push_back(handler);
self
}
pub fn subscription_handler<SubscriptionHandler, SFut, SRes>(mut self, name: impl Into<String>, mut f: SubscriptionHandler) -> Self
where
SubscriptionHandler: FnMut(Option<Box<serde_json::value::RawValue>>, String) -> SFut + Send + Sync + 'static,
SFut: Future<Output = SRes> + Send + 'static,
SRes: IntoSubscriptionResponse,
{
let handler: SubscriptionHandlerFn = Box::new(move |_sub: &str, params: Option<Box<serde_json::value::RawValue>>, unsub: &str| {
let fut = f(params, unsub.to_owned());
Box::pin(async move { fut.await.into_subscription_response() })
});
self.subscription_handlers.insert(name.into(), handler);
self
}
pub fn subscription_fallback<SubscriptionHandler, SFut, SRes>(mut self, mut f: SubscriptionHandler) -> Self
where
SubscriptionHandler: FnMut(String, Option<Box<serde_json::value::RawValue>>, String) -> SFut + Send + Sync + 'static,
SFut: Future<Output = SRes> + Send + 'static,
SRes: IntoSubscriptionResponse,
{
let handler: SubscriptionHandlerFn = Box::new(move |sub: &str, params: Option<Box<serde_json::value::RawValue>>, unsub: &str| {
let fut = f(sub.to_owned(), params, unsub.to_owned());
Box::pin(async move { fut.await.into_subscription_response() })
});
self.subscription_fallback = Some(handler);
self
}
pub fn build(self) -> MockRpcClient {
MockRpcClient {
method_handlers_once: Arc::new(Mutex::new(self.method_handlers_once)),
method_handlers: Arc::new(Mutex::new(self.method_handlers)),
method_fallback: self.method_fallback.map(|f| Arc::new(Mutex::new(f))),
subscription_handlers_once: Arc::new(Mutex::new(self.subscription_handlers_once)),
subscription_handlers: Arc::new(Mutex::new(self.subscription_handlers)),
subscription_fallback: self.subscription_fallback.map(|f| Arc::new(Mutex::new(f))),
}
}
}
#[derive(Clone)]
pub struct MockRpcClient {
method_handlers_once: Arc<Mutex<HashMap<String, VecDeque<MethodHandlerFnOnce>>>>,
method_handlers: Arc<Mutex<HashMap<String, MethodHandlerFn>>>,
method_fallback: Option<Arc<Mutex<MethodHandlerFn>>>,
subscription_handlers_once: Arc<Mutex<HashMap<String, VecDeque<SubscriptionHandlerFnOnce>>>>,
subscription_handlers: Arc<Mutex<HashMap<String, SubscriptionHandlerFn>>>,
subscription_fallback: Option<Arc<Mutex<SubscriptionHandlerFn>>>,
}
impl MockRpcClient {
pub fn builder() -> MockRpcClientBuilder {
MockRpcClientBuilder::default()
}
}
impl RpcClientT for MockRpcClient {
fn request_raw<'a>(
&'a self,
method: &'a str,
params: Option<Box<serde_json::value::RawValue>>,
) -> RawRpcFuture<'a, Box<serde_json::value::RawValue>> {
let mut handlers_once = self.method_handlers_once.lock().unwrap();
if let Some(handlers) = handlers_once.get_mut(method) {
if let Some(handler) = handlers.pop_front() {
return handler(method, params)
}
}
drop(handlers_once);
let mut handlers = self.method_handlers.lock().unwrap();
if let Some(handler) = handlers.get_mut(method) {
return handler(method, params)
}
drop(handlers);
if let Some(handler) = &self.method_fallback {
let mut handler = handler.lock().unwrap();
return handler(method, params)
}
Box::pin(async move { Err(UserError::method_not_found().into()) })
}
fn subscribe_raw<'a>(
&'a self,
sub: &'a str,
params: Option<Box<serde_json::value::RawValue>>,
unsub: &'a str,
) -> RawRpcFuture<'a, RawRpcSubscription> {
let mut handlers_once = self.subscription_handlers_once.lock().unwrap();
if let Some(handlers) = handlers_once.get_mut(sub) {
if let Some(handler) = handlers.pop_front() {
return handler(sub, params, unsub)
}
}
drop(handlers_once);
let mut handlers = self.subscription_handlers.lock().unwrap();
if let Some(handler) = handlers.get_mut(sub) {
return handler(sub, params, unsub)
}
drop(handlers);
if let Some(handler) = &self.subscription_fallback {
let mut handler = handler.lock().unwrap();
return handler(sub, params, unsub)
}
Box::pin(async move { Err(UserError::method_not_found().into()) })
}
}
pub struct Json<T>(pub T);
impl Json<serde_json::Value> {
pub fn value_of<T: serde::Serialize>(item: T) -> Self {
Json(serde_json::to_value(item).expect("item cannot be converted to a serde_json::Value"))
}
}
pub trait IntoHandlerResponse {
fn into_handler_response(self) -> Result<Box<RawValue>, Error>;
}
impl <T: IntoHandlerResponse> IntoHandlerResponse for Result<T, Error> {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
self.and_then(|val| val.into_handler_response())
}
}
impl <T: IntoHandlerResponse> IntoHandlerResponse for Option<T> {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
self.ok_or_else(|| UserError::method_not_found().into())
.and_then(|val| val.into_handler_response())
}
}
impl IntoHandlerResponse for Box<RawValue> {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
Ok(self)
}
}
impl IntoHandlerResponse for serde_json::Value {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
serialize_to_raw_value(&self)
}
}
impl <T: serde::Serialize> IntoHandlerResponse for Json<T> {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
serialize_to_raw_value(&self.0)
}
}
impl IntoHandlerResponse for core::convert::Infallible {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
match self {}
}
}
fn serialize_to_raw_value<T: serde::Serialize>(val: &T) -> Result<Box<RawValue>, Error> {
let res = serde_json::to_string(val).map_err(Error::Serialization)?;
let raw_value = RawValue::from_string(res).map_err(Error::Serialization)?;
Ok(raw_value)
}
pub trait IntoSubscriptionResponse {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error>;
}
impl <T: IntoSubscriptionResponse, S: Into<String>> IntoSubscriptionResponse for (T, S) {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
self.0
.into_subscription_response()
.map(|mut r| {
r.id = Some(self.1.into());
r
})
}
}
impl <T: IntoHandlerResponse + Send + 'static> IntoSubscriptionResponse for tokio::sync::mpsc::Receiver<T> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
struct IntoStream<T>(tokio::sync::mpsc::Receiver<T>);
impl <T> futures::Stream for IntoStream<T> {
type Item = T;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
Ok(RawRpcSubscription {
stream: Box::pin(IntoStream(self).map(|item| item.into_handler_response())),
id: None,
})
}
}
impl <T: IntoHandlerResponse + Send + 'static> IntoSubscriptionResponse for tokio::sync::mpsc::UnboundedReceiver<T> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
struct IntoStream<T>(tokio::sync::mpsc::UnboundedReceiver<T>);
impl <T> futures::Stream for IntoStream<T> {
type Item = T;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
Ok(RawRpcSubscription {
stream: Box::pin(IntoStream(self).map(|item| item.into_handler_response())),
id: None,
})
}
}
impl IntoSubscriptionResponse for RawRpcSubscription {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
Ok(self)
}
}
impl <T: IntoSubscriptionResponse> IntoSubscriptionResponse for Result<T, Error> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
self.and_then(|res| res.into_subscription_response())
}
}
impl <T: IntoHandlerResponse + Send + 'static> IntoSubscriptionResponse for Vec<T> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
let iter = self.into_iter().map(|item| item.into_handler_response());
Ok(RawRpcSubscription {
stream: Box::pin(futures::stream::iter(iter)),
id: None,
})
}
}
impl <T: IntoSubscriptionResponse + Send + 'static> IntoSubscriptionResponse for Option<T> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
match self {
Some(sub) => {
sub.into_subscription_response()
},
None => {
Ok(RawRpcSubscription {
stream: Box::pin(futures::stream::empty()),
id: None,
})
}
}
}
}
impl <T: IntoHandlerResponse + Send + 'static, const N: usize> IntoSubscriptionResponse for [T; N] {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
let iter = self.into_iter().map(|item| item.into_handler_response());
Ok(RawRpcSubscription {
stream: Box::pin(futures::stream::iter(iter)),
id: None,
})
}
}
impl IntoSubscriptionResponse for core::convert::Infallible {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
match self {}
}
}
pub struct AndThen<A, B>(pub A, pub B);
impl <A: IntoSubscriptionResponse, B: IntoSubscriptionResponse> IntoSubscriptionResponse for AndThen<A, B> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
let a_responses = self.0.into_subscription_response();
let b_responses = self.1.into_subscription_response();
match (a_responses, b_responses) {
(Err(a), _) => {
Err(a)
},
(_, Err(b)) => {
Err(b)
},
(Ok(mut a), Ok(b)) => {
a.stream = Box::pin(a.stream.chain(b.stream));
a.id = a.id.or(b.id);
Ok(a)
}
}
}
}
pub enum Either<A, B> {
A(A),
B(B)
}
impl <A: IntoHandlerResponse, B: IntoHandlerResponse> IntoHandlerResponse for Either<A, B> {
fn into_handler_response(self) -> Result<Box<RawValue>, Error> {
match self {
Either::A(a) => a.into_handler_response(),
Either::B(b) => b.into_handler_response(),
}
}
}
impl <A: IntoSubscriptionResponse, B: IntoSubscriptionResponse> IntoSubscriptionResponse for Either<A, B> {
fn into_subscription_response(self) -> Result<RawRpcSubscription, Error> {
match self {
Either::A(a) => a.into_subscription_response(),
Either::B(b) => b.into_subscription_response(),
}
}
}
#[cfg(test)]
mod test {
use crate::{RpcClient, rpc_params};
use super::*;
#[tokio::test]
async fn test_method_params() {
let rpc_client = MockRpcClient::builder()
.method_handler("foo", async |params| {
Json(params)
})
.build();
let rpc_client = RpcClient::new(rpc_client);
let res: (i32,i32,i32) = rpc_client.request("foo", rpc_params![1, 2, 3]).await.unwrap();
assert_eq!(res, (1,2,3));
let res: (String,) = rpc_client.request("foo", rpc_params!["hello"]).await.unwrap();
assert_eq!(res, ("hello".to_owned(),));
}
#[tokio::test]
async fn test_method_handler_then_fallback() {
let rpc_client = MockRpcClient::builder()
.method_handler("foo", async |_params| {
Json(1)
})
.method_fallback(async |name, _params| {
Json(name)
})
.build();
let rpc_client = RpcClient::new(rpc_client);
for i in [1,1,1,1] {
let res: i32 = rpc_client.request("foo", rpc_params![]).await.unwrap();
assert_eq!(res, i);
}
for name in ["bar", "wibble", "steve"] {
let res: String = rpc_client.request(name, rpc_params![]).await.unwrap();
assert_eq!(res, name);
}
}
#[tokio::test]
async fn test_method_once_then_handler() {
let rpc_client = MockRpcClient::builder()
.method_handler_once("foo", async |_params| {
Json(1)
})
.method_handler("foo", async |_params| {
Json(2)
})
.build();
let rpc_client = RpcClient::new(rpc_client);
for i in [1,2,2,2,2] {
let res: i32 = rpc_client.request("foo", rpc_params![]).await.unwrap();
assert_eq!(res, i);
}
}
#[tokio::test]
async fn test_method_once() {
let rpc_client = MockRpcClient::builder()
.method_handler_once("foo", async |_params| {
Json(1)
})
.method_handler_once("foo", async |_params| {
Json(2)
})
.method_handler_once("foo", async |_params| {
Json(3)
})
.build();
let rpc_client = RpcClient::new(rpc_client);
for i in [1,2,3] {
let res: i32 = rpc_client.request("foo", rpc_params![]).await.unwrap();
assert_eq!(res, i);
}
let err = rpc_client.request::<i32>("foo", rpc_params![]).await.unwrap_err();
let not_found_code = UserError::method_not_found().code;
assert!(matches!(err, Error::User(u) if u.code == not_found_code));
}
#[tokio::test]
async fn test_subscription_once_then_handler_then_fallback() {
let rpc_client = MockRpcClient::builder()
.subscription_handler_once("foo", async |_params, _unsub| {
vec![Json(0), Json(0)]
})
.subscription_handler("foo", async |_params, _unsub| {
vec![Json(1), Json(2), Json(3)]
})
.subscription_fallback(async |_name, _params, _unsub| {
vec![Json(4)]
})
.build();
let rpc_client = RpcClient::new(rpc_client);
let sub = rpc_client.subscribe::<i32>("foo", rpc_params![], "unsub").await.unwrap();
let res: Vec<i32> = sub.map(|i| i.unwrap()).collect().await;
assert_eq!(res, vec![0,0]);
for _ in 1..5 {
let sub = rpc_client.subscribe::<i32>("foo", rpc_params![], "unsub").await.unwrap();
let res: Vec<i32> = sub.map(|i| i.unwrap()).collect().await;
assert_eq!(res, vec![1,2,3]);
}
let sub = rpc_client.subscribe::<i32>("bar", rpc_params![], "unsub").await.unwrap();
let res: Vec<i32> = sub.map(|i| i.unwrap()).collect().await;
assert_eq!(res, vec![4]);
}
#[tokio::test]
async fn test_subscription_and_then_with_channel() {
let (tx, rx) = tokio::sync::mpsc::channel(10);
let rpc_client = MockRpcClient::builder()
.subscription_handler_once("foo", async move |_params, _unsub| {
AndThen(
vec![Json(1), Json(2), Json(3)],
rx
)
})
.build();
let rpc_client = RpcClient::new(rpc_client);
tokio::spawn(async move {
for i in 4..=6 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
tx.send(Json(i)).await.unwrap();
}
});
let sub = rpc_client.subscribe::<i32>("foo", rpc_params![], "unsub").await.unwrap();
let res: Vec<i32> = sub.map(|i| i.unwrap()).collect().await;
assert_eq!(res, vec![1,2,3,4,5,6]);
}
}