1use super::{IntoEndpointService, endpoint::Endpoint};
2use crate::{
3 Body, Request, Response, StatusCode, Uri,
4 matcher::{HttpMatcher, UriParams},
5 service::fs::ServeDir,
6 service::web::endpoint::response::IntoResponse,
7};
8use rama_core::{
9 Context,
10 context::Extensions,
11 matcher::Matcher,
12 service::{BoxService, Service, service_fn},
13};
14use std::{convert::Infallible, fmt, marker::PhantomData, sync::Arc};
15
16pub struct WebService<State> {
23 endpoints: Vec<Arc<Endpoint<State>>>,
24 not_found: Arc<BoxService<State, Request, Response, Infallible>>,
25 _phantom: PhantomData<State>,
26}
27
28impl<State> std::fmt::Debug for WebService<State> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("WebService").finish()
31 }
32}
33
34impl<State> Clone for WebService<State> {
35 fn clone(&self) -> Self {
36 Self {
37 endpoints: self.endpoints.clone(),
38 not_found: self.not_found.clone(),
39 _phantom: PhantomData,
40 }
41 }
42}
43
44impl<State> WebService<State>
45where
46 State: Clone + Send + Sync + 'static,
47{
48 pub(crate) fn new() -> Self {
50 Self {
51 endpoints: Vec::new(),
52 not_found: Arc::new(
53 service_fn(async || Ok(StatusCode::NOT_FOUND.into_response())).boxed(),
54 ),
55 _phantom: PhantomData,
56 }
57 }
58
59 pub fn get<I, T>(self, path: &str, service: I) -> Self
61 where
62 I: IntoEndpointService<State, T>,
63 {
64 let matcher = HttpMatcher::method_get().and_path(path);
65 self.on(matcher, service)
66 }
67
68 pub fn post<I, T>(self, path: &str, service: I) -> Self
70 where
71 I: IntoEndpointService<State, T>,
72 {
73 let matcher = HttpMatcher::method_post().and_path(path);
74 self.on(matcher, service)
75 }
76
77 pub fn put<I, T>(self, path: &str, service: I) -> Self
79 where
80 I: IntoEndpointService<State, T>,
81 {
82 let matcher = HttpMatcher::method_put().and_path(path);
83 self.on(matcher, service)
84 }
85
86 pub fn delete<I, T>(self, path: &str, service: I) -> Self
88 where
89 I: IntoEndpointService<State, T>,
90 {
91 let matcher = HttpMatcher::method_delete().and_path(path);
92 self.on(matcher, service)
93 }
94
95 pub fn patch<I, T>(self, path: &str, service: I) -> Self
97 where
98 I: IntoEndpointService<State, T>,
99 {
100 let matcher = HttpMatcher::method_patch().and_path(path);
101 self.on(matcher, service)
102 }
103
104 pub fn head<I, T>(self, path: &str, service: I) -> Self
106 where
107 I: IntoEndpointService<State, T>,
108 {
109 let matcher = HttpMatcher::method_head().and_path(path);
110 self.on(matcher, service)
111 }
112
113 pub fn options<I, T>(self, path: &str, service: I) -> Self
115 where
116 I: IntoEndpointService<State, T>,
117 {
118 let matcher = HttpMatcher::method_options().and_path(path);
119 self.on(matcher, service)
120 }
121
122 pub fn trace<I, T>(self, path: &str, service: I) -> Self
124 where
125 I: IntoEndpointService<State, T>,
126 {
127 let matcher = HttpMatcher::method_trace().and_path(path);
128 self.on(matcher, service)
129 }
130
131 pub fn nest<I, T>(self, prefix: &str, service: I) -> Self
135 where
136 I: IntoEndpointService<State, T>,
137 {
138 let prefix = format!("{}/*", prefix.trim_end_matches(['/', '*']));
139 let matcher = HttpMatcher::path(prefix);
140 let service = NestedService(service.into_endpoint_service());
141 self.on(matcher, service)
142 }
143
144 pub fn dir(self, prefix: &str, dir: &str) -> Self {
146 let service = ServeDir::new(dir).fallback(self.not_found.clone());
147 self.nest(prefix, service)
148 }
149
150 pub fn on<I, T>(mut self, matcher: HttpMatcher<State, Body>, service: I) -> Self
152 where
153 I: IntoEndpointService<State, T>,
154 {
155 let endpoint = Endpoint {
156 matcher,
157 service: service.into_endpoint_service().boxed(),
158 };
159 self.endpoints.push(Arc::new(endpoint));
160 self
161 }
162
163 pub fn not_found<I, T>(mut self, service: I) -> Self
165 where
166 I: IntoEndpointService<State, T>,
167 {
168 self.not_found = Arc::new(service.into_endpoint_service().boxed());
169 self
170 }
171}
172
173struct NestedService<S>(S);
174
175impl<S: fmt::Debug> fmt::Debug for NestedService<S> {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_tuple("NestedService").field(&self.0).finish()
178 }
179}
180
181impl<S: Clone> Clone for NestedService<S> {
182 fn clone(&self) -> Self {
183 NestedService(self.0.clone())
184 }
185}
186
187impl<S, State> Service<State, Request> for NestedService<S>
188where
189 S: Service<State, Request>,
190{
191 type Response = S::Response;
192 type Error = S::Error;
193
194 fn serve(
195 &self,
196 ctx: Context<State>,
197 req: Request,
198 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
199 let path = ctx.get::<UriParams>().unwrap().glob().unwrap();
201
202 let (mut parts, body) = req.into_parts();
204 let mut uri_parts = parts.uri.into_parts();
205 let path_and_query = uri_parts.path_and_query.take().unwrap();
206 match path_and_query.query() {
207 Some(query) => {
208 uri_parts.path_and_query = Some(format!("{}?{}", path, query).parse().unwrap());
209 }
210 None => {
211 uri_parts.path_and_query = Some(path.parse().unwrap());
212 }
213 }
214 parts.uri = Uri::from_parts(uri_parts).unwrap();
215 let req = Request::from_parts(parts, body);
216
217 self.0.serve(ctx, req)
219 }
220}
221
222impl<State> Default for WebService<State>
223where
224 State: Clone + Send + Sync + 'static,
225{
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231impl<State> Service<State, Request> for WebService<State>
232where
233 State: Clone + Send + Sync + 'static,
234{
235 type Response = Response;
236 type Error = Infallible;
237
238 async fn serve(
239 &self,
240 mut ctx: Context<State>,
241 req: Request,
242 ) -> Result<Self::Response, Self::Error> {
243 let mut ext = Extensions::new();
244 for endpoint in &self.endpoints {
245 if endpoint.matcher.matches(Some(&mut ext), &ctx, &req) {
246 ctx.extend(ext);
248 return endpoint.service.serve(ctx, req).await;
249 }
250 ext.clear();
252 }
253 self.not_found.serve(ctx, req).await
254 }
255}
256
257#[doc(hidden)]
258#[macro_export]
259macro_rules! __match_service {
325 ($($M:expr_2021 => $S:expr_2021),+, _ => $F:expr $(,)?) => {{
326 use $crate::service::web::IntoEndpointService;
327 use $crate::dep::core::matcher::MatcherRouter;
328 MatcherRouter(($(($M, $S.into_endpoint_service())),+, $F.into_endpoint_service()))
329 }};
330}
331
332#[doc(inline)]
333pub use crate::__match_service as match_service;
334
335#[cfg(test)]
336mod test {
337 use crate::Body;
338 use crate::dep::http_body_util::BodyExt;
339 use crate::matcher::MethodMatcher;
340
341 use super::*;
342
343 async fn get_response<S>(service: &S, uri: &str) -> Response
344 where
345 S: Service<(), Request, Response = Response, Error = Infallible>,
346 {
347 let req = Request::get(uri).body(Body::empty()).unwrap();
348 service.serve(Context::default(), req).await.unwrap()
349 }
350
351 async fn post_response<S>(service: &S, uri: &str) -> Response
352 where
353 S: Service<(), Request, Response = Response, Error = Infallible>,
354 {
355 let req = Request::post(uri).body(Body::empty()).unwrap();
356 service.serve(Context::default(), req).await.unwrap()
357 }
358
359 async fn connect_response<S>(service: &S, uri: &str) -> Response
360 where
361 S: Service<(), Request, Response = Response, Error = Infallible>,
362 {
363 let req = Request::connect(uri).body(Body::empty()).unwrap();
364 service.serve(Context::default(), req).await.unwrap()
365 }
366
367 #[tokio::test]
368 async fn test_web_service() {
369 let svc = WebService::new()
370 .get("/hello", "hello")
371 .post("/world", "world");
372
373 let res = get_response(&svc, "https://www.test.io/hello").await;
374 assert_eq!(res.status(), StatusCode::OK);
375 let body = res.into_body().collect().await.unwrap().to_bytes();
376 assert_eq!(body, "hello");
377
378 let res = post_response(&svc, "https://www.test.io/world").await;
379 assert_eq!(res.status(), StatusCode::OK);
380 let body = res.into_body().collect().await.unwrap().to_bytes();
381 assert_eq!(body, "world");
382
383 let res = get_response(&svc, "https://www.test.io/world").await;
384 assert_eq!(res.status(), StatusCode::NOT_FOUND);
385
386 let res = get_response(&svc, "https://www.test.io").await;
387 assert_eq!(res.status(), StatusCode::NOT_FOUND);
388 }
389
390 #[tokio::test]
391 async fn test_web_service_not_found() {
392 let svc = WebService::new().not_found("not found");
393
394 let res = get_response(&svc, "https://www.test.io/hello").await;
395 assert_eq!(res.status(), StatusCode::OK);
396 let body = res.into_body().collect().await.unwrap().to_bytes();
397 assert_eq!(body, "not found");
398 }
399
400 #[tokio::test]
401 async fn test_web_service_nest() {
402 let svc = WebService::new().nest(
403 "/api",
404 WebService::new()
405 .get("/hello", "hello")
406 .post("/world", "world"),
407 );
408
409 let res = get_response(&svc, "https://www.test.io/api/hello").await;
410 assert_eq!(res.status(), StatusCode::OK);
411 let body = res.into_body().collect().await.unwrap().to_bytes();
412 assert_eq!(body, "hello");
413
414 let res = post_response(&svc, "https://www.test.io/api/world").await;
415 assert_eq!(res.status(), StatusCode::OK);
416 let body = res.into_body().collect().await.unwrap().to_bytes();
417 assert_eq!(body, "world");
418
419 let res = get_response(&svc, "https://www.test.io/api/world").await;
420 assert_eq!(res.status(), StatusCode::NOT_FOUND);
421
422 let res = get_response(&svc, "https://www.test.io").await;
423 assert_eq!(res.status(), StatusCode::NOT_FOUND);
424 }
425
426 #[tokio::test]
427 async fn test_web_service_dir() {
428 let tmp_dir = tempfile::tempdir().unwrap();
429 let file_path = tmp_dir.path().join("index.html");
430 std::fs::write(&file_path, "<h1>Hello, World!</h1>").unwrap();
431 let style_dir = tmp_dir.path().join("style");
432 std::fs::create_dir(&style_dir).unwrap();
433 let file_path = style_dir.join("main.css");
434 std::fs::write(&file_path, "body { background-color: red }").unwrap();
435
436 let svc = WebService::new()
437 .get("/api/version", "v1")
438 .post("/api", StatusCode::FORBIDDEN)
439 .dir("/", tmp_dir.path().to_str().unwrap());
440
441 let res = get_response(&svc, "https://www.test.io/index.html").await;
442 assert_eq!(res.status(), StatusCode::OK);
443 let body = res.into_body().collect().await.unwrap().to_bytes();
444 assert_eq!(body, "<h1>Hello, World!</h1>");
445
446 let res = get_response(&svc, "https://www.test.io/style/main.css").await;
447 assert_eq!(res.status(), StatusCode::OK);
448 let body = res.into_body().collect().await.unwrap().to_bytes();
449 assert_eq!(body, "body { background-color: red }");
450
451 let res = get_response(&svc, "https://www.test.io/api/version").await;
452 assert_eq!(res.status(), StatusCode::OK);
453 let body = res.into_body().collect().await.unwrap().to_bytes();
454 assert_eq!(body, "v1");
455
456 let res = post_response(&svc, "https://www.test.io/api").await;
457 assert_eq!(res.status(), StatusCode::FORBIDDEN);
458
459 let res = get_response(&svc, "https://www.test.io/notfound.html").await;
460 assert_eq!(res.status(), StatusCode::NOT_FOUND);
461
462 let res = get_response(&svc, "https://www.test.io/").await;
463 assert_eq!(res.status(), StatusCode::OK);
464 let body = res.into_body().collect().await.unwrap().to_bytes();
465 assert_eq!(body, "<h1>Hello, World!</h1>");
466 }
467
468 #[tokio::test]
469 async fn test_matcher_service_tuples() {
470 let svc = match_service! {
471 HttpMatcher::get("/hello") => "hello",
472 HttpMatcher::post("/world") => "world",
473 MethodMatcher::CONNECT => "connect",
474 _ => StatusCode::NOT_FOUND,
475 };
476
477 let res = get_response(&svc, "https://www.test.io/hello").await;
478 assert_eq!(res.status(), StatusCode::OK);
479 let body = res.into_body().collect().await.unwrap().to_bytes();
480 assert_eq!(body, "hello");
481
482 let res = post_response(&svc, "https://www.test.io/world").await;
483 assert_eq!(res.status(), StatusCode::OK);
484 let body = res.into_body().collect().await.unwrap().to_bytes();
485 assert_eq!(body, "world");
486
487 let res = connect_response(&svc, "https://www.test.io").await;
488 assert_eq!(res.status(), StatusCode::OK);
489 let body = res.into_body().collect().await.unwrap().to_bytes();
490 assert_eq!(body, "connect");
491
492 let res = get_response(&svc, "https://www.test.io/world").await;
493 assert_eq!(res.status(), StatusCode::NOT_FOUND);
494
495 let res = get_response(&svc, "https://www.test.io").await;
496 assert_eq!(res.status(), StatusCode::NOT_FOUND);
497 }
498}