use crate::{
routes::{MakeErasedHandler, RouteFuture},
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, Method, MethodId, RegistrationError,
Route,
};
use core::fmt;
use serde_json::value::RawValue;
use std::{borrow::Cow, collections::BTreeMap, convert::Infallible, sync::Arc, task::Poll};
use tower::Service;
use tracing::{debug_span, trace};
#[must_use = "Routers do nothing unless served."]
pub struct Router<S> {
inner: Arc<RouterInner<S>>,
}
impl<S> Clone for Router<S> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<S> Default for Router<S>
where
S: Send + Sync + Clone + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<S> Router<S>
where
S: Send + Sync + Clone + 'static,
{
pub fn new() -> Self {
Self {
inner: Arc::new(RouterInner::new()),
}
}
fn into_inner(self) -> RouterInner<S> {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => RouterInner {
routes: arc.routes.clone(),
last_id: arc.last_id,
fallback: arc.fallback.clone(),
name_to_id: arc.name_to_id.clone(),
id_to_name: arc.id_to_name.clone(),
},
}
}
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, inner => inner.with_state(&state))
}
pub(crate) fn fallback_stateless(self, route: Route) -> Self {
tap_inner!(self, mut this => {
this.fallback = Method::Ready(route);
})
}
fn fallback_erased<E>(self, handler: E) -> Self
where
E: ErasedIntoRoute<S>,
{
tap_inner!(self, mut this => {
this.fallback = Method::Needs(BoxedIntoRoute(handler.clone_box()));
})
}
pub fn fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
self.fallback_erased(MakeErasedHandler::from_handler(handler))
}
pub fn fallback_service<T>(self, service: T) -> Self
where
T: Service<
HandlerArgs,
Response = Box<RawValue>,
Error = Infallible,
Future: Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
{
self.fallback_stateless(Route::new(service))
}
pub fn route<H, T>(self, method: impl Into<Cow<'static, str>>, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
tap_inner!(self, mut this => {
this = this.route(method, handler);
})
}
pub fn nest(self, prefix: impl Into<Cow<'static, str>>, other: Self) -> Self {
let prefix = prefix.into();
let mut this = self.into_inner();
let prefix = Cow::Borrowed(prefix.trim_end_matches('_'));
let RouterInner {
routes, id_to_name, ..
} = other.into_inner();
for (id, handler) in routes.into_iter() {
let existing_name = id_to_name
.get(&id)
.expect("nested router has missing name for existing method");
let method = format!("{}_{}", prefix, existing_name);
panic_on_err!(this.enroll_method(method.into(), handler));
}
Self {
inner: Arc::new(this),
}
}
pub fn merge(self, other: Self) -> Self {
let mut this = self.into_inner();
let RouterInner {
routes,
mut id_to_name,
..
} = other.into_inner();
for (id, handler) in routes.into_iter() {
let existing_name = id_to_name
.remove(&id)
.expect("nested router has missing name for existing method");
panic_on_err!(this.enroll_method(existing_name, handler));
}
Self {
inner: Arc::new(this),
}
}
pub fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let id = args.req.id_owned();
let method = args.req.method();
let span = debug_span!("Router::call_with_state", %method, %id);
trace!(params = args.req.params());
self.inner.call_with_state(args, state).with_span(span)
}
#[cfg(feature = "axum")]
pub fn into_axum(self, path: &str) -> axum::Router<S> {
axum::Router::new().route(path, axum::routing::post(self))
}
}
impl Router<()> {
pub fn handle_request(&self, args: HandlerArgs) -> RouteFuture {
self.call_with_state(args, ())
}
}
impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router").finish_non_exhaustive()
}
}
impl tower::Service<HandlerArgs> for Router<()> {
type Response = Box<RawValue>;
type Error = Infallible;
type Future = RouteFuture;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, args: HandlerArgs) -> Self::Future {
self.handle_request(args)
}
}
impl tower::Service<HandlerArgs> for &Router<()> {
type Response = Box<RawValue>;
type Error = Infallible;
type Future = RouteFuture;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, args: HandlerArgs) -> Self::Future {
self.handle_request(args)
}
}
pub(crate) struct RouterInner<S> {
routes: BTreeMap<MethodId, Method<S>>,
last_id: MethodId,
fallback: Method<S>,
name_to_id: BTreeMap<Cow<'static, str>, MethodId>,
id_to_name: BTreeMap<MethodId, Cow<'static, str>>,
}
impl Default for RouterInner<()> {
fn default() -> Self {
Self::new()
}
}
impl<S> fmt::Debug for RouterInner<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterInner").finish_non_exhaustive()
}
}
impl<S> RouterInner<S> {
pub(crate) fn new() -> Self {
Self {
routes: BTreeMap::new(),
last_id: Default::default(),
fallback: Method::Ready(Route::default_fallback()),
name_to_id: BTreeMap::new(),
id_to_name: BTreeMap::new(),
}
}
pub(crate) fn with_state<S2>(self, state: &S) -> RouterInner<S2>
where
S: Clone,
{
RouterInner {
routes: self
.routes
.into_iter()
.map(|(id, method)| (id, method.with_state(state)))
.collect(),
fallback: self.fallback.with_state(state),
last_id: self.last_id,
name_to_id: self.name_to_id,
id_to_name: self.id_to_name,
}
}
fn get_id(&mut self) -> MethodId {
self.last_id += 1;
self.last_id
}
fn method_by_name(&self, name: &str) -> Option<&Method<S>> {
self.name_to_id.get(name).and_then(|id| self.routes.get(id))
}
#[track_caller]
fn enroll_method_name(
&mut self,
method: Cow<'static, str>,
) -> Result<MethodId, RegistrationError> {
if self.name_to_id.contains_key(&method) {
return Err(RegistrationError::method_already_registered(method));
}
let id = self.get_id();
self.name_to_id.insert(method.clone(), id);
self.id_to_name.insert(id, method.clone());
Ok(id)
}
fn enroll_method(
&mut self,
method: Cow<'static, str>,
handler: Method<S>,
) -> Result<MethodId, RegistrationError> {
self.enroll_method_name(method).inspect(|id| {
self.routes.insert(*id, handler);
})
}
#[track_caller]
fn route_erased<E>(mut self, method: impl Into<Cow<'static, str>>, handler: E) -> Self
where
E: ErasedIntoRoute<S>,
{
let method = method.into();
let handler = handler.clone_box();
add_method_inner(&mut self, method, handler);
fn add_method_inner<S>(
this: &mut RouterInner<S>,
method: Cow<'static, str>,
handler: Box<dyn ErasedIntoRoute<S>>,
) {
panic_on_err!(this.enroll_method(method, Method::Needs(BoxedIntoRoute(handler))));
}
self
}
pub(crate) fn route<H, T>(self, method: impl Into<Cow<'static, str>>, handler: H) -> Self
where
H: Handler<T, S>,
T: Send + 'static,
S: Clone + Send + Sync + 'static,
{
self.route_erased(method, MakeErasedHandler::from_handler(handler))
}
#[track_caller]
pub(crate) fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
let method = args.req.method();
self.method_by_name(method)
.unwrap_or(&self.fallback)
.call_with_state(args, state)
}
}