use std::sync::Arc;
use futures::{
FutureExt,
future::BoxFuture,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tower::{Service, make::MakeService};
use super::RpcServerError;
use crate::{
Bytes,
Substream,
protocol::{
ProtocolExtension,
ProtocolExtensionContext,
ProtocolExtensionError,
ProtocolId,
ProtocolNotificationRx,
rpc::{
RpcError,
RpcServer,
RpcStatus,
body::Body,
context::{RpcCommsBackend, RpcCommsProvider},
either::Either,
message::{Request, Response},
not_found::ProtocolServiceNotFound,
server::{NamedProtocolService, RpcServerHandle},
},
},
};
pub struct Router<A, B> {
server: RpcServer,
protocol_names: Vec<ProtocolId>,
routes: Or<A, B>,
}
impl<A> Router<A, ProtocolServiceNotFound>
where A: NamedProtocolService
{
pub fn new(server: RpcServer, service: A) -> Self {
let expected_protocol = ProtocolId::from_static(<A as NamedProtocolService>::PROTOCOL_NAME);
let protocols = vec![expected_protocol.clone()];
let predicate = move |protocol: &ProtocolId| expected_protocol == protocol;
Self {
protocol_names: protocols,
server,
routes: Or::new(predicate, service, ProtocolServiceNotFound),
}
}
}
impl<A, B> Router<A, B> {
pub fn add_service<T>(mut self, service: T) -> Router<T, Or<A, B>>
where T: NamedProtocolService {
let expected_protocol = ProtocolId::from_static(<T as NamedProtocolService>::PROTOCOL_NAME);
self.protocol_names.push(expected_protocol.clone());
let predicate = move |protocol: &ProtocolId| expected_protocol == protocol;
Router {
protocol_names: self.protocol_names,
server: self.server,
routes: Or::new(predicate, service, self.routes),
}
}
pub fn get_handle(&self) -> RpcServerHandle {
self.server.get_handle()
}
pub fn into_boxed(self) -> Box<Self>
where Self: 'static {
Box::new(self)
}
#[allow(dead_code)]
pub(crate) fn all_protocols(&mut self) -> &[ProtocolId] {
&self.protocol_names
}
}
impl<A, B> Router<A, B>
where
A: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send
+ 'static,
A::Service: Send + 'static,
A::Future: Send + 'static,
<A::Service as Service<Request<Bytes>>>::Future: Send + 'static,
B: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send
+ 'static,
B::Service: Send + 'static,
B::Future: Send + 'static,
<B::Service as Service<Request<Bytes>>>::Future: Send + 'static,
{
pub(crate) async fn serve<TCommsProvider>(
self,
protocol_notifications: ProtocolNotificationRx<Substream>,
comms_provider: TCommsProvider,
) -> Result<(), RpcError>
where
TCommsProvider: RpcCommsProvider + Clone + Send + 'static,
{
self.server
.serve(self.routes, protocol_notifications, comms_provider)
.await
.map_err(Into::into)
}
}
impl<A, B> Service<ProtocolId> for Router<A, B>
where
A: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send,
B: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send,
A::Future: Send + 'static,
B::Future: Send + 'static,
{
type Error = <Or<A, B> as Service<ProtocolId>>::Error;
type Future = <Or<A, B> as Service<ProtocolId>>::Future;
type Response = <Or<A, B> as Service<ProtocolId>>::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Service::poll_ready(&mut self.routes, cx)
}
fn call(&mut self, protocol: ProtocolId) -> Self::Future {
Service::call(&mut self.routes, protocol)
}
}
impl<A, B> ProtocolExtension for Router<A, B>
where
A: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send
+ Sync
+ 'static,
A::Service: Send + 'static,
A::Future: Send + 'static,
<A::Service as Service<Request<Bytes>>>::Future: Send + 'static,
B: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send
+ Sync
+ 'static,
B::Service: Send + 'static,
B::Future: Send + 'static,
<B::Service as Service<Request<Bytes>>>::Future: Send + 'static,
{
fn install(self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
let (proto_notif_tx, proto_notif_rx) = mpsc::channel(20);
context.add_protocol(&self.protocol_names, &proto_notif_tx);
let rpc_context = RpcCommsBackend::new(context.peer_manager(), context.connectivity());
tokio::spawn(self.serve(proto_notif_rx, rpc_context));
Ok(())
}
}
pub struct Or<A, B> {
predicate: Arc<dyn Fn(&ProtocolId) -> bool + Send + Sync + 'static>,
a: A,
b: B,
}
impl<A, B> Or<A, B> {
pub fn new<P>(predicate: P, a: A, b: B) -> Self
where P: Fn(&ProtocolId) -> bool + Send + Sync + 'static {
Self {
predicate: Arc::new(predicate),
a,
b,
}
}
}
impl<A, B> Service<ProtocolId> for Or<A, B>
where
A: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send,
B: MakeService<
ProtocolId,
Request<Bytes>,
Response = Response<Body>,
Error = RpcStatus,
MakeError = RpcServerError,
> + Send,
A::Future: Send + 'static,
B::Future: Send + 'static,
{
type Error = A::MakeError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Response = Either<A::Service, B::Service>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, protocol: ProtocolId) -> Self::Future {
if (self.predicate)(&protocol) {
Either::A(self.a.make_service(protocol).map(|r| r.map(Either::A)))
} else {
Either::B(self.b.make_service(protocol).map(|r| r.map(Either::B)))
}
.boxed()
}
}
#[cfg(test)]
mod test {
use futures::{StreamExt, future};
use prost::Message;
use tari_test_utils::unpack_enum;
use tower::util::BoxService;
use super::*;
#[derive(Clone)]
struct HelloService;
impl NamedProtocolService for HelloService {
const PROTOCOL_NAME: &'static [u8] = b"hello";
}
impl Service<ProtocolId> for HelloService {
type Error = RpcServerError;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
type Response = BoxService<Request<Bytes>, Response<Body>, RpcStatus>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ProtocolId) -> Self::Future {
let my_service = tower::service_fn(|req: Request<Bytes>| {
let msg = req.into_message();
let str = String::from_utf8_lossy(&msg);
future::ready(Ok(Response::from_message(format!("Hello {str}"))))
});
future::ready(Ok(BoxService::new(my_service)))
}
}
#[derive(Clone)]
struct GoodbyeService;
impl NamedProtocolService for GoodbyeService {
const PROTOCOL_NAME: &'static [u8] = b"goodbye";
}
impl Service<ProtocolId> for GoodbyeService {
type Error = RpcServerError;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
type Response = BoxService<Request<Bytes>, Response<Body>, RpcStatus>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ProtocolId) -> Self::Future {
let my_service = tower::service_fn(|req: Request<Bytes>| {
let msg = req.into_message();
let str = String::from_utf8_lossy(&msg);
future::ready(Ok(Response::from_message(format!("Goodbye {str}"))))
});
future::ready(Ok(BoxService::new(my_service)))
}
}
#[tokio::test]
async fn find_route() {
let server = RpcServer::new();
let mut router = Router::new(server, HelloService).add_service(GoodbyeService);
assert_eq!(router.all_protocols(), &[
HelloService::PROTOCOL_NAME,
GoodbyeService::PROTOCOL_NAME
]);
let mut hello_svc = router.call(HelloService::PROTOCOL_NAME.into()).await.unwrap();
let req = Request::new(1.into(), b"Kerbal".to_vec().into());
let resp = hello_svc.call(req).await.unwrap();
let resp = resp.into_message().next().await.unwrap().unwrap().into_bytes_mut();
let s = String::decode(resp).unwrap();
assert_eq!(s, "Hello Kerbal");
let mut bye_svc = router.call(GoodbyeService::PROTOCOL_NAME.into()).await.unwrap();
let req = Request::new(1.into(), b"Xel'naga".to_vec().into());
let resp = bye_svc.call(req).await.unwrap();
let resp = resp.into_message().next().await.unwrap().unwrap().into_bytes_mut();
let s = String::decode(resp).unwrap();
assert_eq!(s, "Goodbye Xel'naga");
let result = router.call(ProtocolId::from_static(b"/totally/real/protocol")).await;
let err = match result {
Ok(_) => panic!("Unexpected success for non-existent route"),
Err(err) => err,
};
unpack_enum!(RpcServerError::ProtocolServiceNotFound(proto_str) = err);
assert_eq!(proto_str, "/totally/real/protocol");
}
}