use std::future::Future;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use crate::error::QrpcResult;
pub trait QrpcMessage: Send + Sync + Sized + 'static {
fn cmd_id(&self) -> u32;
fn encode_vec(&self) -> Vec<u8>;
fn decode_vec(cmd_id: u32, data: &[u8]) -> QrpcResult<Self>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);
impl<S> Deref for State<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> DerefMut for State<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub trait FromRef<T> {
fn from_ref(input: &T) -> Self;
}
impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}
pub trait QrpcCallback<S, M>: Send + Sync + 'static
where
S: Send + Sync + 'static,
M: QrpcMessage,
{
fn call(
&self,
state: &S,
ctx: Ctx<M>,
source_peer_id: String,
message: M,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send>>;
}
pub trait QrpcDispatcher<M>: Send + Sync + 'static
where
M: QrpcMessage,
{
fn instance_id(&self) -> &str;
fn send_to<'a>(
&'a self,
target_id: &'a str,
message: &'a M,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send + 'a>>;
fn broadcast<'a>(
&'a self,
message: &'a M,
) -> Pin<Box<dyn Future<Output=QrpcResult<usize>> + Send + 'a>>;
fn peer_ids<'a>(&'a self) -> Pin<Box<dyn Future<Output=Vec<String>> + Send + 'a>>;
fn wait_for_peer<'a>(
&'a self,
peer_id: &'a str,
timeout: std::time::Duration,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send + 'a>>;
fn shutdown<'a>(&'a self) -> Pin<Box<dyn Future<Output=()> + Send + 'a>>;
}
#[derive(Clone)]
pub struct Ctx<M>
where
M: QrpcMessage,
{
instance_id: String,
dispatcher: Arc<dyn QrpcDispatcher<M>>,
}
impl<M> Ctx<M>
where
M: QrpcMessage,
{
pub(crate) fn new(dispatcher: Arc<dyn QrpcDispatcher<M>>) -> Self {
Self {
instance_id: dispatcher.instance_id().to_string(),
dispatcher,
}
}
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub async fn send_to(&self, target_id: &str, message: &M) -> QrpcResult<()> {
self.dispatcher.send_to(target_id, message).await
}
pub async fn broadcast(&self, message: &M) -> QrpcResult<usize> {
self.dispatcher.broadcast(message).await
}
pub async fn peer_ids(&self) -> Vec<String> {
self.dispatcher.peer_ids().await
}
pub async fn wait_for_peer(
&self,
peer_id: &str,
timeout: std::time::Duration,
) -> QrpcResult<()> {
self.dispatcher.wait_for_peer(peer_id, timeout).await
}
pub async fn shutdown(&self) {
self.dispatcher.shutdown().await;
}
}
pub struct FromRefCallback<T, F, Fut> {
inner: F,
_marker: PhantomData<fn() -> (T, Fut)>,
}
impl<T, F, Fut> FromRefCallback<T, F, Fut> {
pub fn new(inner: F) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T, M, F, Fut> QrpcCallback<S, M> for FromRefCallback<T, F, Fut>
where
S: Send + Sync + 'static,
T: FromRef<S> + Send + 'static,
M: QrpcMessage,
F: Fn(State<T>, Ctx<M>, String, M) -> Fut + Send + Sync + 'static,
Fut: Future<Output=QrpcResult<()>> + Send + 'static,
{
fn call(
&self,
state: &S,
ctx: Ctx<M>,
source_peer_id: String,
message: M,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send>> {
let extracted = T::from_ref(state);
Box::pin((self.inner)(State(extracted), ctx, source_peer_id, message))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::QrpcError;
use std::sync::Arc;
#[derive(Clone)]
struct DemoMsg(Vec<u8>);
impl QrpcMessage for DemoMsg {
fn cmd_id(&self) -> u32 {
42
}
fn encode_vec(&self) -> Vec<u8> {
self.0.clone()
}
fn decode_vec(cmd_id: u32, data: &[u8]) -> QrpcResult<Self> {
if cmd_id != 42 {
return Err(QrpcError::MessageDecode("bad cmd".to_string()));
}
Ok(Self(data.to_vec()))
}
}
struct TestDispatcher;
impl QrpcDispatcher<DemoMsg> for TestDispatcher {
fn instance_id(&self) -> &str {
"demo-node"
}
fn send_to<'a>(
&'a self,
_target_id: &'a str,
_message: &'a DemoMsg,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send + 'a>> {
Box::pin(async move { Ok(()) })
}
fn broadcast<'a>(
&'a self,
_message: &'a DemoMsg,
) -> Pin<Box<dyn Future<Output=QrpcResult<usize>> + Send + 'a>> {
Box::pin(async move { Ok(0) })
}
fn peer_ids<'a>(&'a self) -> Pin<Box<dyn Future<Output=Vec<String>> + Send + 'a>> {
Box::pin(async move { vec!["peer-1".to_string()] })
}
fn wait_for_peer<'a>(
&'a self,
_peer_id: &'a str,
_timeout: std::time::Duration,
) -> Pin<Box<dyn Future<Output=QrpcResult<()>> + Send + 'a>> {
Box::pin(async move { Ok(()) })
}
fn shutdown<'a>(&'a self) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
Box::pin(async move {})
}
}
#[tokio::test]
async fn blanket_callback_impl_works() {
let state = 7usize;
let cb = FromRefCallback::new(
|State(s): State<usize>, ctx: Ctx<DemoMsg>, source_peer_id: String, msg: DemoMsg| async move {
assert_eq!(s, 7);
assert_eq!(ctx.instance_id(), "demo-node");
assert_eq!(source_peer_id, "peer-1");
assert_eq!(msg.0, b"ok");
Ok(())
},
);
let ctx = Ctx::new(Arc::new(TestDispatcher));
QrpcCallback::call(
&cb,
&state,
ctx,
"peer-1".to_string(),
DemoMsg(b"ok".to_vec()),
)
.await
.expect("callback must succeed");
}
}