1use super::{IntoEndpointService, endpoint::Endpoint};
2use crate::{
3 Body, IntoResponse, Request, Response, StatusCode, Uri,
4 matcher::{HttpMatcher, UriParams},
5 service::fs::ServeDir,
6};
7use rama_core::{
8 Context,
9 context::Extensions,
10 matcher::Matcher,
11 service::{BoxService, Service, service_fn},
12};
13use std::{convert::Infallible, fmt, marker::PhantomData, sync::Arc};
14
15pub struct WebService<State> {
22 endpoints: Vec<Arc<Endpoint<State>>>,
23 not_found: Arc<BoxService<State, Request, Response, Infallible>>,
24 _phantom: PhantomData<State>,
25}
26
27impl<State> std::fmt::Debug for WebService<State> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("WebService").finish()
30 }
31}
32
33impl<State> Clone for WebService<State> {
34 fn clone(&self) -> Self {
35 Self {
36 endpoints: self.endpoints.clone(),
37 not_found: self.not_found.clone(),
38 _phantom: PhantomData,
39 }
40 }
41}
42
43impl<State> WebService<State>
44where
45 State: Clone + Send + Sync + 'static,
46{
47 pub(crate) fn new() -> Self {
49 Self {
50 endpoints: Vec::new(),
51 not_found: Arc::new(
52 service_fn(async || Ok(StatusCode::NOT_FOUND.into_response())).boxed(),
53 ),
54 _phantom: PhantomData,
55 }
56 }
57
58 pub fn get<I, T>(self, path: &str, service: I) -> Self
60 where
61 I: IntoEndpointService<State, T>,
62 {
63 let matcher = HttpMatcher::method_get().and_path(path);
64 self.on(matcher, service)
65 }
66
67 pub fn post<I, T>(self, path: &str, service: I) -> Self
69 where
70 I: IntoEndpointService<State, T>,
71 {
72 let matcher = HttpMatcher::method_post().and_path(path);
73 self.on(matcher, service)
74 }
75
76 pub fn put<I, T>(self, path: &str, service: I) -> Self
78 where
79 I: IntoEndpointService<State, T>,
80 {
81 let matcher = HttpMatcher::method_put().and_path(path);
82 self.on(matcher, service)
83 }
84
85 pub fn delete<I, T>(self, path: &str, service: I) -> Self
87 where
88 I: IntoEndpointService<State, T>,
89 {
90 let matcher = HttpMatcher::method_delete().and_path(path);
91 self.on(matcher, service)
92 }
93
94 pub fn patch<I, T>(self, path: &str, service: I) -> Self
96 where
97 I: IntoEndpointService<State, T>,
98 {
99 let matcher = HttpMatcher::method_patch().and_path(path);
100 self.on(matcher, service)
101 }
102
103 pub fn head<I, T>(self, path: &str, service: I) -> Self
105 where
106 I: IntoEndpointService<State, T>,
107 {
108 let matcher = HttpMatcher::method_head().and_path(path);
109 self.on(matcher, service)
110 }
111
112 pub fn options<I, T>(self, path: &str, service: I) -> Self
114 where
115 I: IntoEndpointService<State, T>,
116 {
117 let matcher = HttpMatcher::method_options().and_path(path);
118 self.on(matcher, service)
119 }
120
121 pub fn trace<I, T>(self, path: &str, service: I) -> Self
123 where
124 I: IntoEndpointService<State, T>,
125 {
126 let matcher = HttpMatcher::method_trace().and_path(path);
127 self.on(matcher, service)
128 }
129
130 pub fn nest<I, T>(self, prefix: &str, service: I) -> Self
134 where
135 I: IntoEndpointService<State, T>,
136 {
137 let prefix = format!("{}/*", prefix.trim_end_matches(['/', '*']));
138 let matcher = HttpMatcher::path(prefix);
139 let service = NestedService(service.into_endpoint_service());
140 self.on(matcher, service)
141 }
142
143 pub fn dir(self, prefix: &str, dir: &str) -> Self {
145 let service = ServeDir::new(dir).fallback(self.not_found.clone());
146 self.nest(prefix, service)
147 }
148
149 pub fn on<I, T>(mut self, matcher: HttpMatcher<State, Body>, service: I) -> Self
151 where
152 I: IntoEndpointService<State, T>,
153 {
154 let endpoint = Endpoint {
155 matcher,
156 service: service.into_endpoint_service().boxed(),
157 };
158 self.endpoints.push(Arc::new(endpoint));
159 self
160 }
161
162 pub fn not_found<I, T>(mut self, service: I) -> Self
164 where
165 I: IntoEndpointService<State, T>,
166 {
167 self.not_found = Arc::new(service.into_endpoint_service().boxed());
168 self
169 }
170}
171
172struct NestedService<S>(S);
173
174impl<S: fmt::Debug> fmt::Debug for NestedService<S> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_tuple("NestedService").field(&self.0).finish()
177 }
178}
179
180impl<S: Clone> Clone for NestedService<S> {
181 fn clone(&self) -> Self {
182 NestedService(self.0.clone())
183 }
184}
185
186impl<S, State> Service<State, Request> for NestedService<S>
187where
188 S: Service<State, Request>,
189{
190 type Response = S::Response;
191 type Error = S::Error;
192
193 fn serve(
194 &self,
195 ctx: Context<State>,
196 req: Request,
197 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
198 let path = ctx.get::<UriParams>().unwrap().glob().unwrap();
200
201 let (mut parts, body) = req.into_parts();
203 let mut uri_parts = parts.uri.into_parts();
204 let path_and_query = uri_parts.path_and_query.take().unwrap();
205 match path_and_query.query() {
206 Some(query) => {
207 uri_parts.path_and_query = Some(format!("{}?{}", path, query).parse().unwrap());
208 }
209 None => {
210 uri_parts.path_and_query = Some(path.parse().unwrap());
211 }
212 }
213 parts.uri = Uri::from_parts(uri_parts).unwrap();
214 let req = Request::from_parts(parts, body);
215
216 self.0.serve(ctx, req)
218 }
219}
220
221impl<State> Default for WebService<State>
222where
223 State: Clone + Send + Sync + 'static,
224{
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl<State> Service<State, Request> for WebService<State>
231where
232 State: Clone + Send + Sync + 'static,
233{
234 type Response = Response;
235 type Error = Infallible;
236
237 async fn serve(
238 &self,
239 mut ctx: Context<State>,
240 req: Request,
241 ) -> Result<Self::Response, Self::Error> {
242 let mut ext = Extensions::new();
243 for endpoint in &self.endpoints {
244 if endpoint.matcher.matches(Some(&mut ext), &ctx, &req) {
245 ctx.extend(ext);
247 return endpoint.service.serve(ctx, req).await;
248 }
249 ext.clear();
251 }
252 self.not_found.serve(ctx, req).await
253 }
254}
255
256#[doc(hidden)]
257#[macro_export]
258macro_rules! __match_service {
324 ($($M:expr_2021 => $S:expr_2021),+, _ => $F:expr $(,)?) => {{
325 use $crate::service::web::IntoEndpointService;
326 use $crate::dep::core::matcher::MatcherRouter;
327 MatcherRouter(($(($M, $S.into_endpoint_service())),+, $F.into_endpoint_service()))
328 }};
329}
330
331#[doc(inline)]
332pub use crate::__match_service as match_service;
333
334#[cfg(test)]
335mod test {
336 use crate::Body;
337 use crate::dep::http_body_util::BodyExt;
338 use crate::matcher::MethodMatcher;
339
340 use super::*;
341
342 async fn get_response<S>(service: &S, uri: &str) -> Response
343 where
344 S: Service<(), Request, Response = Response, Error = Infallible>,
345 {
346 let req = Request::get(uri).body(Body::empty()).unwrap();
347 service.serve(Context::default(), req).await.unwrap()
348 }
349
350 async fn post_response<S>(service: &S, uri: &str) -> Response
351 where
352 S: Service<(), Request, Response = Response, Error = Infallible>,
353 {
354 let req = Request::post(uri).body(Body::empty()).unwrap();
355 service.serve(Context::default(), req).await.unwrap()
356 }
357
358 async fn connect_response<S>(service: &S, uri: &str) -> Response
359 where
360 S: Service<(), Request, Response = Response, Error = Infallible>,
361 {
362 let req = Request::connect(uri).body(Body::empty()).unwrap();
363 service.serve(Context::default(), req).await.unwrap()
364 }
365
366 #[tokio::test]
367 async fn test_web_service() {
368 let svc = WebService::new()
369 .get("/hello", "hello")
370 .post("/world", "world");
371
372 let res = get_response(&svc, "https://www.test.io/hello").await;
373 assert_eq!(res.status(), StatusCode::OK);
374 let body = res.into_body().collect().await.unwrap().to_bytes();
375 assert_eq!(body, "hello");
376
377 let res = post_response(&svc, "https://www.test.io/world").await;
378 assert_eq!(res.status(), StatusCode::OK);
379 let body = res.into_body().collect().await.unwrap().to_bytes();
380 assert_eq!(body, "world");
381
382 let res = get_response(&svc, "https://www.test.io/world").await;
383 assert_eq!(res.status(), StatusCode::NOT_FOUND);
384
385 let res = get_response(&svc, "https://www.test.io").await;
386 assert_eq!(res.status(), StatusCode::NOT_FOUND);
387 }
388
389 #[tokio::test]
390 async fn test_web_service_not_found() {
391 let svc = WebService::new().not_found("not found");
392
393 let res = get_response(&svc, "https://www.test.io/hello").await;
394 assert_eq!(res.status(), StatusCode::OK);
395 let body = res.into_body().collect().await.unwrap().to_bytes();
396 assert_eq!(body, "not found");
397 }
398
399 #[tokio::test]
400 async fn test_web_service_nest() {
401 let svc = WebService::new().nest(
402 "/api",
403 WebService::new()
404 .get("/hello", "hello")
405 .post("/world", "world"),
406 );
407
408 let res = get_response(&svc, "https://www.test.io/api/hello").await;
409 assert_eq!(res.status(), StatusCode::OK);
410 let body = res.into_body().collect().await.unwrap().to_bytes();
411 assert_eq!(body, "hello");
412
413 let res = post_response(&svc, "https://www.test.io/api/world").await;
414 assert_eq!(res.status(), StatusCode::OK);
415 let body = res.into_body().collect().await.unwrap().to_bytes();
416 assert_eq!(body, "world");
417
418 let res = get_response(&svc, "https://www.test.io/api/world").await;
419 assert_eq!(res.status(), StatusCode::NOT_FOUND);
420
421 let res = get_response(&svc, "https://www.test.io").await;
422 assert_eq!(res.status(), StatusCode::NOT_FOUND);
423 }
424
425 #[tokio::test]
426 async fn test_web_service_dir() {
427 let tmp_dir = tempfile::tempdir().unwrap();
428 let file_path = tmp_dir.path().join("index.html");
429 std::fs::write(&file_path, "<h1>Hello, World!</h1>").unwrap();
430 let style_dir = tmp_dir.path().join("style");
431 std::fs::create_dir(&style_dir).unwrap();
432 let file_path = style_dir.join("main.css");
433 std::fs::write(&file_path, "body { background-color: red }").unwrap();
434
435 let svc = WebService::new()
436 .get("/api/version", "v1")
437 .post("/api", StatusCode::FORBIDDEN)
438 .dir("/", tmp_dir.path().to_str().unwrap());
439
440 let res = get_response(&svc, "https://www.test.io/index.html").await;
441 assert_eq!(res.status(), StatusCode::OK);
442 let body = res.into_body().collect().await.unwrap().to_bytes();
443 assert_eq!(body, "<h1>Hello, World!</h1>");
444
445 let res = get_response(&svc, "https://www.test.io/style/main.css").await;
446 assert_eq!(res.status(), StatusCode::OK);
447 let body = res.into_body().collect().await.unwrap().to_bytes();
448 assert_eq!(body, "body { background-color: red }");
449
450 let res = get_response(&svc, "https://www.test.io/api/version").await;
451 assert_eq!(res.status(), StatusCode::OK);
452 let body = res.into_body().collect().await.unwrap().to_bytes();
453 assert_eq!(body, "v1");
454
455 let res = post_response(&svc, "https://www.test.io/api").await;
456 assert_eq!(res.status(), StatusCode::FORBIDDEN);
457
458 let res = get_response(&svc, "https://www.test.io/notfound.html").await;
459 assert_eq!(res.status(), StatusCode::NOT_FOUND);
460
461 let res = get_response(&svc, "https://www.test.io/").await;
462 assert_eq!(res.status(), StatusCode::OK);
463 let body = res.into_body().collect().await.unwrap().to_bytes();
464 assert_eq!(body, "<h1>Hello, World!</h1>");
465 }
466
467 #[tokio::test]
468 async fn test_matcher_service_tuples() {
469 let svc = match_service! {
470 HttpMatcher::get("/hello") => "hello",
471 HttpMatcher::post("/world") => "world",
472 MethodMatcher::CONNECT => "connect",
473 _ => StatusCode::NOT_FOUND,
474 };
475
476 let res = get_response(&svc, "https://www.test.io/hello").await;
477 assert_eq!(res.status(), StatusCode::OK);
478 let body = res.into_body().collect().await.unwrap().to_bytes();
479 assert_eq!(body, "hello");
480
481 let res = post_response(&svc, "https://www.test.io/world").await;
482 assert_eq!(res.status(), StatusCode::OK);
483 let body = res.into_body().collect().await.unwrap().to_bytes();
484 assert_eq!(body, "world");
485
486 let res = connect_response(&svc, "https://www.test.io").await;
487 assert_eq!(res.status(), StatusCode::OK);
488 let body = res.into_body().collect().await.unwrap().to_bytes();
489 assert_eq!(body, "connect");
490
491 let res = get_response(&svc, "https://www.test.io/world").await;
492 assert_eq!(res.status(), StatusCode::NOT_FOUND);
493
494 let res = get_response(&svc, "https://www.test.io").await;
495 assert_eq!(res.status(), StatusCode::NOT_FOUND);
496 }
497}