use alloc::{
boxed::Box,
collections::BTreeMap,
string::{String, ToString},
sync::Arc,
vec::Vec,
};
use core::marker::PhantomData;
use crate::{
BoxFuture, RouteContext, RouteError, RouteResult,
handler::{Handler, SyncHandler},
};
type BoxedHandler<S, R> = Arc<dyn Fn(RouteContext, S) -> BoxFuture<RouteResult<R>>>;
#[derive(Clone)]
pub struct Router<R, S = ()> {
state: S,
routes: Vec<Route<S, R>>,
return_type: PhantomData<R>,
}
#[derive(Clone)]
struct Route<S, R> {
pattern: Pattern,
handler: BoxedHandler<S, R>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Pattern {
segments: Vec<Segment>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Segment {
Literal(String),
Param(String),
}
impl<R> Router<R, ()> {
pub fn new() -> Self {
Self {
state: (),
routes: Vec::new(),
return_type: PhantomData,
}
}
}
impl<R> Default for Router<R, ()> {
fn default() -> Self {
Self::new()
}
}
impl<R, S> Router<R, S>
where
S: Clone + 'static,
R: 'static,
{
pub fn with_state<T>(self, state: T) -> Router<R, T>
where
T: Clone + 'static,
{
Router {
state,
routes: Vec::new(),
return_type: PhantomData,
}
}
pub fn route<T, H>(mut self, pattern: &str, handler: H) -> Self
where
H: Handler<T, S, R>,
{
let handler = Arc::new(move |context, state| handler.clone().call(context, state));
self.routes.push(Route {
pattern: Pattern::parse(pattern),
handler,
});
self
}
pub fn route_sync<T, H>(mut self, pattern: &str, handler: H) -> Self
where
H: SyncHandler<T, S, R>,
{
let handler = Arc::new(move |context, state| {
let result = handler.clone().call_sync(context, state);
Box::pin(async move { result }) as BoxFuture<RouteResult<R>>
});
self.routes.push(Route {
pattern: Pattern::parse(pattern),
handler,
});
self
}
pub async fn call(&self, route: impl Into<String>) -> RouteResult<R> {
self.call_with(RouteContext::new(route)).await
}
pub async fn call_with(&self, mut context: RouteContext) -> RouteResult<R> {
for route in &self.routes {
if let Some(params) = route.pattern.matches(&context.route) {
context.params = params;
return (route.handler)(context, self.state.clone()).await;
}
}
Err(RouteError::NotFound {
route: context.route,
})
}
}
impl Pattern {
fn parse(path: &str) -> Self {
Self {
segments: split_path(path)
.map(|segment| {
if let Some(name) = segment.strip_prefix(':') {
Segment::Param(name.to_string())
} else {
Segment::Literal(segment.to_string())
}
})
.collect(),
}
}
fn matches(&self, path: &str) -> Option<BTreeMap<String, String>> {
let mut params = BTreeMap::new();
let mut pattern_segments = self.segments.iter();
let mut path_segments = split_path(path);
loop {
match (pattern_segments.next(), path_segments.next()) {
(None, None) => return Some(params),
(None, Some(_)) | (Some(_), None) => return None,
(Some(Segment::Literal(expected)), Some(actual)) if expected == actual => {}
(Some(Segment::Literal(_)), Some(_)) => return None,
(Some(Segment::Param(name)), Some(actual)) => {
params.insert(name.clone(), actual.to_string());
}
}
}
}
}
fn split_path(path: &str) -> impl Iterator<Item = &str> {
path.trim_matches('/').split('/').filter(|s| !s.is_empty())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Input, Param, Params, State};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
#[test]
fn pattern_matches_named_params() {
let pattern = Pattern::parse("users/:id/posts/:post_id");
let params = pattern.matches("users/42/posts/7").unwrap();
assert_eq!(params["id"], "42");
assert_eq!(params["post_id"], "7");
assert!(pattern.matches("users/42").is_none());
}
#[test]
fn routes_zero_argument_handlers() {
block_on(async {
async fn hello() -> &'static str {
"hello"
}
let app = Router::<String>::new().route("hello", hello);
let output = app.call("hello").await.unwrap();
assert_eq!(output, "hello");
});
}
#[test]
fn extracts_param_input_and_state() {
block_on(async {
#[derive(Clone)]
struct Payload {
suffix: String,
}
async fn user(
Param(id): Param<u64>,
Input(payload): Input<Payload>,
State(prefix): State<String>,
) -> String {
format!("{prefix}:{id}:{}", payload.suffix)
}
let app = Router::<String>::new()
.with_state("user".to_string())
.route("users/:id", user);
let output = app
.call_with(RouteContext::new("users/9").input(Payload {
suffix: "saved".to_string(),
}))
.await
.unwrap();
assert_eq!(output, "user:9:saved");
});
}
#[test]
fn supports_many_parameters() {
block_on(async {
async fn many(
Param(id): Param<u64>,
Params(params): Params,
State(prefix): State<String>,
context: RouteContext,
) -> String {
format!("{prefix}:{id}:{}:{}", params["id"], context.route)
}
let app = Router::<String>::new()
.with_state("many".to_string())
.route("items/:id", many);
let output = app.call("items/7").await.unwrap();
assert_eq!(output, "many:7:7:items/7");
});
}
#[test]
fn supports_more_than_eight_parameters() {
block_on(async {
async fn many(
RouteContext { route, .. }: RouteContext,
State(prefix): State<String>,
Params(params): Params,
Param(id): Param<u64>,
RouteContext { params: p2, .. }: RouteContext,
State(prefix2): State<String>,
Params(params2): Params,
Param(id2): Param<u64>,
RouteContext { route: route2, .. }: RouteContext,
) -> String {
format!(
"{prefix}:{prefix2}:{id}:{id2}:{}:{}:{}:{}:{}",
params["id"], params2["id"], p2["id"], route, route2
)
}
let app = Router::<String>::new()
.with_state("wide".to_string())
.route("wide/:id", many);
let output = app.call("wide/11").await.unwrap();
assert_eq!(output, "wide:wide:11:11:11:11:11:wide/11:wide/11");
});
}
#[test]
fn routes_sync_handlers_on_a_separate_branch() {
block_on(async {
fn show(Param(id): Param<u64>) -> String {
format!("sync:{id}")
}
let app = Router::<String>::new().route_sync("sync/:id", show);
let output = app.call("sync/12").await.unwrap();
assert_eq!(output, "sync:12");
});
}
#[test]
fn handler_can_consume_a_cloned_capture() {
block_on(async {
let prefix = String::from("owned");
let app = Router::<String>::new().route_sync("consume", move || prefix);
let output = app.call("consume").await.unwrap();
assert_eq!(output, "owned");
});
}
fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
let waker = noop_waker();
let mut context = Context::from_waker(&waker);
let mut future = Box::pin(future);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(output) => return output,
Poll::Pending => std::thread::yield_now(),
}
}
}
fn noop_waker() -> Waker {
unsafe fn clone(_: *const ()) -> RawWaker {
noop_raw_waker()
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop(_: *const ()) {}
fn noop_raw_waker() -> RawWaker {
RawWaker::new(
std::ptr::null(),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop),
)
}
unsafe { Waker::from_raw(noop_raw_waker()) }
}
}