use std::collections::HashMap;
use std::sync::Arc;
use buffa::Message;
use buffa::view::MessageView;
use crate::codec::{JsonDeserialize, JsonSerialize};
use crate::handler::BidiStreamingHandler;
use crate::handler::BidiStreamingHandlerWrapper;
use crate::handler::BidiStreamingViewHandlerWrapper;
use crate::handler::ClientStreamingHandler;
use crate::handler::ClientStreamingHandlerWrapper;
use crate::handler::ClientStreamingViewHandlerWrapper;
use crate::handler::ErasedBidiStreamingHandler;
use crate::handler::ErasedClientStreamingHandler;
use crate::handler::ErasedHandler;
use crate::handler::ErasedStreamingHandler;
use crate::handler::Handler;
use crate::handler::ServerStreamingHandlerWrapper;
use crate::handler::ServerStreamingViewHandlerWrapper;
use crate::handler::StreamingHandler;
use crate::handler::UnaryHandlerWrapper;
use crate::handler::UnaryViewHandlerWrapper;
use crate::handler::ViewBidiStreamingHandler;
use crate::handler::ViewClientStreamingHandler;
use crate::handler::ViewHandler;
use crate::handler::ViewStreamingHandler;
use crate::spec::IdempotencyLevel;
use crate::spec::Spec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MethodKind {
Unary,
ServerStreaming,
ClientStreaming,
BidiStreaming,
}
struct UnaryMethod {
handler: Arc<dyn ErasedHandler>,
idempotent: bool,
}
struct StreamingMethod {
handler: Arc<dyn ErasedStreamingHandler>,
kind: MethodKind,
}
struct ClientStreamingMethod {
handler: Arc<dyn ErasedClientStreamingHandler>,
}
struct BidiStreamingMethod {
handler: Arc<dyn ErasedBidiStreamingHandler>,
}
enum Method {
Unary(UnaryMethod),
Streaming(StreamingMethod),
ClientStreaming(ClientStreamingMethod),
BidiStreaming(BidiStreamingMethod),
}
struct RegisteredMethod {
method: Method,
spec: Option<Spec>,
}
impl From<Method> for RegisteredMethod {
fn from(method: Method) -> Self {
Self { method, spec: None }
}
}
pub trait ServiceRegister<Marker> {
fn register_service(self, router: Router) -> Router;
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("router merge conflict on path(s): {conflicts:?}")]
#[non_exhaustive]
pub struct RouterMergeError {
conflicts: Vec<String>,
}
impl RouterMergeError {
#[must_use]
pub fn conflicting_paths(&self) -> &[String] {
&self.conflicts
}
}
#[derive(Default)]
pub struct Router {
methods: HashMap<String, RegisteredMethod>,
allow_overrides: bool,
}
impl Router {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn add_service<S: ?Sized, Marker>(self, service: Arc<S>) -> Self
where
Arc<S>: ServiceRegister<Marker>,
{
<Arc<S> as ServiceRegister<Marker>>::register_service(service, self)
}
#[must_use]
pub fn allow_overrides(mut self) -> Self {
self.allow_overrides = true;
self
}
#[must_use]
pub fn merge(mut self, other: Self) -> Self {
self.merge_in_place(other);
self
}
pub fn merge_in_place(&mut self, other: Self) {
if let Err(err) = self.try_merge_in_place(other) {
panic!(
"router merge conflict on path(s) {:?} — both routers register \
these paths. Call `allow_overrides()` if replacing the existing \
routes is intended.",
err.conflicting_paths()
);
}
}
pub fn try_merge(mut self, other: Self) -> Result<Self, RouterMergeError> {
self.try_merge_in_place(other)?;
Ok(self)
}
pub fn try_merge_in_place(&mut self, other: Self) -> Result<(), RouterMergeError> {
if !self.allow_overrides {
let mut conflicts: Vec<String> = other
.methods
.keys()
.filter(|path| self.methods.contains_key(*path))
.cloned()
.collect();
if !conflicts.is_empty() {
conflicts.sort();
return Err(RouterMergeError { conflicts });
}
}
self.methods.extend(other.methods);
Ok(())
}
#[doc(hidden)]
pub fn route<H, Req, Res>(self, service_name: &str, method_name: &str, handler: H) -> Self
where
H: Handler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + JsonSerialize + Send + 'static,
{
self.route_unary_internal(service_name, method_name, handler, false)
}
#[doc(hidden)]
pub fn route_idempotent<H, Req, Res>(
self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: Handler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + JsonSerialize + Send + 'static,
{
self.route_unary_internal(service_name, method_name, handler, true)
}
fn insert_method(&mut self, path: String, method: RegisteredMethod) {
if !self.allow_overrides && self.methods.contains_key(&path) {
panic!(
"router registration conflict on path {path:?} — a route is \
already registered here (registering the same service twice?). \
Call `allow_overrides()` if replacing it is intended."
);
}
self.methods.insert(path, method);
}
fn route_unary_internal<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
idempotent: bool,
) -> Self
where
H: Handler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + JsonSerialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = UnaryHandlerWrapper::new(handler);
self.insert_method(
path,
Method::Unary(UnaryMethod {
handler: Arc::new(wrapper),
idempotent,
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_server_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: StreamingHandler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ServerStreamingHandlerWrapper::new(handler);
self.insert_method(
path,
Method::Streaming(StreamingMethod {
handler: Arc::new(wrapper),
kind: MethodKind::ServerStreaming,
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_client_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + JsonSerialize + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ClientStreamingHandlerWrapper::new(handler);
self.insert_method(
path,
Method::ClientStreaming(ClientStreamingMethod {
handler: Arc::new(wrapper),
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_bidi_stream<H, Req, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + JsonDeserialize + Send + 'static,
Res: Message + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = BidiStreamingHandlerWrapper::new(handler);
self.insert_method(
path,
Method::BidiStreaming(BidiStreamingMethod {
handler: Arc::new(wrapper),
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_view<H, ReqView>(self, service_name: &str, method_name: &str, handler: H) -> Self
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
{
self.route_view_internal(service_name, method_name, handler, false)
}
#[doc(hidden)]
pub fn route_view_idempotent<H, ReqView>(
self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
{
self.route_view_internal(service_name, method_name, handler, true)
}
fn route_view_internal<H, ReqView>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
idempotent: bool,
) -> Self
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
{
let path = format!("{service_name}/{method_name}");
let wrapper = UnaryViewHandlerWrapper::new(handler);
self.insert_method(
path,
Method::Unary(UnaryMethod {
handler: Arc::new(wrapper),
idempotent,
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_view_server_stream<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
Res: Message + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ServerStreamingViewHandlerWrapper::new(handler);
self.insert_method(
path,
Method::Streaming(StreamingMethod {
handler: Arc::new(wrapper),
kind: MethodKind::ServerStreaming,
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_view_client_stream<H, ReqView>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewClientStreamingHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
{
let path = format!("{service_name}/{method_name}");
let wrapper = ClientStreamingViewHandlerWrapper::new(handler);
self.insert_method(
path,
Method::ClientStreaming(ClientStreamingMethod {
handler: Arc::new(wrapper),
})
.into(),
);
self
}
#[doc(hidden)]
pub fn route_view_bidi_stream<H, ReqView, Res>(
mut self,
service_name: &str,
method_name: &str,
handler: H,
) -> Self
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + JsonDeserialize,
Res: Message + Send + 'static,
{
let path = format!("{service_name}/{method_name}");
let wrapper = BidiStreamingViewHandlerWrapper::new(handler);
self.insert_method(
path,
Method::BidiStreaming(BidiStreamingMethod {
handler: Arc::new(wrapper),
})
.into(),
);
self
}
#[must_use]
pub fn with_spec(mut self, spec: Spec) -> Self {
let key = spec.procedure.strip_prefix('/').unwrap_or(spec.procedure);
match self.methods.get_mut(key) {
Some(m) => {
if let Method::Unary(u) = &m.method {
debug_assert_eq!(
u.idempotent,
spec.idempotency_level == IdempotencyLevel::NoSideEffects,
"route {key:?} idempotency disagrees with Spec::idempotency_level — \
pick `route` vs `route_idempotent` to match the Spec"
);
}
m.spec = Some(spec);
}
None => {
debug_assert!(
false,
"Router::with_spec: no route registered at {key:?} — \
call the matching `route_*` first"
);
}
}
self
}
pub fn methods(&self) -> impl Iterator<Item = &str> {
self.methods.keys().map(String::as_str)
}
pub fn has_method(&self, path: &str) -> bool {
self.methods.contains_key(path)
}
}
impl crate::dispatcher::Dispatcher for Router {
fn lookup(&self, path: &str) -> Option<crate::dispatcher::MethodDescriptor> {
use crate::dispatcher::MethodDescriptor;
let m = self.methods.get(path)?;
let mut desc = match &m.method {
Method::Unary(u) => MethodDescriptor::unary(u.idempotent),
Method::Streaming(s) => MethodDescriptor::from_kind(s.kind),
Method::ClientStreaming(_) => MethodDescriptor::client_streaming(),
Method::BidiStreaming(_) => MethodDescriptor::bidi_streaming(),
};
if let Some(spec) = m.spec {
desc = desc.with_spec(spec);
}
Some(desc)
}
fn call_unary(
&self,
path: &str,
ctx: crate::response::RequestContext,
request: crate::Payload,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::UnaryResult {
match self.methods.get(path).map(|m| &m.method) {
Some(Method::Unary(m)) => m.handler.call_erased(ctx, request, format),
_ => crate::dispatcher::unimplemented_unary(path),
}
}
fn call_server_streaming(
&self,
path: &str,
ctx: crate::response::RequestContext,
request: bytes::Bytes,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::StreamingResult {
match self.methods.get(path).map(|m| &m.method) {
Some(Method::Streaming(m)) => m.handler.call_erased(ctx, request, format),
_ => crate::dispatcher::unimplemented_streaming(path),
}
}
fn call_client_streaming(
&self,
path: &str,
ctx: crate::response::RequestContext,
requests: crate::dispatcher::RequestStream,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::UnaryResult {
match self.methods.get(path).map(|m| &m.method) {
Some(Method::ClientStreaming(m)) => m.handler.call_erased(ctx, requests, format),
_ => crate::dispatcher::unimplemented_unary(path),
}
}
fn call_bidi_streaming(
&self,
path: &str,
ctx: crate::response::RequestContext,
requests: crate::dispatcher::RequestStream,
format: crate::codec::CodecFormat,
) -> crate::dispatcher::StreamingResult {
match self.methods.get(path).map(|m| &m.method) {
Some(Method::BidiStreaming(m)) => m.handler.call_erased(ctx, requests, format),
_ => crate::dispatcher::unimplemented_streaming(path),
}
}
}
pub fn merge_routers(routers: impl IntoIterator<Item = Router>) -> Router {
let mut merged = Router::new();
for router in routers {
merged.merge_in_place(router);
}
merged
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dispatcher::Dispatcher;
use crate::handler_fn;
use crate::spec::StreamType;
use buffa_types::Empty;
#[test]
fn test_router_registration() {
let router = Router::new();
assert!(!router.has_method("test.Service/Method"));
}
fn unary_handler() -> impl Handler<Empty, Empty> {
handler_fn(|_ctx, _req: Empty| async { crate::Response::ok(Empty::default()) })
}
struct TestService;
struct TestServiceRegisterMarker;
impl ServiceRegister<TestServiceRegisterMarker> for Arc<TestService> {
fn register_service(self, router: Router) -> Router {
router.route("test.Service", "Method", unary_handler())
}
}
#[test]
fn add_service_forwards_to_generated_registration() {
let router = Router::new().add_service(Arc::new(TestService));
assert!(router.has_method("test.Service/Method"));
}
struct DualService;
struct DualServiceAMarker;
struct DualServiceBMarker;
impl ServiceRegister<DualServiceAMarker> for Arc<DualService> {
fn register_service(self, router: Router) -> Router {
router.route("test.DualA", "Call", unary_handler())
}
}
impl ServiceRegister<DualServiceBMarker> for Arc<DualService> {
fn register_service(self, router: Router) -> Router {
router.route("test.DualB", "Call", unary_handler())
}
}
#[test]
fn add_service_disambiguates_multi_impl_with_turbofish() {
let router = Router::new()
.add_service::<_, DualServiceAMarker>(Arc::new(DualService))
.add_service::<_, DualServiceBMarker>(Arc::new(DualService));
assert!(router.has_method("test.DualA/Call"));
assert!(router.has_method("test.DualB/Call"));
}
#[test]
#[should_panic(expected = "router registration conflict")]
fn add_service_panics_on_double_registration() {
let _ = Router::new()
.add_service(Arc::new(TestService))
.add_service(Arc::new(TestService));
}
#[test]
fn add_service_with_allow_overrides_permits_re_registration() {
let router = Router::new()
.allow_overrides()
.add_service(Arc::new(TestService))
.add_service(Arc::new(TestService));
assert!(router.has_method("test.Service/Method"));
}
#[test]
#[should_panic(expected = "router registration conflict")]
fn route_panics_on_duplicate_path() {
let _ = Router::new()
.route("test.Service", "Method", unary_handler())
.route("test.Service", "Method", unary_handler());
}
#[test]
fn merge_and_merge_in_place_combine_routes() {
let first = Router::new().route("test.First", "Call", unary_handler());
let second = Router::new().route("test.Second", "Call", unary_handler());
let mut router = first.merge(second);
router.merge_in_place(Router::new().route("test.Third", "Call", unary_handler()));
assert!(router.has_method("test.First/Call"));
assert!(router.has_method("test.Second/Call"));
assert!(router.has_method("test.Third/Call"));
}
#[test]
#[should_panic(expected = "router merge conflict")]
fn merge_panics_on_duplicate_path_by_default() {
let original = Router::new().route("test.Service", "Method", unary_handler());
let replacement = Router::new().route("test.Service", "Method", unary_handler());
let _ = original.merge(replacement);
}
#[test]
fn merge_with_allow_overrides_replaces_duplicate_routes() {
let original = Router::new().route("test.Service", "Method", unary_handler());
let replacement = Router::new().route_idempotent("test.Service", "Method", unary_handler());
let router = original.allow_overrides().merge(replacement);
let descriptor = router.lookup("test.Service/Method").expect("route exists");
assert!(descriptor.idempotent);
}
#[test]
fn try_merge_ok_when_paths_are_disjoint() {
let first = Router::new().route("test.First", "Call", unary_handler());
let second = Router::new().route("test.Second", "Call", unary_handler());
let router = first.try_merge(second).expect("disjoint merge succeeds");
assert!(router.has_method("test.First/Call"));
assert!(router.has_method("test.Second/Call"));
}
#[test]
fn try_merge_reports_conflicting_paths_without_panicking() {
let original = Router::new().route("test.Service", "Method", unary_handler());
let other = Router::new().route("test.Service", "Method", unary_handler());
let Err(err) = original.try_merge(other) else {
panic!("conflict must error");
};
assert_eq!(err.conflicting_paths(), ["test.Service/Method".to_string()]);
}
#[test]
fn try_merge_in_place_is_transactional_on_conflict() {
let mut router = Router::new()
.route("test.Keep", "Call", unary_handler())
.route("test.Service", "Method", unary_handler());
let other = Router::new()
.route("test.Service", "Method", unary_handler())
.route("test.New", "Call", unary_handler());
let err = router
.try_merge_in_place(other)
.expect_err("conflict must error");
assert_eq!(err.conflicting_paths(), ["test.Service/Method".to_string()]);
assert!(router.has_method("test.Keep/Call"));
assert!(!router.has_method("test.New/Call"));
}
#[test]
fn try_merge_with_allow_overrides_replaces_and_returns_ok() {
let original = Router::new().route("test.Service", "Method", unary_handler());
let replacement = Router::new().route_idempotent("test.Service", "Method", unary_handler());
let router = original
.allow_overrides()
.try_merge(replacement)
.expect("overrides suppress the conflict error");
assert!(router.lookup("test.Service/Method").unwrap().idempotent);
}
#[test]
fn with_spec_round_trips_through_lookup() {
const SPEC: Spec = Spec::server("/test.Svc/Method", StreamType::Unary);
let router = Router::new()
.route("test.Svc", "Method", unary_handler())
.with_spec(SPEC);
let desc = router.lookup("test.Svc/Method").expect("route exists");
assert_eq!(
desc.spec,
Some(SPEC),
"lookup must return the attached Spec"
);
assert_eq!(desc.kind, MethodKind::Unary);
assert!(!desc.idempotent);
}
#[test]
fn route_without_with_spec_is_unchanged() {
let router = Router::new().route("test.Svc", "Method", unary_handler());
let desc = router.lookup("test.Svc/Method").expect("route exists");
assert_eq!(desc.spec, None);
}
#[test]
fn merge_routers_preserves_specs() {
const A: Spec = Spec::server("/svc.A/M", StreamType::Unary);
const B: Spec = Spec::server("/svc.B/N", StreamType::Unary);
let merged = merge_routers([
Router::new()
.route("svc.A", "M", unary_handler())
.with_spec(A),
Router::new()
.route("svc.B", "N", unary_handler())
.with_spec(B),
]);
assert_eq!(merged.lookup("svc.A/M").unwrap().spec, Some(A));
assert_eq!(merged.lookup("svc.B/N").unwrap().spec, Some(B));
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "no route registered")]
fn with_spec_unknown_route_panics_in_debug() {
const SPEC: Spec = Spec::server("/test.Svc/Nope", StreamType::Unary);
let _ = Router::new()
.route("test.Svc", "Method", unary_handler())
.with_spec(SPEC);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "idempotency disagrees")]
fn with_spec_idempotency_mismatch_panics_in_debug() {
const SPEC: Spec = Spec::server("/test.Svc/Method", StreamType::Unary)
.with_idempotency_level(IdempotencyLevel::NoSideEffects);
let _ = Router::new()
.route("test.Svc", "Method", unary_handler())
.with_spec(SPEC);
}
}