reset_router/
lib.rs

1/*!
2A fast [`RegexSet`](https://doc.rust-lang.org/regex/regex/struct.RegexSet.html) based router for use with async Hyper (0.13).
3
4Individual handler functions should have the type `H`, where
5```rust
6    H: Fn(Request) -> F,
7    F: Future<Output = Result<S, E>> + Send,
8    S: Into<Response>,
9    E: Into<Response>,
10```
11
12You can return something as simple as `Ok(Response::new("hello world".into()))`. You don't have to worry about futures
13unless you need to read the request body or interact with other future-aware things.
14
15## Usage:
16
17```rust
18use reset_router::{Request, RequestExtensions, Response, Router, SharedService};
19use std::sync::Arc;
20
21pub struct Handler(Arc<String>);
22
23impl SharedService for Handler {
24    type Response = Response;
25    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
26    type Future = futures::future::Ready<Result<Self::Response, Self::Error>>;
27
28    fn call(&self, _: Request) -> Self::Future {
29        futures::future::ready(Ok(http::Response::builder()
30            .status(200)
31            .body(format!("Hello, {}!", &self.0).into())
32            .unwrap()))
33    }
34}
35
36#[derive(Clone)]
37pub struct State(pub i32);
38
39async fn hello(req: Request) -> Result<Response, Response> {
40    let (first_name, last_name) = req.parsed_captures::<(String, String)>()?;
41    Ok(http::Response::builder()
42        .status(200)
43        .body(format!("Hello, {} {}!", first_name, last_name).into())
44        .unwrap())
45}
46
47async fn add(req: Request) -> Result<Response, Response> {
48    let (add1, add2) = req.parsed_captures::<(i32, i32)>()?;
49
50    let state_num: i32 = req.data::<State>().map(|x| x.0).unwrap_or(0);
51
52    Ok(http::Response::builder()
53        .status(200)
54        .body(
55            format!("{} + {} + {} = {}\r\n", add1, add2, state_num, add1 + add2 + state_num).into(),
56        )
57        .unwrap())
58}
59
60#[tokio::main]
61async fn main() -> Result<(), Box<dyn std::error::Error>> {
62    let router = Router::build()
63        .data(State(42))
64        .add(http::Method::POST, r"^/hello/([^/]+)/(.+)$", hello)
65        .add(http::Method::GET, r"^/hello/([^/]+)/(.+)$", hello)
66        .add(http::Method::GET, r"^/add/([\d]+)/([\d]+)$", add)
67        .add(http::Method::GET, r"^/other$", Handler(Arc::new(String::from("world"))))
68        .add_not_found(|_| {
69            async {
70                Ok::<_, Response>(http::Response::builder().status(404).body("404".into()).unwrap())
71            }
72        })
73        .finish()?;
74
75    let addr = "0.0.0.0:3000".parse()?;
76
77    let server = hyper::Server::bind(&addr).serve(router);
78
79    server.await?;
80
81    Ok(())
82}
83```
84*/
85
86/// Error handling
87pub mod err {
88    /// The error enum
89    #[derive(Debug)]
90    pub enum Error {
91        CapturesMissing,
92        MethodNotSupported,
93        Http(http::Error),
94        Recognizer(reset_recognizer::err::Error),
95    }
96
97    impl std::fmt::Display for Error {
98        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
99            use Error::*;
100            match self {
101                CapturesMissing => "Captures missing".fmt(f),
102                MethodNotSupported => "Method not supported".fmt(f),
103                Http(ref inner) => inner.fmt(f),
104                Recognizer(ref inner) => inner.fmt(f),
105            }
106        }
107    }
108
109    impl std::error::Error for Error {
110        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
111            use Error::*;
112            match self {
113                Http(ref inner) => Some(inner),
114                Recognizer(ref inner) => Some(inner),
115                _ => None,
116            }
117        }
118    }
119
120    /// Result wrapper: `Result<T, Error>`
121    pub type Result<T> = std::result::Result<T, Error>;
122
123    impl From<Error> for http::Response<hyper::Body> {
124        fn from(t: Error) -> Self {
125            http::Response::builder().status(500).body(t.to_string().into()).unwrap()
126        }
127    }
128}
129
130use reset_recognizer as recognizer;
131
132use std::collections::HashMap;
133
134use std::pin::Pin;
135use std::sync::Arc;
136
137use futures::{
138    future::{ready, Ready},
139    ready, Future, FutureExt, TryFuture,
140};
141use std::task::{Context, Poll};
142
143/// Convenience wrapper for `http::Request<hyper::Body>`
144pub type Request = http::Request<hyper::Body>;
145
146/// Convenience wrapper for `http::Response<hyper::Body>`
147pub type Response = http::Response<hyper::Body>;
148
149struct MethodMap<T> {
150    options: Option<T>,
151    get: Option<T>,
152    post: Option<T>,
153    put: Option<T>,
154    delete: Option<T>,
155    head: Option<T>,
156    trace: Option<T>,
157    connect: Option<T>,
158    patch: Option<T>,
159    extension: Option<HashMap<http::Method, T>>,
160}
161
162impl<T> MethodMap<T> {
163    fn new() -> Self {
164        Self {
165            options: None,
166            get: None,
167            post: None,
168            put: None,
169            delete: None,
170            head: None,
171            trace: None,
172            connect: None,
173            patch: None,
174            extension: None,
175        }
176    }
177
178    fn get(&self, method: &http::Method) -> Option<&T> {
179        match *method {
180            http::Method::OPTIONS => self.options.as_ref(),
181            http::Method::GET => self.get.as_ref(),
182            http::Method::POST => self.post.as_ref(),
183            http::Method::PUT => self.put.as_ref(),
184            http::Method::DELETE => self.delete.as_ref(),
185            http::Method::HEAD => self.head.as_ref(),
186            http::Method::TRACE => self.trace.as_ref(),
187            http::Method::CONNECT => self.connect.as_ref(),
188            http::Method::PATCH => self.patch.as_ref(),
189            ref m => self.extension.as_ref().and_then(|e| e.get(m)),
190        }
191    }
192
193    fn insert(&mut self, method: http::Method, t: T) {
194        match method {
195            http::Method::OPTIONS => {
196                self.options = Some(t);
197            }
198            http::Method::GET => {
199                self.get = Some(t);
200            }
201            http::Method::POST => {
202                self.post = Some(t);
203            }
204            http::Method::PUT => {
205                self.put = Some(t);
206            }
207            http::Method::DELETE => {
208                self.delete = Some(t);
209            }
210            http::Method::HEAD => {
211                self.head = Some(t);
212            }
213            http::Method::TRACE => {
214                self.trace = Some(t);
215            }
216            http::Method::CONNECT => {
217                self.connect = Some(t);
218            }
219            http::Method::PATCH => {
220                self.patch = Some(t);
221            }
222            m => {
223                let mut extension = self.extension.take().unwrap_or_else(HashMap::new);
224                extension.insert(m, t);
225                self.extension = Some(extension);
226            }
227        }
228    }
229}
230
231/// Container for application `data`, available in request handler `fn`s
232pub struct Data<T>(Arc<T>);
233
234impl<T> Data<T> {
235    /// Create a new `data` container
236    pub fn new(t: T) -> Self {
237        Data(Arc::new(t))
238    }
239
240    /// Create a new `data` from an existing `Arc<T>`
241    pub fn from_arc(arc: Arc<T>) -> Self {
242        Data(arc)
243    }
244}
245
246impl<T> std::ops::Deref for Data<T> {
247    type Target = T;
248
249    fn deref(&self) -> &Self::Target {
250        &*self.0
251    }
252}
253
254impl<T> Clone for Data<T> {
255    fn clone(&self) -> Self {
256        Data(Arc::clone(&self.0))
257    }
258}
259
260/// Shared trait for route handlers. Similar to `tower::Service`, but takes `&self`.
261///
262/// Implemented for `H` where
263/// ```
264/// H: Fn(Request) -> F,
265/// F: Future<Output = Result<S, E>> + Send,
266/// S: Into<Response>,
267/// E: Into<Response>,
268/// ```
269pub trait SharedService {
270    type Response: Into<Response>;
271    type Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
272    type Future: Future<Output = Result<Self::Response, Self::Error>> + Send;
273
274    fn call(&self, request: Request) -> Self::Future;
275}
276
277#[doc(hidden)]
278#[pin_project::pin_project]
279pub struct HandlerFuture<F> {
280    #[pin]
281    inner: F,
282}
283
284impl<F, S, E> Future for HandlerFuture<F>
285where
286    F: Future<Output = Result<S, E>> + Send,
287    S: Into<Response>,
288    E: Into<Response>,
289{
290    type Output = Result<Response, Box<dyn std::error::Error + Send + Sync + 'static>>;
291
292    #[inline]
293    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
294        let out = match ready!(self.project().inner.try_poll(cx)) {
295            Ok(res) => Ok(res.into()),
296            Err(err) => Ok(err.into()),
297        };
298        Poll::Ready(out)
299    }
300}
301
302impl<H, F, S, E> SharedService for H
303where
304    F: Future<Output = Result<S, E>> + Send,
305    S: Into<Response>,
306    E: Into<Response>,
307    H: Fn(Request) -> F,
308{
309    type Response = Response;
310    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
311    type Future = HandlerFuture<F>;
312
313    fn call(&self, request: Request) -> Self::Future {
314        HandlerFuture { inner: self(request) }
315    }
316}
317
318struct BoxedSharedService(
319    Arc<
320        dyn Fn(
321                Request,
322            ) -> Pin<
323                Box<
324                    dyn Future<
325                            Output = Result<
326                                Response,
327                                Box<dyn std::error::Error + Send + Sync + 'static>,
328                            >,
329                        > + Send,
330                >,
331            > + Send
332            + Sync,
333    >,
334);
335
336impl BoxedSharedService {
337    fn new<T: SharedService + Send + Sync + 'static>(t: T) -> Self {
338        Self(Arc::new(move |req: Request| {
339            let rt = t.call(req).map(|res| res.map(|s| s.into()).map_err(|e| e.into()));
340            Box::pin(rt)
341        }))
342    }
343}
344
345#[derive(Clone)]
346struct DataMap(Data<type_map::concurrent::TypeMap>);
347
348struct InnerRouter {
349    data: Option<DataMap>,
350    not_found: BoxedSharedService,
351    routers: MethodMap<recognizer::Router<BoxedSharedService>>,
352}
353
354/// The router, impls `hyper::service::Service` and the equivalent of `MakeService`
355#[derive(Clone)]
356pub struct Router(Arc<InnerRouter>);
357
358impl Router {
359    /// Create a new `RouterBuilder`
360    pub fn build() -> RouterBuilder {
361        RouterBuilder::new()
362    }
363}
364
365/// Builder for a `Router`
366pub struct RouterBuilder {
367    data: Option<type_map::concurrent::TypeMap>,
368    not_found: Option<BoxedSharedService>,
369    inner: HashMap<http::Method, recognizer::RouterBuilder<BoxedSharedService>>,
370}
371
372impl RouterBuilder {
373    fn new() -> Self {
374        RouterBuilder { data: None, not_found: None, inner: HashMap::new() }
375    }
376}
377
378impl RouterBuilder {
379    fn default_not_found(_: Request) -> impl Future<Output = Result<Response, Response>> {
380        async { Ok(http::Response::builder().status(404).body("Not Found".into()).unwrap()) }
381    }
382
383    /// Add application `data` to router
384    pub fn data<T: Send + Sync + 'static>(self, data: T) -> Self {
385        self.wrapped_data(Data::new(data))
386    }
387
388    /// Add application `data` to router from an existing `Data<T>` object
389    pub fn wrapped_data<T: Send + Sync + 'static>(mut self, data: Data<T>) -> Self {
390        let mut map = self.data.take().unwrap_or_else(type_map::concurrent::TypeMap::new);
391        map.insert(data);
392        self.data = Some(map);
393        self
394    }
395
396    /// Set the `404: Not Found` handler
397    pub fn add_not_found<H>(mut self, handler: H) -> Self
398    where
399        H: SharedService + Send + Sync + 'static,
400    {
401        self.not_found = Some(BoxedSharedService::new(handler));
402        self
403    }
404
405    /// Add handler for method and regex. Highest priority wins. Priority is 0 by default.
406    pub fn add<H>(self, method: http::Method, regex: &str, handler: H) -> Self
407    where
408        H: SharedService + Send + Sync + 'static,
409    {
410        self.add_with_priority(method, regex, 0, handler)
411    }
412
413    /// Add handler for method, regex, and priority. Highest priority wins.
414    pub fn add_with_priority<H>(
415        mut self,
416        method: http::Method,
417        regex: &str,
418        priority: i8,
419        handler: H,
420    ) -> Self
421    where
422        H: SharedService + Send + Sync + 'static,
423    {
424        let handler = BoxedSharedService::new(handler);
425        let entry = self
426            .inner
427            .remove(&method)
428            .unwrap_or_else(recognizer::Router::build)
429            .add_with_priority(regex, priority, handler);
430        self.inner.insert(method, entry);
431
432        self
433    }
434
435    /// Consumes the builder, returning the finished `Router`
436    pub fn finish(self) -> err::Result<Router> {
437        let mut inner_router = InnerRouter {
438            data: self.data.map(Data::new).map(DataMap),
439            not_found: self
440                .not_found
441                .unwrap_or_else(|| BoxedSharedService::new(Self::default_not_found)),
442            routers: MethodMap::new(),
443        };
444
445        for (method, builder) in self.inner {
446            let router = builder.finish().map_err(err::Error::Recognizer)?;
447            inner_router.routers.insert(method, router);
448        }
449
450        Ok(Router(Arc::new(inner_router)))
451    }
452}
453
454impl tower::Service<Request> for Router {
455    type Response = Response;
456    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
457    type Future =
458        std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
459
460    fn poll_ready(
461        &mut self,
462        _cx: &mut std::task::Context<'_>,
463    ) -> std::task::Poll<Result<(), Self::Error>> {
464        std::task::Poll::Ready(Ok(()))
465    }
466
467    fn call(&mut self, mut request: Request) -> Self::Future {
468        let service = Arc::clone(&self.0);
469
470        if let Some(router) = service.routers.get(request.method()) {
471            if let Ok(recognizer::Match { handler, captures }) =
472                router.recognize(request.uri().path())
473            {
474                let extensions_mut = request.extensions_mut();
475
476                if let Some(ref data) = self.0.data {
477                    extensions_mut.insert(data.clone());
478                }
479
480                extensions_mut.insert(captures);
481
482                return handler.0(request);
483            }
484        }
485
486        if let Some(ref data) = self.0.data {
487            request.extensions_mut().insert(data.clone());
488        }
489
490        service.not_found.0(request)
491    }
492}
493
494impl<'a> tower::Service<&'a hyper::server::conn::AddrStream> for Router {
495    type Response = Router;
496    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
497    type Future = Ready<Result<Self::Response, Self::Error>>;
498
499    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
500        Poll::Ready(Ok(()))
501    }
502
503    fn call(&mut self, _: &'a hyper::server::conn::AddrStream) -> Self::Future {
504        ready(Ok(Router(Arc::clone(&self.0))))
505    }
506}
507
508/// Extensions to `http::Request` and `http::request::Parts` to support easy access to captures and `State` object
509pub trait RequestExtensions {
510    /// Any captures provided by the matching `Regex` for the current path
511    fn captures(&self) -> Option<&recognizer::Captures>;
512    /// Positional captures parsed into `FromStr` types, in tuple format
513    fn parsed_captures<C: recognizer::FromCaptures>(&self) -> err::Result<C> {
514        let captures = self.captures().ok_or_else(|| err::Error::CapturesMissing)?;
515        Ok(C::from_captures(&*captures).map_err(err::Error::Recognizer)?)
516    }
517    /// Copy of any `data` passed into the router builder using `data` or `wrapped_data`
518    fn data<T: Send + Sync + 'static>(&self) -> Option<Data<T>>;
519}
520
521impl RequestExtensions for Request {
522    fn captures(&self) -> Option<&recognizer::Captures> {
523        self.extensions().get::<recognizer::Captures>()
524    }
525
526    fn data<T: Send + Sync + 'static>(&self) -> Option<Data<T>> {
527        self.extensions().get::<DataMap>().and_then(|x| x.0.get::<Data<T>>()).cloned()
528    }
529}
530
531impl RequestExtensions for http::request::Parts {
532    fn captures(&self) -> Option<&recognizer::Captures> {
533        self.extensions.get::<recognizer::Captures>()
534    }
535
536    fn data<T: Send + Sync + 'static>(&self) -> Option<Data<T>> {
537        self.extensions.get::<DataMap>().and_then(|x| x.0.get::<Data<T>>()).cloned()
538    }
539}