1use super::{endpoint::Endpoint, IntoEndpointService};
2use crate::{
3 matcher::{HttpMatcher, UriParams},
4 service::fs::ServeDir,
5 Body, IntoResponse, Request, Response, StatusCode, Uri,
6};
7use rama_core::{
8 context::Extensions,
9 matcher::Matcher,
10 service::{service_fn, BoxService, Service},
11 Context,
12};
13use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, sync::Arc};
14
15pub struct WebService<State> {
22 endpoints: Vec<Arc<Endpoint<State>>>,
23 not_found: Arc<BoxService<State, Request, Response, Infallible>>,
24 _phantom: PhantomData<State>,
25}
26
27impl<State> std::fmt::Debug for WebService<State> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("WebService").finish()
30 }
31}
32
33impl<State> Clone for WebService<State> {
34 fn clone(&self) -> Self {
35 Self {
36 endpoints: self.endpoints.clone(),
37 not_found: self.not_found.clone(),
38 _phantom: PhantomData,
39 }
40 }
41}
42
43impl<State> WebService<State>
44where
45 State: Clone + Send + Sync + 'static,
46{
47 pub(crate) fn new() -> Self {
49 Self {
50 endpoints: Vec::new(),
51 not_found: Arc::new(
52 service_fn(|| async { Ok(StatusCode::NOT_FOUND.into_response()) }).boxed(),
53 ),
54 _phantom: PhantomData,
55 }
56 }
57
58 pub fn get<I, T>(self, path: &str, service: I) -> Self
60 where
61 I: IntoEndpointService<State, T>,
62 {
63 let matcher = HttpMatcher::method_get().and_path(path);
64 self.on(matcher, service)
65 }
66
67 pub fn post<I, T>(self, path: &str, service: I) -> Self
69 where
70 I: IntoEndpointService<State, T>,
71 {
72 let matcher = HttpMatcher::method_post().and_path(path);
73 self.on(matcher, service)
74 }
75
76 pub fn put<I, T>(self, path: &str, service: I) -> Self
78 where
79 I: IntoEndpointService<State, T>,
80 {
81 let matcher = HttpMatcher::method_put().and_path(path);
82 self.on(matcher, service)
83 }
84
85 pub fn delete<I, T>(self, path: &str, service: I) -> Self
87 where
88 I: IntoEndpointService<State, T>,
89 {
90 let matcher = HttpMatcher::method_delete().and_path(path);
91 self.on(matcher, service)
92 }
93
94 pub fn patch<I, T>(self, path: &str, service: I) -> Self
96 where
97 I: IntoEndpointService<State, T>,
98 {
99 let matcher = HttpMatcher::method_patch().and_path(path);
100 self.on(matcher, service)
101 }
102
103 pub fn head<I, T>(self, path: &str, service: I) -> Self
105 where
106 I: IntoEndpointService<State, T>,
107 {
108 let matcher = HttpMatcher::method_head().and_path(path);
109 self.on(matcher, service)
110 }
111
112 pub fn options<I, T>(self, path: &str, service: I) -> Self
114 where
115 I: IntoEndpointService<State, T>,
116 {
117 let matcher = HttpMatcher::method_options().and_path(path);
118 self.on(matcher, service)
119 }
120
121 pub fn trace<I, T>(self, path: &str, service: I) -> Self
123 where
124 I: IntoEndpointService<State, T>,
125 {
126 let matcher = HttpMatcher::method_trace().and_path(path);
127 self.on(matcher, service)
128 }
129
130 pub fn nest<I, T>(self, prefix: &str, service: I) -> Self
134 where
135 I: IntoEndpointService<State, T>,
136 {
137 let prefix = format!("{}/*", prefix.trim_end_matches(['/', '*']));
138 let matcher = HttpMatcher::path(prefix);
139 let service = NestedService(service.into_endpoint_service());
140 self.on(matcher, service)
141 }
142
143 pub fn dir(self, prefix: &str, dir: &str) -> Self {
145 let service = ServeDir::new(dir).fallback(self.not_found.clone());
146 self.nest(prefix, service)
147 }
148
149 pub fn on<I, T>(mut self, matcher: HttpMatcher<State, Body>, service: I) -> Self
151 where
152 I: IntoEndpointService<State, T>,
153 {
154 let endpoint = Endpoint {
155 matcher,
156 service: service.into_endpoint_service().boxed(),
157 };
158 self.endpoints.push(Arc::new(endpoint));
159 self
160 }
161
162 pub fn not_found<I, T>(mut self, service: I) -> Self
164 where
165 I: IntoEndpointService<State, T>,
166 {
167 self.not_found = Arc::new(service.into_endpoint_service().boxed());
168 self
169 }
170}
171
172struct NestedService<S>(S);
173
174impl<S: fmt::Debug> fmt::Debug for NestedService<S> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_tuple("NestedService").field(&self.0).finish()
177 }
178}
179
180impl<S: Clone> Clone for NestedService<S> {
181 fn clone(&self) -> Self {
182 NestedService(self.0.clone())
183 }
184}
185
186impl<S, State> Service<State, Request> for NestedService<S>
187where
188 S: Service<State, Request>,
189{
190 type Response = S::Response;
191 type Error = S::Error;
192
193 fn serve(
194 &self,
195 ctx: Context<State>,
196 req: Request,
197 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
198 let path = ctx.get::<UriParams>().unwrap().glob().unwrap();
200
201 let (mut parts, body) = req.into_parts();
203 let mut uri_parts = parts.uri.into_parts();
204 let path_and_query = uri_parts.path_and_query.take().unwrap();
205 match path_and_query.query() {
206 Some(query) => {
207 uri_parts.path_and_query = Some(format!("{}?{}", path, query).parse().unwrap());
208 }
209 None => {
210 uri_parts.path_and_query = Some(path.parse().unwrap());
211 }
212 }
213 parts.uri = Uri::from_parts(uri_parts).unwrap();
214 let req = Request::from_parts(parts, body);
215
216 self.0.serve(ctx, req)
218 }
219}
220
221impl<State> Default for WebService<State>
222where
223 State: Clone + Send + Sync + 'static,
224{
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl<State> Service<State, Request> for WebService<State>
231where
232 State: Clone + Send + Sync + 'static,
233{
234 type Response = Response;
235 type Error = Infallible;
236
237 async fn serve(
238 &self,
239 mut ctx: Context<State>,
240 req: Request,
241 ) -> Result<Self::Response, Self::Error> {
242 let mut ext = Extensions::new();
243 for endpoint in &self.endpoints {
244 if endpoint.matcher.matches(Some(&mut ext), &ctx, &req) {
245 ctx.extend(ext);
247 return endpoint.service.serve(ctx, req).await;
248 }
249 ext.clear();
251 }
252 self.not_found.serve(ctx, req).await
253 }
254}
255
256#[doc(hidden)]
257#[macro_export]
258macro_rules! __match_service {
323 ($($M:expr => $S:expr),+, _ => $F:expr $(,)?) => {{
324 use $crate::service::web::IntoEndpointService;
325 ($(($M, $S.into_endpoint_service())),+, $F.into_endpoint_service())
326 }};
327}
328
329#[doc(inline)]
330pub use crate::__match_service as match_service;
331
332#[cfg(test)]
333mod test {
334 use crate::dep::http_body_util::BodyExt;
335 use crate::matcher::MethodMatcher;
336 use crate::Body;
337
338 use super::*;
339
340 async fn get_response<S>(service: &S, uri: &str) -> Response
341 where
342 S: Service<(), Request, Response = Response, Error = Infallible>,
343 {
344 let req = Request::get(uri).body(Body::empty()).unwrap();
345 service.serve(Context::default(), req).await.unwrap()
346 }
347
348 async fn post_response<S>(service: &S, uri: &str) -> Response
349 where
350 S: Service<(), Request, Response = Response, Error = Infallible>,
351 {
352 let req = Request::post(uri).body(Body::empty()).unwrap();
353 service.serve(Context::default(), req).await.unwrap()
354 }
355
356 async fn connect_response<S>(service: &S, uri: &str) -> Response
357 where
358 S: Service<(), Request, Response = Response, Error = Infallible>,
359 {
360 let req = Request::connect(uri).body(Body::empty()).unwrap();
361 service.serve(Context::default(), req).await.unwrap()
362 }
363
364 #[tokio::test]
365 async fn test_web_service() {
366 let svc = WebService::new()
367 .get("/hello", "hello")
368 .post("/world", "world");
369
370 let res = get_response(&svc, "https://www.test.io/hello").await;
371 assert_eq!(res.status(), StatusCode::OK);
372 let body = res.into_body().collect().await.unwrap().to_bytes();
373 assert_eq!(body, "hello");
374
375 let res = post_response(&svc, "https://www.test.io/world").await;
376 assert_eq!(res.status(), StatusCode::OK);
377 let body = res.into_body().collect().await.unwrap().to_bytes();
378 assert_eq!(body, "world");
379
380 let res = get_response(&svc, "https://www.test.io/world").await;
381 assert_eq!(res.status(), StatusCode::NOT_FOUND);
382
383 let res = get_response(&svc, "https://www.test.io").await;
384 assert_eq!(res.status(), StatusCode::NOT_FOUND);
385 }
386
387 #[tokio::test]
388 async fn test_web_service_not_found() {
389 let svc = WebService::new().not_found("not found");
390
391 let res = get_response(&svc, "https://www.test.io/hello").await;
392 assert_eq!(res.status(), StatusCode::OK);
393 let body = res.into_body().collect().await.unwrap().to_bytes();
394 assert_eq!(body, "not found");
395 }
396
397 #[tokio::test]
398 async fn test_web_service_nest() {
399 let svc = WebService::new().nest(
400 "/api",
401 WebService::new()
402 .get("/hello", "hello")
403 .post("/world", "world"),
404 );
405
406 let res = get_response(&svc, "https://www.test.io/api/hello").await;
407 assert_eq!(res.status(), StatusCode::OK);
408 let body = res.into_body().collect().await.unwrap().to_bytes();
409 assert_eq!(body, "hello");
410
411 let res = post_response(&svc, "https://www.test.io/api/world").await;
412 assert_eq!(res.status(), StatusCode::OK);
413 let body = res.into_body().collect().await.unwrap().to_bytes();
414 assert_eq!(body, "world");
415
416 let res = get_response(&svc, "https://www.test.io/api/world").await;
417 assert_eq!(res.status(), StatusCode::NOT_FOUND);
418
419 let res = get_response(&svc, "https://www.test.io").await;
420 assert_eq!(res.status(), StatusCode::NOT_FOUND);
421 }
422
423 #[tokio::test]
424 async fn test_web_service_dir() {
425 let tmp_dir = tempfile::tempdir().unwrap();
426 let file_path = tmp_dir.path().join("index.html");
427 std::fs::write(&file_path, "<h1>Hello, World!</h1>").unwrap();
428 let style_dir = tmp_dir.path().join("style");
429 std::fs::create_dir(&style_dir).unwrap();
430 let file_path = style_dir.join("main.css");
431 std::fs::write(&file_path, "body { background-color: red }").unwrap();
432
433 let svc = WebService::new()
434 .get("/api/version", "v1")
435 .post("/api", StatusCode::FORBIDDEN)
436 .dir("/", tmp_dir.path().to_str().unwrap());
437
438 let res = get_response(&svc, "https://www.test.io/index.html").await;
439 assert_eq!(res.status(), StatusCode::OK);
440 let body = res.into_body().collect().await.unwrap().to_bytes();
441 assert_eq!(body, "<h1>Hello, World!</h1>");
442
443 let res = get_response(&svc, "https://www.test.io/style/main.css").await;
444 assert_eq!(res.status(), StatusCode::OK);
445 let body = res.into_body().collect().await.unwrap().to_bytes();
446 assert_eq!(body, "body { background-color: red }");
447
448 let res = get_response(&svc, "https://www.test.io/api/version").await;
449 assert_eq!(res.status(), StatusCode::OK);
450 let body = res.into_body().collect().await.unwrap().to_bytes();
451 assert_eq!(body, "v1");
452
453 let res = post_response(&svc, "https://www.test.io/api").await;
454 assert_eq!(res.status(), StatusCode::FORBIDDEN);
455
456 let res = get_response(&svc, "https://www.test.io/notfound.html").await;
457 assert_eq!(res.status(), StatusCode::NOT_FOUND);
458
459 let res = get_response(&svc, "https://www.test.io/").await;
460 assert_eq!(res.status(), StatusCode::OK);
461 let body = res.into_body().collect().await.unwrap().to_bytes();
462 assert_eq!(body, "<h1>Hello, World!</h1>");
463 }
464
465 #[tokio::test]
466 async fn test_matcher_service_tuples() {
467 let svc = match_service! {
468 HttpMatcher::get("/hello") => "hello",
469 HttpMatcher::post("/world") => "world",
470 MethodMatcher::CONNECT => "connect",
471 _ => StatusCode::NOT_FOUND,
472 };
473
474 let res = get_response(&svc, "https://www.test.io/hello").await;
475 assert_eq!(res.status(), StatusCode::OK);
476 let body = res.into_body().collect().await.unwrap().to_bytes();
477 assert_eq!(body, "hello");
478
479 let res = post_response(&svc, "https://www.test.io/world").await;
480 assert_eq!(res.status(), StatusCode::OK);
481 let body = res.into_body().collect().await.unwrap().to_bytes();
482 assert_eq!(body, "world");
483
484 let res = connect_response(&svc, "https://www.test.io").await;
485 assert_eq!(res.status(), StatusCode::OK);
486 let body = res.into_body().collect().await.unwrap().to_bytes();
487 assert_eq!(body, "connect");
488
489 let res = get_response(&svc, "https://www.test.io/world").await;
490 assert_eq!(res.status(), StatusCode::NOT_FOUND);
491
492 let res = get_response(&svc, "https://www.test.io").await;
493 assert_eq!(res.status(), StatusCode::NOT_FOUND);
494 }
495}