use crate::{
routes::{BatchFuture, MakeErasedHandler, RouteFuture},
types::InboundData,
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, HandlerCtx, Method, MethodId,
RegistrationError, Route,
};
use core::fmt;
use serde_json::value::RawValue;
use std::{borrow::Cow, collections::BTreeMap, convert::Infallible, sync::Arc, task::Poll};
use tower::Service;
use tracing::debug_span;
#[must_use = "Routers do nothing unless served."]
pub struct Router<S> {
inner: Arc<RouterInner<S>>,
}
impl<S> Clone for Router<S> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<S> Default for Router<S>
where
S: Send + Sync + Clone + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<S> Router<S>
where
S: Send + Sync + Clone + 'static,
{
pub fn new() -> Self {
Self {
inner: Arc::new(RouterInner::new()),
}
}
pub fn new_named(service_name: &'static str) -> Self {
Self {
inner: Arc::new(RouterInner {
service_name: Some(service_name),
..RouterInner::new()
}),
}
}
pub fn service_name(&self) -> &'static str {
self.inner.service_name()
}
pub fn set_name(self, service_name: &'static str) -> Self {
tap_inner!(self, mut this => {
this.service_name = Some(service_name);
})
}
fn into_inner(self) -> RouterInner<S> {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => RouterInner {
routes: arc.routes.clone(),
last_id: arc.last_id,
fallback: arc.fallback.clone(),
service_name: arc.service_name,
name_to_id: arc.name_to_id.clone(),
id_to_name: arc.id_to_name.clone(),
},
}
}
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, inner => inner.with_state(&state))
}
pub(crate) fn fallback_stateless(self, route: Route) -> Self {
tap_inner!(self, mut this => {
this.fallback = Method::Ready(route);
})
}
fn fallback_erased<E>(self, handler: E) -> Self
where
E: ErasedIntoRoute<S>,
{
tap_inner!(self, mut this => {
this.fallback = Method::Needs(BoxedIntoRoute(handler.clone_box()));
})
}
pub fn fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
self.fallback_erased(MakeErasedHandler::from_handler(handler))
}
pub fn fallback_service<T>(self, service: T) -> Self
where
T: Service<
HandlerArgs,
Response = Option<Box<RawValue>>,
Error = Infallible,
Future: Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
{
self.fallback_stateless(Route::new(service))
}
pub fn route<H, T>(self, method: impl Into<Cow<'static, str>>, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
tap_inner!(self, mut this => {
this = this.route(method, handler);
})
}
pub fn nest(self, prefix: impl Into<Cow<'static, str>>, other: Self) -> Self {
let prefix = prefix.into();
let mut this = self.into_inner();
let prefix = Cow::Borrowed(prefix.trim_end_matches('_'));
let RouterInner {
routes, id_to_name, ..
} = other.into_inner();
for (id, handler) in routes.into_iter() {
let existing_name = id_to_name
.get(&id)
.expect("nested router has missing name for existing method");
let method = format!("{}_{}", prefix, existing_name);
panic_on_err!(this.enroll_method(method.into(), handler));
}
Self {
inner: Arc::new(this),
}
}
pub fn merge(self, other: Self) -> Self {
let mut this = self.into_inner();
let RouterInner {
routes,
mut id_to_name,
..
} = other.into_inner();
for (id, handler) in routes.into_iter() {
let existing_name = id_to_name
.remove(&id)
.expect("nested router has missing name for existing method");
panic_on_err!(this.enroll_method(existing_name, handler));
}
Self {
inner: Arc::new(this),
}
}
pub fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let id = args.req().id_owned();
let method = args.req().method();
let span = debug_span!(parent: args.span(), "Router::call_with_state", %method, ?id);
self.inner.call_with_state(args, state).with_span(span)
}
pub fn call_batch_with_state(
&self,
ctx: HandlerCtx,
inbound: InboundData,
state: S,
) -> BatchFuture {
let mut fut =
BatchFuture::new_with_capacity(inbound.single(), self.service_name(), inbound.len());
let batch_span = debug_span!(parent: ctx.span(), "BatchFuture::poll", reqs = inbound.len(), futs = tracing::field::Empty);
for req in inbound.iter() {
let req = req.map(|req| {
let ctx = ctx.child_ctx(self, Some(&batch_span));
let request_span = ctx.span().clone();
let args = HandlerArgs::new(ctx, req);
self.inner
.call_with_state(args, state.clone())
.with_span(request_span)
});
fut.push_parse_result(req);
}
batch_span.record("futs", fut.len());
fut.with_span(batch_span)
}
#[cfg(feature = "axum")]
pub fn into_axum(self, path: &str) -> axum::Router<S> {
axum::Router::new().route(path, axum::routing::post(crate::axum::IntoAxum::from(self)))
}
#[cfg(feature = "axum")]
pub fn into_axum_with_handle(
self,
path: &str,
handle: tokio::runtime::Handle,
) -> axum::Router<S> {
axum::Router::new().route(
path,
axum::routing::post(crate::axum::IntoAxum::new(self, handle)),
)
}
}
impl Router<()> {
#[cfg(feature = "pubsub")]
pub async fn serve_pubsub<C: crate::pubsub::Connect>(
self,
connect: C,
) -> Result<crate::pubsub::ServerShutdown, C::Error> {
connect.serve(self).await
}
#[cfg(all(feature = "axum", feature = "pubsub"))]
pub fn to_axum_cfg(&self) -> crate::pubsub::AxumWsCfg {
crate::pubsub::AxumWsCfg::new(self.clone())
}
#[cfg(all(feature = "axum", feature = "pubsub"))]
pub fn into_axum_with_ws(self, post_route: &str, ws_route: &str) -> axum::Router<()> {
let cfg = self.to_axum_cfg();
self.into_axum(post_route)
.with_state(())
.route(ws_route, axum::routing::any(crate::pubsub::ajj_websocket))
.with_state(cfg)
}
#[cfg(all(feature = "axum", feature = "pubsub"))]
pub fn into_axum_with_ws_and_handle(
self,
post_route: &str,
ws_route: &str,
handle: tokio::runtime::Handle,
) -> axum::Router<()> {
let cfg = self.to_axum_cfg();
self.into_axum_with_handle(post_route, handle)
.with_state(())
.route(ws_route, axum::routing::any(crate::pubsub::ajj_websocket))
.with_state(cfg)
}
pub fn handle_request(&self, args: HandlerArgs) -> RouteFuture {
self.call_with_state(args, ())
}
pub fn handle_request_batch(&self, ctx: HandlerCtx, batch: InboundData) -> BatchFuture {
self.call_batch_with_state(ctx, batch, ())
}
}
impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router").finish_non_exhaustive()
}
}
impl tower::Service<HandlerArgs> for Router<()> {
type Response = Option<Box<RawValue>>;
type Error = Infallible;
type Future = RouteFuture;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, args: HandlerArgs) -> Self::Future {
self.handle_request(args)
}
}
impl tower::Service<HandlerArgs> for &Router<()> {
type Response = Option<Box<RawValue>>;
type Error = Infallible;
type Future = RouteFuture;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, args: HandlerArgs) -> Self::Future {
self.handle_request(args)
}
}
pub(crate) struct RouterInner<S> {
routes: BTreeMap<MethodId, Method<S>>,
last_id: MethodId,
fallback: Method<S>,
service_name: Option<&'static str>,
name_to_id: BTreeMap<Cow<'static, str>, MethodId>,
id_to_name: BTreeMap<MethodId, Cow<'static, str>>,
}
impl Default for RouterInner<()> {
fn default() -> Self {
Self::new()
}
}
impl<S> fmt::Debug for RouterInner<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterInner").finish_non_exhaustive()
}
}
impl<S> RouterInner<S> {
pub(crate) fn new() -> Self {
Self {
routes: BTreeMap::new(),
last_id: Default::default(),
fallback: Method::Ready(Route::default_fallback()),
service_name: None,
name_to_id: BTreeMap::new(),
id_to_name: BTreeMap::new(),
}
}
pub(crate) fn with_state<S2>(self, state: &S) -> RouterInner<S2>
where
S: Clone,
{
RouterInner {
routes: self
.routes
.into_iter()
.map(|(id, method)| (id, method.with_state(state)))
.collect(),
fallback: self.fallback.with_state(state),
last_id: self.last_id,
service_name: self.service_name,
name_to_id: self.name_to_id,
id_to_name: self.id_to_name,
}
}
fn service_name(&self) -> &'static str {
self.service_name.unwrap_or("ajj")
}
fn get_id(&mut self) -> MethodId {
self.last_id += 1;
self.last_id
}
fn method_by_name(&self, name: &str) -> Option<&Method<S>> {
self.name_to_id.get(name).and_then(|id| self.routes.get(id))
}
#[track_caller]
fn enroll_method_name(
&mut self,
method: Cow<'static, str>,
) -> Result<MethodId, RegistrationError> {
if self.name_to_id.contains_key(&method) {
return Err(RegistrationError::method_already_registered(method));
}
let id = self.get_id();
self.name_to_id.insert(method.clone(), id);
self.id_to_name.insert(id, method.clone());
Ok(id)
}
fn enroll_method(
&mut self,
method: Cow<'static, str>,
handler: Method<S>,
) -> Result<MethodId, RegistrationError> {
self.enroll_method_name(method).inspect(|id| {
self.routes.insert(*id, handler);
})
}
#[track_caller]
fn route_erased<E>(mut self, method: impl Into<Cow<'static, str>>, handler: E) -> Self
where
E: ErasedIntoRoute<S>,
{
let method = method.into();
let handler = handler.clone_box();
add_method_inner(&mut self, method, handler);
fn add_method_inner<S>(
this: &mut RouterInner<S>,
method: Cow<'static, str>,
handler: Box<dyn ErasedIntoRoute<S>>,
) {
panic_on_err!(this.enroll_method(method, Method::Needs(BoxedIntoRoute(handler))));
}
self
}
pub(crate) fn route<H, T>(self, method: impl Into<Cow<'static, str>>, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
self.route_erased(method, MakeErasedHandler::from_handler(handler))
}
#[track_caller]
pub(crate) fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let method = args.req().method();
crate::metrics::record_call(self.service_name(), method);
self.method_by_name(method)
.unwrap_or(&self.fallback)
.call_with_state(args, state)
}
}