saphir/
router.rs

1//! Router is responsible for redirecting requests to handlers.
2//!
3//! *SAFETY NOTICE*
4//!
5//! Inside the router we need a little bit of unsafe code. This code allow us to
6//! consider the futures generated by the handlers as 'static. This is
7//! considered safe since all data lives within the server stack which has a
8//! static lifetime over your application. We plan to remove this unsafe code as
9//! soon as we find another solution to it.
10#![allow(clippy::type_complexity)]
11
12use crate::{
13    body::Body,
14    controller::{Controller, DynControllerHandler},
15    error::SaphirError,
16    guard::{Builder as GuardBuilder, GuardChain, GuardChainEnd},
17    handler::DynHandler,
18    http_context::{HandlerMetadata, HttpContext, RouteId, State},
19    request::Request,
20    responder::{DynResponder, Responder},
21    utils::{EndpointResolver, EndpointResolverResult},
22};
23use futures::{future::BoxFuture, FutureExt};
24use http::Method;
25use std::{collections::HashMap, sync::Arc};
26
27/// Builder type for the router
28pub struct Builder<Chain: RouterChain + Send + Unpin + 'static + Sync> {
29    resolver: HashMap<String, EndpointResolver>,
30    chain: Chain,
31}
32
33impl Default for Builder<RouterChainEnd> {
34    fn default() -> Self {
35        Self {
36            resolver: Default::default(),
37            chain: RouterChainEnd { handlers: Default::default() },
38        }
39    }
40}
41
42impl<Controllers: 'static + RouterChain + Unpin + Send + Sync> Builder<Controllers> {
43    /// Add a simple request handle to a given path
44    ///
45    /// ```rust
46    /// # use saphir::router::Builder as RBuilder;
47    /// # use saphir::prelude::*;
48    /// #
49    /// # let builder = RBuilder::default();
50    /// // Simply declare a handler fn
51    /// async fn simple_handler(req: Request<Body>) -> impl Responder {200}
52    ///
53    /// // Then while building your server
54    /// // ...
55    /// builder.route("/simple", Method::GET, simple_handler);
56    /// // ...
57    /// ```
58    pub fn route<H>(mut self, route: &str, method: Method, handler: H) -> Self
59    where
60        H: 'static + DynHandler<Body> + Send + Sync,
61    {
62        let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
63            er.add_method(method.clone());
64            er.id()
65        } else {
66            let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
67            let er_id = er.id();
68            self.resolver.insert(route.to_string(), er);
69            er_id
70        };
71
72        self.chain
73            .add_handler(endpoint_id, method, Box::new(handler), crate::guard::Builder::default().build());
74
75        self
76    }
77
78    /// Add a request handler to a given path behind guards
79    ///
80    /// ```rust
81    /// # use saphir::router::Builder as RBuilder;
82    /// # use saphir::prelude::*;
83    /// #
84    /// # let builder = RBuilder::default();
85    ///
86    /// async fn handler(req: Request<Body>) -> impl Responder { 200 }
87    ///
88    /// async fn value_guard(req: Request<Body>) -> Result<Request<Body>, u16> {
89    ///    match req.captures().get("value") {
90    ///        Some(v) if v.eq("allowed_value") => Ok(req),
91    ///        Some(_) => Err(403),
92    ///        None => Err(400),
93    ///    }
94    /// }
95    ///
96    /// builder.route_with_guards("/handler/{value}", Method::GET, handler, |g| {
97    ///     g.apply(value_guard)
98    /// });
99    /// // ...
100    /// ```
101    pub fn route_with_guards<H, F, Chain>(mut self, route: &str, method: Method, handler: H, guards: F) -> Self
102    where
103        H: 'static + DynHandler<Body> + Send + Sync,
104        F: FnOnce(GuardBuilder<GuardChainEnd>) -> GuardBuilder<Chain>,
105        Chain: GuardChain + 'static,
106    {
107        let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
108            er.add_method(method.clone());
109            er.id()
110        } else {
111            let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
112            let er_id = er.id();
113            self.resolver.insert(route.to_string(), er);
114            er_id
115        };
116
117        self.chain
118            .add_handler(endpoint_id, method, Box::new(handler), guards(GuardBuilder::default()).build());
119
120        self
121    }
122
123    /// Add a simple request handle to a given path
124    ///
125    /// ```rust
126    /// # use saphir::router::Builder as RBuilder;
127    /// # use saphir::prelude::*;
128    /// #
129    /// # let builder = RBuilder::default();
130    /// // Implement controller for your struct
131    /// struct SimpleController;
132    /// # impl Controller for SimpleController {
133    /// #    const BASE_PATH: &'static str = "/basic";
134    /// #    fn handlers(&self) -> Vec<ControllerEndpoint<Self>> where Self: Sized {EndpointsBuilder::new().build()}
135    /// # }
136    /// // Then while building your server
137    /// // ...
138    /// builder.controller(SimpleController);
139    /// // ...
140    /// ```
141    pub fn controller<C: Controller + Send + Unpin + Sync>(mut self, controller: C) -> Builder<RouterChainLink<C, Controllers>> {
142        let mut handlers = HashMap::new();
143        for (name, method, subroute, handler, guard_chain) in controller.handlers() {
144            let route = format!("{}{}", C::BASE_PATH, subroute);
145            let meta = name.map(|name| HandlerMetadata {
146                route_id: Default::default(),
147                name: Some(name),
148            });
149            let endpoint_id = if let Some(er) = self.resolver.get_mut(&route) {
150                er.add_method_with_metadata(method.clone(), meta);
151                er.id()
152            } else {
153                let er = EndpointResolver::new_with_metadata(&route, method.clone(), meta).expect("Unable to construct endpoint resolver");
154                let er_id = er.id();
155                self.resolver.insert(route, er);
156                er_id
157            };
158
159            handlers.insert((endpoint_id, method), (handler, guard_chain));
160        }
161
162        Builder {
163            resolver: self.resolver,
164            chain: RouterChainLink {
165                controller,
166                handlers,
167                rest: self.chain,
168            },
169        }
170    }
171
172    pub(crate) fn build(self) -> Router {
173        let Builder { resolver, chain: controllers } = self;
174
175        let mut resolvers: Vec<_> = resolver.into_values().collect();
176        resolvers.sort_unstable();
177
178        Router {
179            inner: Arc::new(RouterInner {
180                resolvers,
181                chain: Box::new(controllers),
182            }),
183        }
184    }
185}
186
187struct RouterInner {
188    resolvers: Vec<EndpointResolver>,
189    chain: Box<dyn RouterChain + Send + Unpin + Sync>,
190}
191
192#[doc(hidden)]
193#[derive(Clone)]
194pub struct Router {
195    inner: Arc<RouterInner>,
196}
197
198impl Router {
199    pub fn builder() -> Builder<RouterChainEnd> {
200        Builder::default()
201    }
202
203    pub fn resolve(&self, req: &mut Request<Body>) -> Result<u64, u16> {
204        let mut method_not_allowed = false;
205        for endpoint_resolver in &self.inner.resolvers {
206            match endpoint_resolver.resolve(req) {
207                EndpointResolverResult::InvalidPath => continue,
208                EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
209                EndpointResolverResult::Match(_) => return Ok(endpoint_resolver.id()),
210            }
211        }
212
213        if method_not_allowed {
214            Err(405)
215        } else {
216            Err(404)
217        }
218    }
219
220    pub fn resolve_metadata(&self, req: &mut Request) -> HandlerMetadata {
221        let mut method_not_allowed = false;
222
223        for endpoint_resolver in &self.inner.resolvers {
224            match endpoint_resolver.resolve(req) {
225                EndpointResolverResult::InvalidPath => continue,
226                EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
227                EndpointResolverResult::Match(meta) => return meta.clone(),
228            }
229        }
230
231        if method_not_allowed {
232            HandlerMetadata::not_allowed()
233        } else {
234            HandlerMetadata::not_found()
235        }
236    }
237
238    pub async fn dispatch(&self, mut ctx: HttpContext) -> Result<HttpContext, SaphirError> {
239        let req = ctx.state.take_request().ok_or(SaphirError::RequestMovedBeforeHandler)?;
240        // # SAFETY #
241        // The router is initialized in static memory when calling run on Server.
242        let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) };
243        let b = crate::response::Builder::new();
244        let route_id = match ctx.metadata.route_id {
245            RouteId::Id(id) => id,
246            RouteId::Error(e) => {
247                return e.respond_with_builder(b, &ctx).build().map(|r| {
248                    ctx.state = State::After(Box::new(r));
249                    ctx
250                });
251            }
252        };
253        let res = if let Some(responder) = static_self.inner.chain.dispatch(route_id, req) {
254            responder.await.dyn_respond(b, &ctx)
255        } else {
256            404.respond_with_builder(b, &ctx)
257        }
258        .build();
259
260        res.map(|r| {
261            ctx.state = State::After(Box::new(r));
262            ctx
263        })
264    }
265}
266
267#[doc(hidden)]
268pub trait RouterChain {
269    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>>;
270    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>);
271}
272
273#[doc(hidden)]
274pub struct RouterChainEnd {
275    handlers: HashMap<(u64, Method), (Box<dyn DynHandler<Body> + Send + Sync>, Box<dyn GuardChain>)>,
276}
277
278impl RouterChain for RouterChainEnd {
279    #[inline]
280    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
281        if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
282            if handler.1.is_end() {
283                Some(handler.0.dyn_handle(req))
284            } else {
285                let fut = handler.1.validate(req).then(move |req| async move {
286                    match req {
287                        Ok(req) => handler.0.dyn_handle(req).await,
288                        Err(resp) => resp,
289                    }
290                });
291                Some(fut.boxed())
292            }
293        } else {
294            None
295        }
296    }
297
298    #[inline]
299    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
300        self.handlers.insert((endpoint_id, method), (handler, guards));
301    }
302}
303
304#[doc(hidden)]
305pub struct RouterChainLink<C, Rest: RouterChain> {
306    controller: C,
307    handlers: HashMap<(u64, Method), (Box<dyn DynControllerHandler<C, Body> + Send + Sync>, Box<dyn GuardChain>)>,
308    rest: Rest,
309}
310
311impl<C: Sync + Send, Rest: RouterChain + Sync + Send> RouterChain for RouterChainLink<C, Rest> {
312    #[inline]
313    fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
314        if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
315            if handler.1.is_end() {
316                Some(handler.0.dyn_handle(&self.controller, req))
317            } else {
318                let fut = handler.1.validate(req).then(move |req| async move {
319                    match req {
320                        Ok(req) => handler.0.dyn_handle(&self.controller, req).await,
321                        Err(resp) => resp,
322                    }
323                });
324                Some(fut.boxed())
325            }
326        } else {
327            self.rest.dispatch(resolver_id, req)
328        }
329    }
330
331    #[inline]
332    fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
333        self.rest.add_handler(endpoint_id, method, handler, guards);
334    }
335}