rama_http/service/web/
service.rs

1use super::{endpoint::Endpoint, IntoEndpointService};
2use crate::{
3    matcher::{HttpMatcher, UriParams},
4    service::fs::ServeDir,
5    Body, IntoResponse, Request, Response, StatusCode, Uri,
6};
7use rama_core::{
8    context::Extensions,
9    matcher::Matcher,
10    service::{service_fn, BoxService, Service},
11    Context,
12};
13use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, sync::Arc};
14
15/// A basic web service that can be used to serve HTTP requests.
16///
17/// Note that this service boxes all the internal services, so it is not as efficient as it could be.
18/// For those locations where you need do not desire the convenience over performance,
19/// you can instead use a tuple of `(M, S)` tuples, where M is a matcher and S is a service,
20/// e.g. `((MethodMatcher::GET, service_a), (MethodMatcher::POST, service_b), service_fallback)`.
21pub struct WebService<State> {
22    endpoints: Vec<Arc<Endpoint<State>>>,
23    not_found: Arc<BoxService<State, Request, Response, Infallible>>,
24    _phantom: PhantomData<State>,
25}
26
27impl<State> std::fmt::Debug for WebService<State> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("WebService").finish()
30    }
31}
32
33impl<State> Clone for WebService<State> {
34    fn clone(&self) -> Self {
35        Self {
36            endpoints: self.endpoints.clone(),
37            not_found: self.not_found.clone(),
38            _phantom: PhantomData,
39        }
40    }
41}
42
43impl<State> WebService<State>
44where
45    State: Clone + Send + Sync + 'static,
46{
47    /// create a new web service
48    pub(crate) fn new() -> Self {
49        Self {
50            endpoints: Vec::new(),
51            not_found: Arc::new(
52                service_fn(|| async { Ok(StatusCode::NOT_FOUND.into_response()) }).boxed(),
53            ),
54            _phantom: PhantomData,
55        }
56    }
57
58    /// add a GET route to the web service, using the given service.
59    pub fn get<I, T>(self, path: &str, service: I) -> Self
60    where
61        I: IntoEndpointService<State, T>,
62    {
63        let matcher = HttpMatcher::method_get().and_path(path);
64        self.on(matcher, service)
65    }
66
67    /// add a POST route to the web service, using the given service.
68    pub fn post<I, T>(self, path: &str, service: I) -> Self
69    where
70        I: IntoEndpointService<State, T>,
71    {
72        let matcher = HttpMatcher::method_post().and_path(path);
73        self.on(matcher, service)
74    }
75
76    /// add a PUT route to the web service, using the given service.
77    pub fn put<I, T>(self, path: &str, service: I) -> Self
78    where
79        I: IntoEndpointService<State, T>,
80    {
81        let matcher = HttpMatcher::method_put().and_path(path);
82        self.on(matcher, service)
83    }
84
85    /// add a DELETE route to the web service, using the given service.
86    pub fn delete<I, T>(self, path: &str, service: I) -> Self
87    where
88        I: IntoEndpointService<State, T>,
89    {
90        let matcher = HttpMatcher::method_delete().and_path(path);
91        self.on(matcher, service)
92    }
93
94    /// add a PATCH route to the web service, using the given service.
95    pub fn patch<I, T>(self, path: &str, service: I) -> Self
96    where
97        I: IntoEndpointService<State, T>,
98    {
99        let matcher = HttpMatcher::method_patch().and_path(path);
100        self.on(matcher, service)
101    }
102
103    /// add a HEAD route to the web service, using the given service.
104    pub fn head<I, T>(self, path: &str, service: I) -> Self
105    where
106        I: IntoEndpointService<State, T>,
107    {
108        let matcher = HttpMatcher::method_head().and_path(path);
109        self.on(matcher, service)
110    }
111
112    /// add a OPTIONS route to the web service, using the given service.
113    pub fn options<I, T>(self, path: &str, service: I) -> Self
114    where
115        I: IntoEndpointService<State, T>,
116    {
117        let matcher = HttpMatcher::method_options().and_path(path);
118        self.on(matcher, service)
119    }
120
121    /// add a TRACE route to the web service, using the given service.
122    pub fn trace<I, T>(self, path: &str, service: I) -> Self
123    where
124        I: IntoEndpointService<State, T>,
125    {
126        let matcher = HttpMatcher::method_trace().and_path(path);
127        self.on(matcher, service)
128    }
129
130    /// nest a web service under the given path.
131    ///
132    /// The nested service will receive a request with the path prefix removed.
133    pub fn nest<I, T>(self, prefix: &str, service: I) -> Self
134    where
135        I: IntoEndpointService<State, T>,
136    {
137        let prefix = format!("{}/*", prefix.trim_end_matches(['/', '*']));
138        let matcher = HttpMatcher::path(prefix);
139        let service = NestedService(service.into_endpoint_service());
140        self.on(matcher, service)
141    }
142
143    /// serve the given directory under the given path.
144    pub fn dir(self, prefix: &str, dir: &str) -> Self {
145        let service = ServeDir::new(dir).fallback(self.not_found.clone());
146        self.nest(prefix, service)
147    }
148
149    /// add a route to the web service which matches the given matcher, using the given service.
150    pub fn on<I, T>(mut self, matcher: HttpMatcher<State, Body>, service: I) -> Self
151    where
152        I: IntoEndpointService<State, T>,
153    {
154        let endpoint = Endpoint {
155            matcher,
156            service: service.into_endpoint_service().boxed(),
157        };
158        self.endpoints.push(Arc::new(endpoint));
159        self
160    }
161
162    /// use the given service in case no match could be found.
163    pub fn not_found<I, T>(mut self, service: I) -> Self
164    where
165        I: IntoEndpointService<State, T>,
166    {
167        self.not_found = Arc::new(service.into_endpoint_service().boxed());
168        self
169    }
170}
171
172struct NestedService<S>(S);
173
174impl<S: fmt::Debug> fmt::Debug for NestedService<S> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_tuple("NestedService").field(&self.0).finish()
177    }
178}
179
180impl<S: Clone> Clone for NestedService<S> {
181    fn clone(&self) -> Self {
182        NestedService(self.0.clone())
183    }
184}
185
186impl<S, State> Service<State, Request> for NestedService<S>
187where
188    S: Service<State, Request>,
189{
190    type Response = S::Response;
191    type Error = S::Error;
192
193    fn serve(
194        &self,
195        ctx: Context<State>,
196        req: Request,
197    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
198        // get nested path
199        let path = ctx.get::<UriParams>().unwrap().glob().unwrap();
200
201        // set the nested path
202        let (mut parts, body) = req.into_parts();
203        let mut uri_parts = parts.uri.into_parts();
204        let path_and_query = uri_parts.path_and_query.take().unwrap();
205        match path_and_query.query() {
206            Some(query) => {
207                uri_parts.path_and_query = Some(format!("{}?{}", path, query).parse().unwrap());
208            }
209            None => {
210                uri_parts.path_and_query = Some(path.parse().unwrap());
211            }
212        }
213        parts.uri = Uri::from_parts(uri_parts).unwrap();
214        let req = Request::from_parts(parts, body);
215
216        // make the actual request
217        self.0.serve(ctx, req)
218    }
219}
220
221impl<State> Default for WebService<State>
222where
223    State: Clone + Send + Sync + 'static,
224{
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230impl<State> Service<State, Request> for WebService<State>
231where
232    State: Clone + Send + Sync + 'static,
233{
234    type Response = Response;
235    type Error = Infallible;
236
237    async fn serve(
238        &self,
239        mut ctx: Context<State>,
240        req: Request,
241    ) -> Result<Self::Response, Self::Error> {
242        let mut ext = Extensions::new();
243        for endpoint in &self.endpoints {
244            if endpoint.matcher.matches(Some(&mut ext), &ctx, &req) {
245                // insert the extensions that might be generated by the matcher(s) into the context
246                ctx.extend(ext);
247                return endpoint.service.serve(ctx, req).await;
248            }
249            // clear the extensions for the next matcher
250            ext.clear();
251        }
252        self.not_found.serve(ctx, req).await
253    }
254}
255
256#[doc(hidden)]
257#[macro_export]
258/// Create a new [`Service`] from a chain of matcher-service tuples.
259///
260/// Think of it like the Rust match statement, but for http services.
261/// Which is nothing more then a convenient wrapper to create a tuple of matcher-service tuples,
262/// with the last tuple being the fallback service. And all services implement
263/// the [`IntoEndpointService`] trait.
264///
265/// # Example
266///
267/// ```rust
268/// use rama_http::matcher::{HttpMatcher, MethodMatcher};
269/// use rama_http::{Body, Request, Response, StatusCode};
270/// use rama_http::dep::http_body_util::BodyExt;
271/// use rama_core::{Context, Service};
272///
273/// #[tokio::main]
274/// async fn main() {
275///   let svc = rama_http::service::web::match_service! {
276///     HttpMatcher::get("/hello") => "hello",
277///     HttpMatcher::post("/world") => "world",
278///     MethodMatcher::CONNECT => "connect",
279///     _ => StatusCode::NOT_FOUND,
280///   };
281///
282///   let resp = svc.serve(
283///       Context::default(),
284///       Request::post("https://www.test.io/world").body(Body::empty()).unwrap(),
285///   ).await.unwrap();
286///   assert_eq!(resp.status(), StatusCode::OK);
287///   let body = resp.into_body().collect().await.unwrap().to_bytes();
288///   assert_eq!(body, "world");
289/// }
290/// ```
291///
292/// Which is short for the following:
293///
294/// ```rust
295/// use rama_http::matcher::{HttpMatcher, MethodMatcher};
296/// use rama_http::{Body, Request, Response, StatusCode};
297/// use rama_http::dep::http_body_util::BodyExt;
298/// use rama_http::service::web::IntoEndpointService;
299/// use rama_core::{Context, Service};
300///
301/// #[tokio::main]
302/// async fn main() {
303///   let svc = (
304///     (HttpMatcher::get("/hello"), "hello".into_endpoint_service()),
305///     (HttpMatcher::post("/world"), "world".into_endpoint_service()),
306///     (MethodMatcher::CONNECT, "connect".into_endpoint_service()),
307///     StatusCode::NOT_FOUND.into_endpoint_service(),
308///   );
309///
310///   let resp = svc.serve(
311///      Context::default(),
312///      Request::post("https://www.test.io/world").body(Body::empty()).unwrap(),
313///   ).await.unwrap();
314///   assert_eq!(resp.status(), StatusCode::OK);
315///   let body = resp.into_body().collect().await.unwrap().to_bytes();
316///   assert_eq!(body, "world");
317/// }
318/// ```
319///
320/// As you can see it is pretty much the same, except that you need to explicitly ensure
321/// that each service is an actual Endpoint service.
322macro_rules! __match_service {
323    ($($M:expr => $S:expr),+, _ => $F:expr $(,)?) => {{
324        use $crate::service::web::IntoEndpointService;
325        ($(($M, $S.into_endpoint_service())),+, $F.into_endpoint_service())
326    }};
327}
328
329#[doc(inline)]
330pub use crate::__match_service as match_service;
331
332#[cfg(test)]
333mod test {
334    use crate::dep::http_body_util::BodyExt;
335    use crate::matcher::MethodMatcher;
336    use crate::Body;
337
338    use super::*;
339
340    async fn get_response<S>(service: &S, uri: &str) -> Response
341    where
342        S: Service<(), Request, Response = Response, Error = Infallible>,
343    {
344        let req = Request::get(uri).body(Body::empty()).unwrap();
345        service.serve(Context::default(), req).await.unwrap()
346    }
347
348    async fn post_response<S>(service: &S, uri: &str) -> Response
349    where
350        S: Service<(), Request, Response = Response, Error = Infallible>,
351    {
352        let req = Request::post(uri).body(Body::empty()).unwrap();
353        service.serve(Context::default(), req).await.unwrap()
354    }
355
356    async fn connect_response<S>(service: &S, uri: &str) -> Response
357    where
358        S: Service<(), Request, Response = Response, Error = Infallible>,
359    {
360        let req = Request::connect(uri).body(Body::empty()).unwrap();
361        service.serve(Context::default(), req).await.unwrap()
362    }
363
364    #[tokio::test]
365    async fn test_web_service() {
366        let svc = WebService::new()
367            .get("/hello", "hello")
368            .post("/world", "world");
369
370        let res = get_response(&svc, "https://www.test.io/hello").await;
371        assert_eq!(res.status(), StatusCode::OK);
372        let body = res.into_body().collect().await.unwrap().to_bytes();
373        assert_eq!(body, "hello");
374
375        let res = post_response(&svc, "https://www.test.io/world").await;
376        assert_eq!(res.status(), StatusCode::OK);
377        let body = res.into_body().collect().await.unwrap().to_bytes();
378        assert_eq!(body, "world");
379
380        let res = get_response(&svc, "https://www.test.io/world").await;
381        assert_eq!(res.status(), StatusCode::NOT_FOUND);
382
383        let res = get_response(&svc, "https://www.test.io").await;
384        assert_eq!(res.status(), StatusCode::NOT_FOUND);
385    }
386
387    #[tokio::test]
388    async fn test_web_service_not_found() {
389        let svc = WebService::new().not_found("not found");
390
391        let res = get_response(&svc, "https://www.test.io/hello").await;
392        assert_eq!(res.status(), StatusCode::OK);
393        let body = res.into_body().collect().await.unwrap().to_bytes();
394        assert_eq!(body, "not found");
395    }
396
397    #[tokio::test]
398    async fn test_web_service_nest() {
399        let svc = WebService::new().nest(
400            "/api",
401            WebService::new()
402                .get("/hello", "hello")
403                .post("/world", "world"),
404        );
405
406        let res = get_response(&svc, "https://www.test.io/api/hello").await;
407        assert_eq!(res.status(), StatusCode::OK);
408        let body = res.into_body().collect().await.unwrap().to_bytes();
409        assert_eq!(body, "hello");
410
411        let res = post_response(&svc, "https://www.test.io/api/world").await;
412        assert_eq!(res.status(), StatusCode::OK);
413        let body = res.into_body().collect().await.unwrap().to_bytes();
414        assert_eq!(body, "world");
415
416        let res = get_response(&svc, "https://www.test.io/api/world").await;
417        assert_eq!(res.status(), StatusCode::NOT_FOUND);
418
419        let res = get_response(&svc, "https://www.test.io").await;
420        assert_eq!(res.status(), StatusCode::NOT_FOUND);
421    }
422
423    #[tokio::test]
424    async fn test_web_service_dir() {
425        let tmp_dir = tempfile::tempdir().unwrap();
426        let file_path = tmp_dir.path().join("index.html");
427        std::fs::write(&file_path, "<h1>Hello, World!</h1>").unwrap();
428        let style_dir = tmp_dir.path().join("style");
429        std::fs::create_dir(&style_dir).unwrap();
430        let file_path = style_dir.join("main.css");
431        std::fs::write(&file_path, "body { background-color: red }").unwrap();
432
433        let svc = WebService::new()
434            .get("/api/version", "v1")
435            .post("/api", StatusCode::FORBIDDEN)
436            .dir("/", tmp_dir.path().to_str().unwrap());
437
438        let res = get_response(&svc, "https://www.test.io/index.html").await;
439        assert_eq!(res.status(), StatusCode::OK);
440        let body = res.into_body().collect().await.unwrap().to_bytes();
441        assert_eq!(body, "<h1>Hello, World!</h1>");
442
443        let res = get_response(&svc, "https://www.test.io/style/main.css").await;
444        assert_eq!(res.status(), StatusCode::OK);
445        let body = res.into_body().collect().await.unwrap().to_bytes();
446        assert_eq!(body, "body { background-color: red }");
447
448        let res = get_response(&svc, "https://www.test.io/api/version").await;
449        assert_eq!(res.status(), StatusCode::OK);
450        let body = res.into_body().collect().await.unwrap().to_bytes();
451        assert_eq!(body, "v1");
452
453        let res = post_response(&svc, "https://www.test.io/api").await;
454        assert_eq!(res.status(), StatusCode::FORBIDDEN);
455
456        let res = get_response(&svc, "https://www.test.io/notfound.html").await;
457        assert_eq!(res.status(), StatusCode::NOT_FOUND);
458
459        let res = get_response(&svc, "https://www.test.io/").await;
460        assert_eq!(res.status(), StatusCode::OK);
461        let body = res.into_body().collect().await.unwrap().to_bytes();
462        assert_eq!(body, "<h1>Hello, World!</h1>");
463    }
464
465    #[tokio::test]
466    async fn test_matcher_service_tuples() {
467        let svc = match_service! {
468            HttpMatcher::get("/hello") => "hello",
469            HttpMatcher::post("/world") => "world",
470            MethodMatcher::CONNECT => "connect",
471            _ => StatusCode::NOT_FOUND,
472        };
473
474        let res = get_response(&svc, "https://www.test.io/hello").await;
475        assert_eq!(res.status(), StatusCode::OK);
476        let body = res.into_body().collect().await.unwrap().to_bytes();
477        assert_eq!(body, "hello");
478
479        let res = post_response(&svc, "https://www.test.io/world").await;
480        assert_eq!(res.status(), StatusCode::OK);
481        let body = res.into_body().collect().await.unwrap().to_bytes();
482        assert_eq!(body, "world");
483
484        let res = connect_response(&svc, "https://www.test.io").await;
485        assert_eq!(res.status(), StatusCode::OK);
486        let body = res.into_body().collect().await.unwrap().to_bytes();
487        assert_eq!(body, "connect");
488
489        let res = get_response(&svc, "https://www.test.io/world").await;
490        assert_eq!(res.status(), StatusCode::NOT_FOUND);
491
492        let res = get_response(&svc, "https://www.test.io").await;
493        assert_eq!(res.status(), StatusCode::NOT_FOUND);
494    }
495}