1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use neco_server_core::{Method, Request, Response, StatusCode};
6
7use crate::Extensions;
8
9type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
10
11pub struct RoutedRequest {
13 pub request: Request,
15 pub extensions: Extensions,
17}
18
19impl RoutedRequest {
20 pub fn new(request: Request) -> Self {
22 Self {
23 request,
24 extensions: Extensions::new(),
25 }
26 }
27}
28
29pub type Handler<S, R = Response> = Arc<dyn Fn(RoutedRequest, S) -> BoxFuture<R> + Send + Sync>;
31
32pub type Middleware<S, R = Response> =
34 Arc<dyn Fn(RoutedRequest, S, Next<S, R>) -> BoxFuture<R> + Send + Sync>;
35
36#[derive(Clone)]
37enum RouteMethod {
38 Exact(Method),
39 Any,
40}
41
42impl RouteMethod {
43 fn matches(&self, method: &Method) -> bool {
44 match self {
45 Self::Exact(expected) => expected == method,
46 Self::Any => true,
47 }
48 }
49}
50
51struct Route<S, R> {
52 method: RouteMethod,
53 path: String,
54 handler: Handler<S, R>,
55 middleware: Vec<Middleware<S, R>>,
56}
57
58impl<S, R> Clone for Route<S, R> {
59 fn clone(&self) -> Self {
60 Self {
61 method: self.method.clone(),
62 path: self.path.clone(),
63 handler: self.handler.clone(),
64 middleware: self.middleware.clone(),
65 }
66 }
67}
68
69pub struct Next<S, R = Response> {
71 middleware: Arc<Vec<Middleware<S, R>>>,
72 handler: Handler<S, R>,
73 index: usize,
74}
75
76impl<S, R> Clone for Next<S, R> {
77 fn clone(&self) -> Self {
78 Self {
79 middleware: self.middleware.clone(),
80 handler: self.handler.clone(),
81 index: self.index,
82 }
83 }
84}
85
86impl<S, R> Next<S, R>
87where
88 S: Clone + Send + Sync + 'static,
89 R: Send + 'static,
90{
91 pub fn run(&self, request: RoutedRequest, state: S) -> BoxFuture<R> {
93 if let Some(middleware) = self.middleware.get(self.index).cloned() {
94 let next = Self {
95 middleware: self.middleware.clone(),
96 handler: self.handler.clone(),
97 index: self.index + 1,
98 };
99 middleware(request, state, next)
100 } else {
101 (self.handler)(request, state)
102 }
103 }
104}
105
106pub struct Router<S, R = Response> {
108 state: S,
109 routes: Vec<Route<S, R>>,
110 pending_middleware: Vec<Middleware<S, R>>,
111}
112
113impl<S, R> Clone for Router<S, R>
114where
115 S: Clone,
116{
117 fn clone(&self) -> Self {
118 Self {
119 state: self.state.clone(),
120 routes: self.routes.clone(),
121 pending_middleware: self.pending_middleware.clone(),
122 }
123 }
124}
125
126impl<S, R> Router<S, R>
127where
128 S: Clone + Send + Sync + 'static,
129 R: From<Response> + Send + 'static,
130{
131 pub fn new(state: S) -> Self {
133 Self {
134 state,
135 routes: Vec::new(),
136 pending_middleware: Vec::new(),
137 }
138 }
139
140 pub fn get<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
142 where
143 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
144 Fut: Future<Output = R> + Send + 'static,
145 {
146 self.route(RouteMethod::Exact(Method::Get), path, handler)
147 }
148
149 pub fn post<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
151 where
152 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
153 Fut: Future<Output = R> + Send + 'static,
154 {
155 self.route(RouteMethod::Exact(Method::Post), path, handler)
156 }
157
158 pub fn put<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
160 where
161 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
162 Fut: Future<Output = R> + Send + 'static,
163 {
164 self.route(RouteMethod::Exact(Method::Put), path, handler)
165 }
166
167 pub fn delete<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
169 where
170 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
171 Fut: Future<Output = R> + Send + 'static,
172 {
173 self.route(RouteMethod::Exact(Method::Delete), path, handler)
174 }
175
176 pub fn patch<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
178 where
179 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
180 Fut: Future<Output = R> + Send + 'static,
181 {
182 self.route(RouteMethod::Exact(Method::Patch), path, handler)
183 }
184
185 pub fn head<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
187 where
188 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
189 Fut: Future<Output = R> + Send + 'static,
190 {
191 self.route(RouteMethod::Exact(Method::Head), path, handler)
192 }
193
194 pub fn options<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
196 where
197 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
198 Fut: Future<Output = R> + Send + 'static,
199 {
200 self.route(RouteMethod::Exact(Method::Options), path, handler)
201 }
202
203 pub fn any<F, Fut>(self, path: impl Into<String>, handler: F) -> Self
205 where
206 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
207 Fut: Future<Output = R> + Send + 'static,
208 {
209 self.route(RouteMethod::Any, path, handler)
210 }
211
212 pub fn on<F, Fut>(self, method: Method, path: impl Into<String>, handler: F) -> Self
214 where
215 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
216 Fut: Future<Output = R> + Send + 'static,
217 {
218 self.route(RouteMethod::Exact(method), path, handler)
219 }
220
221 fn route<F, Fut>(mut self, method: RouteMethod, path: impl Into<String>, handler: F) -> Self
222 where
223 F: Fn(RoutedRequest, S) -> Fut + Send + Sync + 'static,
224 Fut: Future<Output = R> + Send + 'static,
225 {
226 let handler: Handler<S, R> =
227 Arc::new(move |request, state| Box::pin(handler(request, state)));
228 self.routes.push(Route {
229 method,
230 path: path.into(),
231 handler,
232 middleware: self.pending_middleware.clone(),
233 });
234 self
235 }
236
237 pub fn middleware<F, Fut>(mut self, middleware: F) -> Self
243 where
244 F: Fn(RoutedRequest, S, Next<S, R>) -> Fut + Send + Sync + 'static,
245 Fut: Future<Output = R> + Send + 'static,
246 {
247 let middleware: Middleware<S, R> =
248 Arc::new(move |request, state, next| Box::pin(middleware(request, state, next)));
249 for route in &mut self.routes {
250 route.middleware.push(middleware.clone());
251 }
252 self.pending_middleware.push(middleware);
253 self
254 }
255
256 pub fn merge(mut self, other: Self) -> Self {
258 self.routes.extend(other.routes);
259 self
260 }
261
262 pub async fn handle(&self, request: Request) -> R {
264 self.dispatch_routed(RoutedRequest::new(request)).await
265 }
266
267 pub async fn handle_routed(&self, request: RoutedRequest) -> R {
269 self.dispatch_routed(request).await
270 }
271
272 async fn dispatch_routed(&self, request: RoutedRequest) -> R {
273 let path_exists = self
274 .routes
275 .iter()
276 .any(|route| route.path == request.request.path);
277 let route = match self.routes.iter().find(|route| {
278 route.path == request.request.path && route.method.matches(&request.request.method)
279 }) {
280 Some(route) => route,
281 None if path_exists => return not_found_or_method::<R>(StatusCode::METHOD_NOT_ALLOWED),
282 None => return not_found_or_method::<R>(StatusCode::NOT_FOUND),
283 };
284
285 let next = Next {
286 middleware: Arc::new(route.middleware.clone()),
287 handler: route.handler.clone(),
288 index: 0,
289 };
290 next.run(request, self.state.clone()).await
291 }
292}
293
294fn not_found_or_method<R>(status: StatusCode) -> R
295where
296 R: From<Response>,
297{
298 Response::new(status).into()
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use std::future::Future;
305 use std::pin::Pin;
306 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
307
308 #[derive(Clone)]
309 struct TestState {
310 prefix: &'static str,
311 }
312
313 fn block_on<F>(future: F) -> F::Output
314 where
315 F: Future,
316 {
317 fn raw_waker() -> RawWaker {
318 fn clone(_: *const ()) -> RawWaker {
319 raw_waker()
320 }
321 fn wake(_: *const ()) {}
322 fn wake_by_ref(_: *const ()) {}
323 fn drop(_: *const ()) {}
324
325 RawWaker::new(
326 std::ptr::null(),
327 &RawWakerVTable::new(clone, wake, wake_by_ref, drop),
328 )
329 }
330
331 let waker = unsafe { Waker::from_raw(raw_waker()) };
332 let mut future = Box::pin(future);
333 let mut context = Context::from_waker(&waker);
334
335 loop {
336 match Pin::as_mut(&mut future).poll(&mut context) {
337 Poll::Ready(value) => return value,
338 Poll::Pending => std::thread::yield_now(),
339 }
340 }
341 }
342
343 #[test]
344 fn router_dispatches_exact_method_and_path() {
345 let router =
346 Router::new(TestState { prefix: "echo:" }).get("/echo", |request, state| async move {
347 let mut body = state.prefix.as_bytes().to_vec();
348 body.extend_from_slice(&request.request.body);
349 Response::new(StatusCode::OK).with_body(body)
350 });
351
352 let response = block_on(
353 router.handle(Request::new(Method::Get, "/echo").with_body(b"hello".to_vec())),
354 );
355
356 assert_eq!(response.status, StatusCode::OK);
357 assert_eq!(response.body, b"echo:hello");
358 }
359
360 #[test]
361 fn router_dispatches_custom_method_route() {
362 let router = Router::new(TestState { prefix: "patch:" }).on(
363 Method::Other("PATCH".into()),
364 "/echo",
365 |request, state| async move {
366 let mut body = state.prefix.as_bytes().to_vec();
367 body.extend_from_slice(&request.request.body);
368 Response::new(StatusCode::OK).with_body(body)
369 },
370 );
371
372 let response = block_on(
373 router.handle(Request::new(Method::Other("PATCH".into()), "/echo").with_body(b"ok")),
374 );
375
376 assert_eq!(response.status, StatusCode::OK);
377 assert_eq!(response.body, b"patch:ok");
378 }
379
380 #[test]
381 fn router_dispatches_put_route() {
382 let router =
383 Router::new(TestState { prefix: "put:" }).put("/item", |request, state| async move {
384 let mut body = state.prefix.as_bytes().to_vec();
385 body.extend_from_slice(&request.request.body);
386 Response::new(StatusCode::OK).with_body(body)
387 });
388
389 let response = block_on(router.handle(Request::new(Method::Put, "/item").with_body(b"ok")));
390
391 assert_eq!(response.status, StatusCode::OK);
392 assert_eq!(response.body, b"put:ok");
393 }
394
395 #[test]
396 fn router_returns_method_not_allowed_when_path_exists() {
397 let router = Router::new(TestState { prefix: "x" })
398 .get("/echo", |_request, _state| async move {
399 Response::new(StatusCode::OK)
400 });
401
402 let response = block_on(router.handle(Request::new(Method::Post, "/echo")));
403 assert_eq!(response.status, StatusCode::METHOD_NOT_ALLOWED);
404 }
405
406 #[test]
407 fn middleware_wraps_handler() {
408 let router = Router::new(TestState { prefix: "core:" })
409 .get("/x", |_request, _state| async move {
410 Response::new(StatusCode::OK).with_body(b"body".to_vec())
411 })
412 .middleware(|mut request, state, next| async move {
413 request.extensions.insert::<u64>(7);
414 let mut response = next.run(request, state).await;
415 response.headers.insert("x-middleware", "yes");
416 response
417 });
418
419 let response = block_on(router.handle(Request::new(Method::Get, "/x")));
420 assert_eq!(response.status, StatusCode::OK);
421 assert_eq!(response.headers.get("X-Middleware"), Some("yes"));
422 }
423
424 #[test]
425 fn middleware_extensions_reach_handler() {
426 let router = Router::new(TestState { prefix: "ext:" })
427 .get("/x", |mut request, state| async move {
428 let marker = request.extensions.remove::<u64>().unwrap_or_default();
429 let mut body = state.prefix.as_bytes().to_vec();
430 body.extend_from_slice(marker.to_string().as_bytes());
431 Response::new(StatusCode::OK).with_body(body)
432 })
433 .middleware(|mut request, state, next| async move {
434 request.extensions.insert::<u64>(7);
435 next.run(request, state).await
436 });
437
438 let response = block_on(router.handle(Request::new(Method::Get, "/x")));
439 assert_eq!(response.status, StatusCode::OK);
440 assert_eq!(response.body, b"ext:7");
441 }
442
443 #[test]
444 fn middleware_applies_to_routes_added_after_layer() {
445 let router = Router::new(TestState { prefix: "late:" })
446 .middleware(|request, state, next| async move {
447 let mut response: Response = next.run(request, state).await;
448 response.headers.insert("x-layered", "yes");
449 response
450 })
451 .get("/x", |_request, state| async move {
452 Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
453 });
454
455 let response = block_on(router.handle(Request::new(Method::Get, "/x")));
456 assert_eq!(response.headers.get("x-layered"), Some("yes"));
457 }
458
459 #[test]
460 fn merged_router_does_not_leak_middleware_to_later_routes() {
461 let public = Router::new(TestState { prefix: "public:" }).get(
462 "/public",
463 |_request, state| async move {
464 Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
465 },
466 );
467 let protected = Router::new(TestState { prefix: "auth:" })
468 .get("/protected", |_request, state| async move {
469 Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
470 })
471 .middleware(|request, state, next| async move {
472 let mut response = next.run(request, state).await;
473 response.headers.insert("x-auth", "yes");
474 response
475 });
476 let router = public
477 .merge(protected)
478 .get("/later", |_request, state| async move {
479 Response::new(StatusCode::OK).with_body(state.prefix.as_bytes().to_vec())
480 });
481
482 let protected_response = block_on(router.handle(Request::new(Method::Get, "/protected")));
483 assert_eq!(protected_response.headers.get("x-auth"), Some("yes"));
484
485 let later_response = block_on(router.handle(Request::new(Method::Get, "/later")));
486 assert_eq!(later_response.headers.get("x-auth"), None);
487 }
488}