use std::collections::HashMap;
use std::sync::Arc;
use buffa::Message;
use serde::Serialize;
use serde::de::DeserializeOwned;
use buffa::view::MessageView;
use crate::handler::BidiStreamingHandler;
use crate::handler::BidiStreamingHandlerWrapper;
use crate::handler::BidiStreamingViewHandlerWrapper;
use crate::handler::ClientStreamingHandler;
use crate::handler::ClientStreamingHandlerWrapper;
use crate::handler::ClientStreamingViewHandlerWrapper;
use crate::handler::ErasedBidiStreamingHandler;
use crate::handler::ErasedClientStreamingHandler;
use crate::handler::ErasedHandler;
use crate::handler::ErasedStreamingHandler;
use crate::handler::Handler;
use crate::handler::ServerStreamingHandlerWrapper;
use crate::handler::ServerStreamingViewHandlerWrapper;
use crate::handler::StreamingHandler;
use crate::handler::UnaryHandlerWrapper;
use crate::handler::UnaryViewHandlerWrapper;
use crate::handler::ViewBidiStreamingHandler;
use crate::handler::ViewClientStreamingHandler;
use crate::handler::ViewHandler;
use crate::handler::ViewStreamingHandler;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MethodKind {
Unary,
ServerStreaming,
ClientStreaming,
BidiStreaming,
}
struct UnaryMethod {
handler: Arc<dyn ErasedHandler>,
idempotent: bool,
}
struct StreamingMethod {
handler: Arc<dyn ErasedStreamingHandler>,
kind: MethodKind,
}
struct ClientStreamingMethod {
handler: Arc<dyn ErasedClientStreamingHandler>,
}
struct BidiStreamingMethod {
handler: Arc<dyn ErasedBidiStreamingHandler>,
}
enum Method {
Unary(UnaryMethod),
Streaming(StreamingMethod),
ClientStreaming(ClientStreamingMethod),
BidiStreaming(BidiStreamingMethod),
}
#[derive(Default)]
pub struct Router {
methods: HashMap<String, Method>,
}
impl Router {
pub fn new() -> Self {
Self::default()
}
pub fn route<H, Req, Res>(self, service_name: &str, method_name: &str, handler: H) -> Self
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
self.route_unary_internal(service_name, method_name, handler, false)
}
pub fn route_idempotent<H, Req, Res>(
self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
self.route_unary_internal(service_name, method_name, handler, true)
}
fn route_unary_internal<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
idempotent: bool,
) -> Self
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = UnaryHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::Unary(UnaryMethod {
handler: Arc::new(wrapper),
idempotent,
}),
);
self
}
pub fn route_server_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ServerStreamingHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::Streaming(StreamingMethod {
handler: Arc::new(wrapper),
kind: MethodKind::ServerStreaming,
}),
);
self
}
pub fn route_client_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ClientStreamingHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::ClientStreaming(ClientStreamingMethod {
handler: Arc::new(wrapper),
}),
);
self
}
pub fn route_bidi_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = BidiStreamingHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::BidiStreaming(BidiStreamingMethod {
handler: Arc::new(wrapper),
}),
);
self
}
pub fn route_view<H, ReqView, Res>(
self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
self.route_view_internal(service_name, method_name, handler, false)
}
pub fn route_view_idempotent<H, ReqView, Res>(
self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
self.route_view_internal(service_name, method_name, handler, true)
}
fn route_view_internal<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
idempotent: bool,
) -> Self
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = UnaryViewHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::Unary(UnaryMethod {
handler: Arc::new(wrapper),
idempotent,
}),
);
self
}
pub fn route_view_server_stream<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ServerStreamingViewHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::Streaming(StreamingMethod {
handler: Arc::new(wrapper),
kind: MethodKind::ServerStreaming,
}),
);
self
}
pub fn route_view_client_stream<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewClientStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ClientStreamingViewHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::ClientStreaming(ClientStreamingMethod {
handler: Arc::new(wrapper),
}),
);
self
}
pub fn route_view_bidi_stream<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = BidiStreamingViewHandlerWrapper::new(handler);
self.methods.insert(
path,
Method::BidiStreaming(BidiStreamingMethod {
handler: Arc::new(wrapper),
}),
);
self
}
pub fn methods(&self) -> impl Iterator<Item = &str> {
self.methods.keys().map(String::as_str)
}
pub fn has_method(&self, path: &str) -> bool {
self.methods.contains_key(path)
}
}
impl crate::dispatcher::Dispatcher for Router {
fn lookup(&self, path: &str) -> Option<crate::dispatcher::MethodDescriptor> {
use crate::dispatcher::MethodDescriptor;
match self.methods.get(path)? {
Method::Unary(m) => Some(MethodDescriptor::unary(m.idempotent)),
Method::Streaming(m) => Some(MethodDescriptor {
kind: m.kind,
idempotent: false,
}),
Method::ClientStreaming(_) => Some(MethodDescriptor::client_streaming()),
Method::BidiStreaming(_) => Some(MethodDescriptor::bidi_streaming()),
}
}
fn call_unary(
&self,
path: &str,
ctx: crate::handler::Context,
request: bytes::Bytes,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::UnaryResult {
match self.methods.get(path) {
Some(Method::Unary(m)) => m.handler.call_erased(ctx, request, format),
_ => crate::dispatcher::unimplemented_unary(path),
}
}
fn call_server_streaming(
&self,
path: &str,
ctx: crate::handler::Context,
request: bytes::Bytes,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::StreamingResult {
match self.methods.get(path) {
Some(Method::Streaming(m)) => m.handler.call_erased(ctx, request, format),
_ => crate::dispatcher::unimplemented_streaming(path),
}
}
fn call_client_streaming(
&self,
path: &str,
ctx: crate::handler::Context,
requests: crate::dispatcher::RequestStream,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::UnaryResult {
match self.methods.get(path) {
Some(Method::ClientStreaming(m)) => m.handler.call_erased(ctx, requests, format),
_ => crate::dispatcher::unimplemented_unary(path),
}
}
fn call_bidi_streaming(
&self,
path: &str,
ctx: crate::handler::Context,
requests: crate::dispatcher::RequestStream,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::StreamingResult {
match self.methods.get(path) {
Some(Method::BidiStreaming(m)) => m.handler.call_erased(ctx, requests, format),
_ => crate::dispatcher::unimplemented_streaming(path),
}
}
}
pub fn merge_routers(routers: impl IntoIterator<Item = Router>) -> Router {
let mut merged = Router::new();
for router in routers {
merged.methods.extend(router.methods);
}
merged
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_registration() {
let router = Router::new();
assert!(!router.has_method("test.Service/Method"));
}
}