rama_http/service/web/
service.rs

1use super::{IntoEndpointService, endpoint::Endpoint};
2use crate::{
3    Body, Request, Response, StatusCode, Uri,
4    matcher::{HttpMatcher, UriParams},
5    service::fs::ServeDir,
6    service::web::endpoint::response::IntoResponse,
7};
8use rama_core::{
9    Context,
10    context::Extensions,
11    matcher::Matcher,
12    service::{BoxService, Service, service_fn},
13};
14use std::{convert::Infallible, fmt, marker::PhantomData, sync::Arc};
15
16/// A basic web service that can be used to serve HTTP requests.
17///
18/// Note that this service boxes all the internal services, so it is not as efficient as it could be.
19/// For those locations where you need do not desire the convenience over performance,
20/// you can instead use a tuple of `(M, S)` tuples, where M is a matcher and S is a service,
21/// e.g. `((MethodMatcher::GET, service_a), (MethodMatcher::POST, service_b), service_fallback)`.
22pub struct WebService<State> {
23    endpoints: Vec<Arc<Endpoint<State>>>,
24    not_found: Arc<BoxService<State, Request, Response, Infallible>>,
25    _phantom: PhantomData<State>,
26}
27
28impl<State> std::fmt::Debug for WebService<State> {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("WebService").finish()
31    }
32}
33
34impl<State> Clone for WebService<State> {
35    fn clone(&self) -> Self {
36        Self {
37            endpoints: self.endpoints.clone(),
38            not_found: self.not_found.clone(),
39            _phantom: PhantomData,
40        }
41    }
42}
43
44impl<State> WebService<State>
45where
46    State: Clone + Send + Sync + 'static,
47{
48    /// create a new web service
49    pub(crate) fn new() -> Self {
50        Self {
51            endpoints: Vec::new(),
52            not_found: Arc::new(
53                service_fn(async || Ok(StatusCode::NOT_FOUND.into_response())).boxed(),
54            ),
55            _phantom: PhantomData,
56        }
57    }
58
59    /// add a GET route to the web service, using the given service.
60    pub fn get<I, T>(self, path: &str, service: I) -> Self
61    where
62        I: IntoEndpointService<State, T>,
63    {
64        let matcher = HttpMatcher::method_get().and_path(path);
65        self.on(matcher, service)
66    }
67
68    /// add a POST route to the web service, using the given service.
69    pub fn post<I, T>(self, path: &str, service: I) -> Self
70    where
71        I: IntoEndpointService<State, T>,
72    {
73        let matcher = HttpMatcher::method_post().and_path(path);
74        self.on(matcher, service)
75    }
76
77    /// add a PUT route to the web service, using the given service.
78    pub fn put<I, T>(self, path: &str, service: I) -> Self
79    where
80        I: IntoEndpointService<State, T>,
81    {
82        let matcher = HttpMatcher::method_put().and_path(path);
83        self.on(matcher, service)
84    }
85
86    /// add a DELETE route to the web service, using the given service.
87    pub fn delete<I, T>(self, path: &str, service: I) -> Self
88    where
89        I: IntoEndpointService<State, T>,
90    {
91        let matcher = HttpMatcher::method_delete().and_path(path);
92        self.on(matcher, service)
93    }
94
95    /// add a PATCH route to the web service, using the given service.
96    pub fn patch<I, T>(self, path: &str, service: I) -> Self
97    where
98        I: IntoEndpointService<State, T>,
99    {
100        let matcher = HttpMatcher::method_patch().and_path(path);
101        self.on(matcher, service)
102    }
103
104    /// add a HEAD route to the web service, using the given service.
105    pub fn head<I, T>(self, path: &str, service: I) -> Self
106    where
107        I: IntoEndpointService<State, T>,
108    {
109        let matcher = HttpMatcher::method_head().and_path(path);
110        self.on(matcher, service)
111    }
112
113    /// add a OPTIONS route to the web service, using the given service.
114    pub fn options<I, T>(self, path: &str, service: I) -> Self
115    where
116        I: IntoEndpointService<State, T>,
117    {
118        let matcher = HttpMatcher::method_options().and_path(path);
119        self.on(matcher, service)
120    }
121
122    /// add a TRACE route to the web service, using the given service.
123    pub fn trace<I, T>(self, path: &str, service: I) -> Self
124    where
125        I: IntoEndpointService<State, T>,
126    {
127        let matcher = HttpMatcher::method_trace().and_path(path);
128        self.on(matcher, service)
129    }
130
131    /// nest a web service under the given path.
132    ///
133    /// The nested service will receive a request with the path prefix removed.
134    pub fn nest<I, T>(self, prefix: &str, service: I) -> Self
135    where
136        I: IntoEndpointService<State, T>,
137    {
138        let prefix = format!("{}/*", prefix.trim_end_matches(['/', '*']));
139        let matcher = HttpMatcher::path(prefix);
140        let service = NestedService(service.into_endpoint_service());
141        self.on(matcher, service)
142    }
143
144    /// serve the given directory under the given path.
145    pub fn dir(self, prefix: &str, dir: &str) -> Self {
146        let service = ServeDir::new(dir).fallback(self.not_found.clone());
147        self.nest(prefix, service)
148    }
149
150    /// add a route to the web service which matches the given matcher, using the given service.
151    pub fn on<I, T>(mut self, matcher: HttpMatcher<State, Body>, service: I) -> Self
152    where
153        I: IntoEndpointService<State, T>,
154    {
155        let endpoint = Endpoint {
156            matcher,
157            service: service.into_endpoint_service().boxed(),
158        };
159        self.endpoints.push(Arc::new(endpoint));
160        self
161    }
162
163    /// use the given service in case no match could be found.
164    pub fn not_found<I, T>(mut self, service: I) -> Self
165    where
166        I: IntoEndpointService<State, T>,
167    {
168        self.not_found = Arc::new(service.into_endpoint_service().boxed());
169        self
170    }
171}
172
173struct NestedService<S>(S);
174
175impl<S: fmt::Debug> fmt::Debug for NestedService<S> {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        f.debug_tuple("NestedService").field(&self.0).finish()
178    }
179}
180
181impl<S: Clone> Clone for NestedService<S> {
182    fn clone(&self) -> Self {
183        NestedService(self.0.clone())
184    }
185}
186
187impl<S, State> Service<State, Request> for NestedService<S>
188where
189    S: Service<State, Request>,
190{
191    type Response = S::Response;
192    type Error = S::Error;
193
194    fn serve(
195        &self,
196        ctx: Context<State>,
197        req: Request,
198    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
199        // get nested path
200        let path = ctx.get::<UriParams>().unwrap().glob().unwrap();
201
202        // set the nested path
203        let (mut parts, body) = req.into_parts();
204        let mut uri_parts = parts.uri.into_parts();
205        let path_and_query = uri_parts.path_and_query.take().unwrap();
206        match path_and_query.query() {
207            Some(query) => {
208                uri_parts.path_and_query = Some(format!("{}?{}", path, query).parse().unwrap());
209            }
210            None => {
211                uri_parts.path_and_query = Some(path.parse().unwrap());
212            }
213        }
214        parts.uri = Uri::from_parts(uri_parts).unwrap();
215        let req = Request::from_parts(parts, body);
216
217        // make the actual request
218        self.0.serve(ctx, req)
219    }
220}
221
222impl<State> Default for WebService<State>
223where
224    State: Clone + Send + Sync + 'static,
225{
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231impl<State> Service<State, Request> for WebService<State>
232where
233    State: Clone + Send + Sync + 'static,
234{
235    type Response = Response;
236    type Error = Infallible;
237
238    async fn serve(
239        &self,
240        mut ctx: Context<State>,
241        req: Request,
242    ) -> Result<Self::Response, Self::Error> {
243        let mut ext = Extensions::new();
244        for endpoint in &self.endpoints {
245            if endpoint.matcher.matches(Some(&mut ext), &ctx, &req) {
246                // insert the extensions that might be generated by the matcher(s) into the context
247                ctx.extend(ext);
248                return endpoint.service.serve(ctx, req).await;
249            }
250            // clear the extensions for the next matcher
251            ext.clear();
252        }
253        self.not_found.serve(ctx, req).await
254    }
255}
256
257#[doc(hidden)]
258#[macro_export]
259/// Create a new [`Service`] from a chain of matcher-service tuples.
260///
261/// Think of it like the Rust match statement, but for http services.
262/// Which is nothing more then a convenient wrapper to create a tuple of matcher-service tuples,
263/// with the last tuple being the fallback service. And all services implement
264/// the [`IntoEndpointService`] trait.
265///
266/// # Example
267///
268/// ```rust
269/// use rama_http::matcher::{HttpMatcher, MethodMatcher};
270/// use rama_http::{Body, Request, Response, StatusCode};
271/// use rama_http::dep::http_body_util::BodyExt;
272/// use rama_core::{Context, Service};
273///
274/// #[tokio::main]
275/// async fn main() {
276///   let svc = rama_http::service::web::match_service! {
277///     HttpMatcher::get("/hello") => "hello",
278///     HttpMatcher::post("/world") => "world",
279///     MethodMatcher::CONNECT => "connect",
280///     _ => StatusCode::NOT_FOUND,
281///   };
282///
283///   let resp = svc.serve(
284///       Context::default(),
285///       Request::post("https://www.test.io/world").body(Body::empty()).unwrap(),
286///   ).await.unwrap();
287///   assert_eq!(resp.status(), StatusCode::OK);
288///   let body = resp.into_body().collect().await.unwrap().to_bytes();
289///   assert_eq!(body, "world");
290/// }
291/// ```
292///
293/// Which is short for the following:
294///
295/// ```rust
296/// use rama_http::matcher::{HttpMatcher, MethodMatcher};
297/// use rama_http::{Body, Request, Response, StatusCode};
298/// use rama_http::dep::http_body_util::BodyExt;
299/// use rama_http::service::web::IntoEndpointService;
300/// use rama_core::{Context, Service};
301/// use rama_core::matcher::MatcherRouter;
302///
303/// #[tokio::main]
304/// async fn main() {
305///   let svc = MatcherRouter((
306///     (HttpMatcher::get("/hello"), "hello".into_endpoint_service()),
307///     (HttpMatcher::post("/world"), "world".into_endpoint_service()),
308///     (MethodMatcher::CONNECT, "connect".into_endpoint_service()),
309///     StatusCode::NOT_FOUND.into_endpoint_service(),
310///   ));
311///
312///   let resp = svc.serve(
313///      Context::default(),
314///      Request::post("https://www.test.io/world").body(Body::empty()).unwrap(),
315///   ).await.unwrap();
316///   assert_eq!(resp.status(), StatusCode::OK);
317///   let body = resp.into_body().collect().await.unwrap().to_bytes();
318///   assert_eq!(body, "world");
319/// }
320/// ```
321///
322/// As you can see it is pretty much the same, except that you need to explicitly ensure
323/// that each service is an actual Endpoint service.
324macro_rules! __match_service {
325    ($($M:expr_2021 => $S:expr_2021),+, _ => $F:expr $(,)?) => {{
326        use $crate::service::web::IntoEndpointService;
327        use $crate::dep::core::matcher::MatcherRouter;
328        MatcherRouter(($(($M, $S.into_endpoint_service())),+, $F.into_endpoint_service()))
329    }};
330}
331
332#[doc(inline)]
333pub use crate::__match_service as match_service;
334
335#[cfg(test)]
336mod test {
337    use crate::Body;
338    use crate::dep::http_body_util::BodyExt;
339    use crate::matcher::MethodMatcher;
340
341    use super::*;
342
343    async fn get_response<S>(service: &S, uri: &str) -> Response
344    where
345        S: Service<(), Request, Response = Response, Error = Infallible>,
346    {
347        let req = Request::get(uri).body(Body::empty()).unwrap();
348        service.serve(Context::default(), req).await.unwrap()
349    }
350
351    async fn post_response<S>(service: &S, uri: &str) -> Response
352    where
353        S: Service<(), Request, Response = Response, Error = Infallible>,
354    {
355        let req = Request::post(uri).body(Body::empty()).unwrap();
356        service.serve(Context::default(), req).await.unwrap()
357    }
358
359    async fn connect_response<S>(service: &S, uri: &str) -> Response
360    where
361        S: Service<(), Request, Response = Response, Error = Infallible>,
362    {
363        let req = Request::connect(uri).body(Body::empty()).unwrap();
364        service.serve(Context::default(), req).await.unwrap()
365    }
366
367    #[tokio::test]
368    async fn test_web_service() {
369        let svc = WebService::new()
370            .get("/hello", "hello")
371            .post("/world", "world");
372
373        let res = get_response(&svc, "https://www.test.io/hello").await;
374        assert_eq!(res.status(), StatusCode::OK);
375        let body = res.into_body().collect().await.unwrap().to_bytes();
376        assert_eq!(body, "hello");
377
378        let res = post_response(&svc, "https://www.test.io/world").await;
379        assert_eq!(res.status(), StatusCode::OK);
380        let body = res.into_body().collect().await.unwrap().to_bytes();
381        assert_eq!(body, "world");
382
383        let res = get_response(&svc, "https://www.test.io/world").await;
384        assert_eq!(res.status(), StatusCode::NOT_FOUND);
385
386        let res = get_response(&svc, "https://www.test.io").await;
387        assert_eq!(res.status(), StatusCode::NOT_FOUND);
388    }
389
390    #[tokio::test]
391    async fn test_web_service_not_found() {
392        let svc = WebService::new().not_found("not found");
393
394        let res = get_response(&svc, "https://www.test.io/hello").await;
395        assert_eq!(res.status(), StatusCode::OK);
396        let body = res.into_body().collect().await.unwrap().to_bytes();
397        assert_eq!(body, "not found");
398    }
399
400    #[tokio::test]
401    async fn test_web_service_nest() {
402        let svc = WebService::new().nest(
403            "/api",
404            WebService::new()
405                .get("/hello", "hello")
406                .post("/world", "world"),
407        );
408
409        let res = get_response(&svc, "https://www.test.io/api/hello").await;
410        assert_eq!(res.status(), StatusCode::OK);
411        let body = res.into_body().collect().await.unwrap().to_bytes();
412        assert_eq!(body, "hello");
413
414        let res = post_response(&svc, "https://www.test.io/api/world").await;
415        assert_eq!(res.status(), StatusCode::OK);
416        let body = res.into_body().collect().await.unwrap().to_bytes();
417        assert_eq!(body, "world");
418
419        let res = get_response(&svc, "https://www.test.io/api/world").await;
420        assert_eq!(res.status(), StatusCode::NOT_FOUND);
421
422        let res = get_response(&svc, "https://www.test.io").await;
423        assert_eq!(res.status(), StatusCode::NOT_FOUND);
424    }
425
426    #[tokio::test]
427    async fn test_web_service_dir() {
428        let tmp_dir = tempfile::tempdir().unwrap();
429        let file_path = tmp_dir.path().join("index.html");
430        std::fs::write(&file_path, "<h1>Hello, World!</h1>").unwrap();
431        let style_dir = tmp_dir.path().join("style");
432        std::fs::create_dir(&style_dir).unwrap();
433        let file_path = style_dir.join("main.css");
434        std::fs::write(&file_path, "body { background-color: red }").unwrap();
435
436        let svc = WebService::new()
437            .get("/api/version", "v1")
438            .post("/api", StatusCode::FORBIDDEN)
439            .dir("/", tmp_dir.path().to_str().unwrap());
440
441        let res = get_response(&svc, "https://www.test.io/index.html").await;
442        assert_eq!(res.status(), StatusCode::OK);
443        let body = res.into_body().collect().await.unwrap().to_bytes();
444        assert_eq!(body, "<h1>Hello, World!</h1>");
445
446        let res = get_response(&svc, "https://www.test.io/style/main.css").await;
447        assert_eq!(res.status(), StatusCode::OK);
448        let body = res.into_body().collect().await.unwrap().to_bytes();
449        assert_eq!(body, "body { background-color: red }");
450
451        let res = get_response(&svc, "https://www.test.io/api/version").await;
452        assert_eq!(res.status(), StatusCode::OK);
453        let body = res.into_body().collect().await.unwrap().to_bytes();
454        assert_eq!(body, "v1");
455
456        let res = post_response(&svc, "https://www.test.io/api").await;
457        assert_eq!(res.status(), StatusCode::FORBIDDEN);
458
459        let res = get_response(&svc, "https://www.test.io/notfound.html").await;
460        assert_eq!(res.status(), StatusCode::NOT_FOUND);
461
462        let res = get_response(&svc, "https://www.test.io/").await;
463        assert_eq!(res.status(), StatusCode::OK);
464        let body = res.into_body().collect().await.unwrap().to_bytes();
465        assert_eq!(body, "<h1>Hello, World!</h1>");
466    }
467
468    #[tokio::test]
469    async fn test_matcher_service_tuples() {
470        let svc = match_service! {
471            HttpMatcher::get("/hello") => "hello",
472            HttpMatcher::post("/world") => "world",
473            MethodMatcher::CONNECT => "connect",
474            _ => StatusCode::NOT_FOUND,
475        };
476
477        let res = get_response(&svc, "https://www.test.io/hello").await;
478        assert_eq!(res.status(), StatusCode::OK);
479        let body = res.into_body().collect().await.unwrap().to_bytes();
480        assert_eq!(body, "hello");
481
482        let res = post_response(&svc, "https://www.test.io/world").await;
483        assert_eq!(res.status(), StatusCode::OK);
484        let body = res.into_body().collect().await.unwrap().to_bytes();
485        assert_eq!(body, "world");
486
487        let res = connect_response(&svc, "https://www.test.io").await;
488        assert_eq!(res.status(), StatusCode::OK);
489        let body = res.into_body().collect().await.unwrap().to_bytes();
490        assert_eq!(body, "connect");
491
492        let res = get_response(&svc, "https://www.test.io/world").await;
493        assert_eq!(res.status(), StatusCode::NOT_FOUND);
494
495        let res = get_response(&svc, "https://www.test.io").await;
496        assert_eq!(res.status(), StatusCode::NOT_FOUND);
497    }
498}