Skip to main content

neco_server_router/
router.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use neco_server_core::{Method, Request, Response, StatusCode};
6
7use crate::Extensions;
8
9type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
10
11/// Request envelope used during router dispatch.
12pub struct RoutedRequest {
13    /// The pure HTTP request message.
14    pub request: Request,
15    /// Request-local typed storage for middleware and handlers.
16    pub extensions: Extensions,
17}
18
19impl RoutedRequest {
20    /// Creates a routed request from a pure HTTP request.
21    pub fn new(request: Request) -> Self {
22        Self {
23            request,
24            extensions: Extensions::new(),
25        }
26    }
27}
28
29/// Route handler function type.
30pub type Handler<S, R = Response> = Arc<dyn Fn(RoutedRequest, S) -> BoxFuture<R> + Send + Sync>;
31
32/// Middleware function type.
33pub type Middleware<S, R = Response> =
34    Arc<dyn Fn(RoutedRequest, S, Next<S, R>) -> BoxFuture<R> + Send + Sync>;
35
36#[derive(Clone)]
37enum RouteMethod {
38    Exact(Method),
39    Any,
40}
41
42impl RouteMethod {
43    fn matches(&self, method: &Method) -> bool {
44        match self {
45            Self::Exact(expected) => expected == method,
46            Self::Any => true,
47        }
48    }
49}
50
51struct Route<S, R> {
52    method: RouteMethod,
53    path: String,
54    handler: Handler<S, R>,
55    middleware: Vec<Middleware<S, R>>,
56}
57
58impl<S, R> Clone for Route<S, R> {
59    fn clone(&self) -> Self {
60        Self {
61            method: self.method.clone(),
62            path: self.path.clone(),
63            handler: self.handler.clone(),
64            middleware: self.middleware.clone(),
65        }
66    }
67}
68
69/// Middleware continuation.
70pub struct Next<S, R = Response> {
71    middleware: Arc<Vec<Middleware<S, R>>>,
72    handler: Handler<S, R>,
73    index: usize,
74}
75
76impl<S, R> Clone for Next<S, R> {
77    fn clone(&self) -> Self {
78        Self {
79            middleware: self.middleware.clone(),
80            handler: self.handler.clone(),
81            index: self.index,
82        }
83    }
84}
85
86impl<S, R> Next<S, R>
87where
88    S: Clone + Send + Sync + 'static,
89    R: Send + 'static,
90{
91    /// Runs the next middleware or the final handler.
92    pub fn run(&self, request: RoutedRequest, state: S) -> BoxFuture<R> {
93        if let Some(middleware) = self.middleware.get(self.index).cloned() {
94            let next = Self {
95                middleware: self.middleware.clone(),
96                handler: self.handler.clone(),
97                index: self.index + 1,
98            };
99            middleware(request, state, next)
100        } else {
101            (self.handler)(request, state)
102        }
103    }
104}
105
106/// Fixed-path router with middleware chain.
107pub struct Router<S, R = Response> {
108    state: S,
109    routes: Vec<Route<S, R>>,
110    pending_middleware: Vec<Middleware<S, R>>,
111}
112
113impl<S, R> Clone for Router<S, R>
114where
115    S: Clone,
116{
117    fn clone(&self) -> Self {
118        Self {
119            state: self.state.clone(),
120            routes: self.routes.clone(),
121            pending_middleware: self.pending_middleware.clone(),
122        }
123    }
124}
125
126impl<S, R> Router<S, R>
127where
128    S: Clone + Send + Sync + 'static,
129    R: From<Response> + Send + 'static,
130{
131    /// Creates an empty router bound to a clonable state value.
132    pub fn new(state: S) -> Self {
133        Self {
134            state,
135            routes: Vec::new(),
136            pending_middleware: Vec::new(),
137        }
138    }
139
140    /// Registers a GET route.
141    pub fn get<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
142    where
143        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
144        Fut: Future<Output = R> + Send + 'static,
145    {
146        self.route(RouteMethod::Exact(Method::Get), path, handler)
147    }
148
149    /// Registers a POST route.
150    pub fn post<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
151    where
152        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
153        Fut: Future<Output = R> + Send + 'static,
154    {
155        self.route(RouteMethod::Exact(Method::Post), path, handler)
156    }
157
158    /// Registers a PUT route.
159    pub fn put<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
160    where
161        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
162        Fut: Future<Output = R> + Send + 'static,
163    {
164        self.route(RouteMethod::Exact(Method::Put), path, handler)
165    }
166
167    /// Registers a DELETE route.
168    pub fn delete<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
169    where
170        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
171        Fut: Future<Output = R> + Send + 'static,
172    {
173        self.route(RouteMethod::Exact(Method::Delete), path, handler)
174    }
175
176    /// Registers a PATCH route.
177    pub fn patch<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
178    where
179        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
180        Fut: Future<Output = R> + Send + 'static,
181    {
182        self.route(RouteMethod::Exact(Method::Patch), path, handler)
183    }
184
185    /// Registers a HEAD route.
186    pub fn head<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
187    where
188        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
189        Fut: Future<Output = R> + Send + 'static,
190    {
191        self.route(RouteMethod::Exact(Method::Head), path, handler)
192    }
193
194    /// Registers an OPTIONS route.
195    pub fn options<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
196    where
197        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
198        Fut: Future<Output = R> + Send + 'static,
199    {
200        self.route(RouteMethod::Exact(Method::Options), path, handler)
201    }
202
203    /// Registers a route for any method.
204    pub fn any<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
205    where
206        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
207        Fut: Future<Output = R> + Send + 'static,
208    {
209        self.route(RouteMethod::Any, path, handler)
210    }
211
212    /// Registers a route for an explicit method token.
213    pub fn on<F, Fut>(self, method: Method, path: impl Into<String>, handler: F) -> Self
214    where
215        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
216        Fut: Future<Output = R> + Send + 'static,
217    {
218        self.route(RouteMethod::Exact(method), path, handler)
219    }
220
221    fn route<F, Fut>(mut self, method: RouteMethod, path: impl Into<String>, handler: F) -> Self
222    where
223        F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
224        Fut: Future<Output = R> + Send + 'static,
225    {
226        let handler: Handler<S, R> =
227            Arc::new(move |request, state| Box::pin(handler(request, state)));
228        self.routes.push(Route {
229            method,
230            path: path.into(),
231            handler,
232            middleware: self.pending_middleware.clone(),
233        });
234        self
235    }
236
237    /// Adds a middleware to the end of the chain.
238    ///
239    /// The middleware applies to all routes currently registered on this router and to
240    /// any routes added later on the same router value. Middleware attached to another
241    /// router does not leak across [`Self::merge`].
242    pub fn middleware<F, Fut>(mut self, middleware: F) -> Self
243    where
244        F: Fn(RoutedRequest, S, Next<S, R>) -> Fut + Send + Sync + 'static,
245        Fut: Future<Output = R> + Send + 'static,
246    {
247        let middleware: Middleware<S, R> =
248            Arc::new(move |request, state, next| Box::pin(middleware(request, state, next)));
249        for route in &mut self.routes {
250            route.middleware.push(middleware.clone());
251        }
252        self.pending_middleware.push(middleware);
253        self
254    }
255
256    /// Merges routes and middleware from another router with the same state.
257    pub fn merge(mut self, other: Self) -> Self {
258        self.routes.extend(other.routes);
259        self
260    }
261
262    /// Dispatches a request entirely in-process.
263    pub async fn handle(&self, request: Request) -> R {
264        self.dispatch_routed(RoutedRequest::new(request)).await
265    }
266
267    /// Dispatches a routed request entirely in-process.
268    pub async fn handle_routed(&self, request: RoutedRequest) -> R {
269        self.dispatch_routed(request).await
270    }
271
272    async fn dispatch_routed(&self, request: RoutedRequest) -> R {
273        let path_exists = self
274            .routes
275            .iter()
276            .any(|route| route.path == request.request.path);
277        let route = match self.routes.iter().find(|route| {
278            route.path == request.request.path && route.method.matches(&request.request.method)
279        }) {
280            Some(route) => route,
281            None if path_exists => return not_found_or_method::<R>(StatusCode::METHOD_NOT_ALLOWED),
282            None => return not_found_or_method::<R>(StatusCode::NOT_FOUND),
283        };
284
285        let next = Next {
286            middleware: Arc::new(route.middleware.clone()),
287            handler: route.handler.clone(),
288            index: 0,
289        };
290        next.run(request, self.state.clone()).await
291    }
292}
293
294fn not_found_or_method<R>(status: StatusCode) -> R
295where
296    R: From<Response>,
297{
298    Response::new(status).into()
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use std::future::Future;
305    use std::pin::Pin;
306    use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
307
308    #[derive(Clone)]
309    struct TestState {
310        prefix: &'static str,
311    }
312
313    fn block_on<F>(future: F) -> F::Output
314    where
315        F: Future,
316    {
317        fn raw_waker() -> RawWaker {
318            fn clone(_: *const ()) -> RawWaker {
319                raw_waker()
320            }
321            fn wake(_: *const ()) {}
322            fn wake_by_ref(_: *const ()) {}
323            fn drop(_: *const ()) {}
324
325            RawWaker::new(
326                std::ptr::null(),
327                &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
328            )
329        }
330
331        let waker = unsafe { Waker::from_raw(raw_waker()) };
332        let mut future = Box::pin(future);
333        let mut context = Context::from_waker(&waker);
334
335        loop {
336            match Pin::as_mut(&mut future).poll(&mut context) {
337                Poll::Ready(value) => return value,
338                Poll::Pending => std::thread::yield_now(),
339            }
340        }
341    }
342
343    #[test]
344    fn router_dispatches_exact_method_and_path() {
345        let router =
346            Router::new(TestState { prefix: "echo:" }).get("/echo", |request, state| async move {
347                let mut body = state.prefix.as_bytes().to_vec();
348                body.extend_from_slice(&request.request.body);
349                Response::new(StatusCode::OK).with_body(body)
350            });
351
352        let response = block_on(
353            router.handle(Request::new(Method::Get, "/echo").with_body(b"hello".to_vec())),
354        );
355
356        assert_eq!(response.status, StatusCode::OK);
357        assert_eq!(response.body, b"echo:hello");
358    }
359
360    #[test]
361    fn router_dispatches_custom_method_route() {
362        let router = Router::new(TestState { prefix: "patch:" }).on(
363            Method::Other("PATCH".into()),
364            "/echo",
365            |request, state| async move {
366                let mut body = state.prefix.as_bytes().to_vec();
367                body.extend_from_slice(&request.request.body);
368                Response::new(StatusCode::OK).with_body(body)
369            },
370        );
371
372        let response = block_on(
373            router.handle(Request::new(Method::Other("PATCH".into()), "/echo").with_body(b"ok")),
374        );
375
376        assert_eq!(response.status, StatusCode::OK);
377        assert_eq!(response.body, b"patch:ok");
378    }
379
380    #[test]
381    fn router_dispatches_put_route() {
382        let router =
383            Router::new(TestState { prefix: "put:" }).put("/item", |request, state| async move {
384                let mut body = state.prefix.as_bytes().to_vec();
385                body.extend_from_slice(&request.request.body);
386                Response::new(StatusCode::OK).with_body(body)
387            });
388
389        let response = block_on(router.handle(Request::new(Method::Put, "/item").with_body(b"ok")));
390
391        assert_eq!(response.status, StatusCode::OK);
392        assert_eq!(response.body, b"put:ok");
393    }
394
395    #[test]
396    fn router_returns_method_not_allowed_when_path_exists() {
397        let router = Router::new(TestState { prefix: "x" })
398            .get("/echo", |_request, _state| async move {
399                Response::new(StatusCode::OK)
400            });
401
402        let response = block_on(router.handle(Request::new(Method::Post, "/echo")));
403        assert_eq!(response.status, StatusCode::METHOD_NOT_ALLOWED);
404    }
405
406    #[test]
407    fn middleware_wraps_handler() {
408        let router = Router::new(TestState { prefix: "core:" })
409            .get("/x", |_request, _state| async move {
410                Response::new(StatusCode::OK).with_body(b"body".to_vec())
411            })
412            .middleware(|mut request, state, next| async move {
413                request.extensions.insert::<u64>(7);
414                let mut response = next.run(request, state).await;
415                response.headers.insert("x-middleware", "yes");
416                response
417            });
418
419        let response = block_on(router.handle(Request::new(Method::Get, "/x")));
420        assert_eq!(response.status, StatusCode::OK);
421        assert_eq!(response.headers.get("X-Middleware"), Some("yes"));
422    }
423
424    #[test]
425    fn middleware_extensions_reach_handler() {
426        let router = Router::new(TestState { prefix: "ext:" })
427            .get("/x", |mut request, state| async move {
428                let marker = request.extensions.remove::<u64>().unwrap_or_default();
429                let mut body = state.prefix.as_bytes().to_vec();
430                body.extend_from_slice(marker.to_string().as_bytes());
431                Response::new(StatusCode::OK).with_body(body)
432            })
433            .middleware(|mut request, state, next| async move {
434                request.extensions.insert::<u64>(7);
435                next.run(request, state).await
436            });
437
438        let response = block_on(router.handle(Request::new(Method::Get, "/x")));
439        assert_eq!(response.status, StatusCode::OK);
440        assert_eq!(response.body, b"ext:7");
441    }
442
443    #[test]
444    fn middleware_applies_to_routes_added_after_layer() {
445        let router = Router::new(TestState { prefix: "late:" })
446            .middleware(|request, state, next| async move {
447                let mut response: Response = next.run(request, state).await;
448                response.headers.insert("x-layered", "yes");
449                response
450            })
451            .get("/x", |_request, state| async move {
452                Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
453            });
454
455        let response = block_on(router.handle(Request::new(Method::Get, "/x")));
456        assert_eq!(response.headers.get("x-layered"), Some("yes"));
457    }
458
459    #[test]
460    fn merged_router_does_not_leak_middleware_to_later_routes() {
461        let public = Router::new(TestState { prefix: "public:" }).get(
462            "/public",
463            |_request, state| async move {
464                Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
465            },
466        );
467        let protected = Router::new(TestState { prefix: "auth:" })
468            .get("/protected", |_request, state| async move {
469                Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
470            })
471            .middleware(|request, state, next| async move {
472                let mut response = next.run(request, state).await;
473                response.headers.insert("x-auth", "yes");
474                response
475            });
476        let router = public
477            .merge(protected)
478            .get("/later", |_request, state| async move {
479                Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
480            });
481
482        let protected_response = block_on(router.handle(Request::new(Method::Get, "/protected")));
483        assert_eq!(protected_response.headers.get("x-auth"), Some("yes"));
484
485        let later_response = block_on(router.handle(Request::new(Method::Get, "/later")));
486        assert_eq!(later_response.headers.get("x-auth"), None);
487    }
488}