1use std::{cell::Cell, cell::RefCell, fmt, marker::PhantomData, rc::Rc};
2
3use crate::http::Request;
4use crate::router::ResourceDef;
5use crate::service::boxed::{self, BoxServiceFactory};
6use crate::service::cfg::SharedCfg;
7use crate::service::{Identity, Middleware, Service, ServiceCtx, ServiceFactory};
8use crate::service::{IntoServiceFactory, chain_factory, dev::ServiceChainFactory};
9use crate::util::{BoxFuture, Extensions};
10
11use super::app_service::{AppFactory, AppService};
12use super::config::ServiceConfig;
13use super::request::WebRequest;
14use super::resource::Resource;
15use super::response::WebResponse;
16use super::route::Route;
17use super::service::{AppServiceFactory, ServiceFactoryWrapper, WebServiceFactory};
18use super::stack::WebStack;
19use super::{DefaultError, ErrorRenderer};
20
21type HttpNewService<Err: ErrorRenderer> =
22 BoxServiceFactory<SharedCfg, WebRequest<Err>, WebResponse, Err::Container, ()>;
23type FnStateFactory = Box<dyn Fn(Extensions) -> BoxFuture<'static, Result<Extensions, ()>>>;
24
25pub struct App<M, F, Err: ErrorRenderer = DefaultError> {
28 middleware: M,
29 filter: ServiceChainFactory<F, WebRequest<Err>, SharedCfg>,
30 services: Vec<Box<dyn AppServiceFactory<Err>>>,
31 default: Option<Rc<HttpNewService<Err>>>,
32 external: Vec<ResourceDef>,
33 extensions: Extensions,
34 state_factories: Vec<FnStateFactory>,
35 error_renderer: Err,
36 case_insensitive: bool,
37}
38
39impl Default for App<Identity, Filter<DefaultError>, DefaultError> {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl App<Identity, Filter<DefaultError>, DefaultError> {
46 #[must_use]
47 pub fn new() -> Self {
49 App {
50 middleware: Identity,
51 filter: chain_factory(Filter::new()),
52 state_factories: Vec::new(),
53 services: Vec::new(),
54 default: None,
55 external: Vec::new(),
56 extensions: Extensions::new(),
57 error_renderer: DefaultError,
58 case_insensitive: false,
59 }
60 }
61}
62
63impl<Err: ErrorRenderer> App<Identity, Filter<Err>, Err> {
64 #[must_use]
65 pub fn with(err: Err) -> Self {
67 App {
68 middleware: Identity,
69 filter: chain_factory(Filter::new()),
70 state_factories: Vec::new(),
71 services: Vec::new(),
72 default: None,
73 external: Vec::new(),
74 extensions: Extensions::new(),
75 error_renderer: err,
76 case_insensitive: false,
77 }
78 }
79}
80
81impl<M, T, Err> App<M, T, Err>
82where
83 T: ServiceFactory<
84 WebRequest<Err>,
85 SharedCfg,
86 Response = WebRequest<Err>,
87 Error = Err::Container,
88 InitError = (),
89 >,
90 Err: ErrorRenderer,
91{
92 #[must_use]
93 pub fn state<U: 'static>(mut self, state: U) -> Self {
126 self.extensions.insert(state);
127 self
128 }
129
130 #[must_use]
131 pub fn state_factory<F, D, E>(mut self, state: F) -> Self
136 where
137 F: AsyncFnOnce() -> Result<D, E> + 'static,
138 D: 'static,
139 E: fmt::Debug,
140 {
141 let state = Cell::new(Some(state));
142
143 self.state_factories.push(Box::new(move |mut ext| {
144 let mut state = state.take();
145
146 Box::pin(async move {
147 if let Some(state) = state.take() {
148 match state().await {
149 Err(e) => {
150 log::error!("Cannot construct state instance: {e:?}");
151 Err(())
152 }
153 Ok(st) => {
154 ext.insert(st);
155 Ok(ext)
156 }
157 }
158 } else {
159 log::error!("Cannot construct state instance");
160 Err(())
161 }
162 })
163 }));
164 self
165 }
166
167 #[must_use]
168 pub fn configure<F>(mut self, f: F) -> Self
194 where
195 F: FnOnce(&mut ServiceConfig<Err>),
196 {
197 let mut cfg = ServiceConfig::new();
198 f(&mut cfg);
199 self.services.extend(cfg.services);
200 self.external.extend(cfg.external);
201 self.extensions.extend(cfg.state);
202 self
203 }
204
205 #[must_use]
206 pub fn route(self, path: &str, mut route: Route<Err>) -> Self {
226 self.service(
227 Resource::new(path)
228 .add_guards(route.take_guards())
229 .route(route),
230 )
231 }
232
233 #[must_use]
234 pub fn service<F>(mut self, factory: F) -> Self
244 where
245 F: WebServiceFactory<Err> + 'static,
246 {
247 self.services
248 .push(Box::new(ServiceFactoryWrapper::new(factory)));
249 self
250 }
251
252 #[must_use]
253 pub fn default_service<F, U>(mut self, f: F) -> Self
288 where
289 F: IntoServiceFactory<U, WebRequest<Err>, SharedCfg>,
290 U: ServiceFactory<
291 WebRequest<Err>,
292 SharedCfg,
293 Response = WebResponse,
294 Error = Err::Container,
295 > + 'static,
296 U::InitError: fmt::Debug,
297 {
298 self.default = Some(Rc::new(boxed::factory(
300 chain_factory(f)
301 .map_init_err(|e| log::error!("Cannot construct default service: {e:?}")),
302 )));
303
304 self
305 }
306
307 #[must_use]
308 pub fn external_resource<N, U>(mut self, name: N, url: U) -> Self
331 where
332 N: AsRef<str>,
333 U: AsRef<str>,
334 {
335 let mut rdef = ResourceDef::new(url.as_ref());
336 *rdef.name_mut() = name.as_ref().to_string();
337 self.external.push(rdef);
338 self
339 }
340
341 #[must_use]
342 pub fn filter<S, U>(
368 self,
369 filter: U,
370 ) -> App<
371 M,
372 impl ServiceFactory<
373 WebRequest<Err>,
374 SharedCfg,
375 Response = WebRequest<Err>,
376 Error = Err::Container,
377 InitError = (),
378 >,
379 Err,
380 >
381 where
382 S: ServiceFactory<
383 WebRequest<Err>,
384 SharedCfg,
385 Response = WebRequest<Err>,
386 Error = Err::Container,
387 >,
388 U: IntoServiceFactory<S, WebRequest<Err>, SharedCfg>,
389 {
390 App {
391 filter: self
392 .filter
393 .and_then(filter.into_factory().map_init_err(|_| ())),
394 middleware: self.middleware,
395 state_factories: self.state_factories,
396 services: self.services,
397 default: self.default,
398 external: self.external,
399 extensions: self.extensions,
400 error_renderer: self.error_renderer,
401 case_insensitive: self.case_insensitive,
402 }
403 }
404
405 #[must_use]
406 pub fn middleware<U>(self, mw: U) -> App<WebStack<M, U, Err>, T, Err> {
434 App {
435 middleware: WebStack::new(self.middleware, mw),
436 filter: self.filter,
437 state_factories: self.state_factories,
438 services: self.services,
439 default: self.default,
440 external: self.external,
441 extensions: self.extensions,
442 error_renderer: self.error_renderer,
443 case_insensitive: self.case_insensitive,
444 }
445 }
446
447 #[deprecated(since = "3.2.0", note = "use `middleware()` instead")]
448 #[doc(hidden)]
449 pub fn wrap<U>(self, mw: U) -> App<WebStack<M, U, Err>, T, Err> {
450 self.middleware(mw)
451 }
452
453 #[must_use]
454 pub fn case_insensitive_routing(mut self) -> Self {
458 self.case_insensitive = true;
459 self
460 }
461}
462
463impl<M, F, Err> App<M, F, Err>
464where
465 M: Middleware<AppService<F::Service, Err>, SharedCfg> + 'static,
466 M::Service: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
467 F: ServiceFactory<
468 WebRequest<Err>,
469 SharedCfg,
470 Response = WebRequest<Err>,
471 Error = Err::Container,
472 InitError = (),
473 >,
474 Err: ErrorRenderer,
475{
476 pub fn finish(
495 self,
496 ) -> impl ServiceFactory<
497 Request,
498 SharedCfg,
499 Response = WebResponse,
500 Error = Err::Container,
501 InitError = (),
502 > {
503 IntoServiceFactory::<AppFactory<M, F, Err>, Request, SharedCfg>::into_factory(self)
504 }
505}
506
507impl<M, F, Err> IntoServiceFactory<AppFactory<M, F, Err>, Request, SharedCfg>
508 for App<M, F, Err>
509where
510 M: Middleware<AppService<F::Service, Err>, SharedCfg> + 'static,
511 M::Service: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
512 F: ServiceFactory<
513 WebRequest<Err>,
514 SharedCfg,
515 Response = WebRequest<Err>,
516 Error = Err::Container,
517 InitError = (),
518 >,
519 Err: ErrorRenderer,
520{
521 fn into_factory(self) -> AppFactory<M, F, Err> {
522 AppFactory {
523 filter: self.filter,
524 middleware: Rc::new(self.middleware),
525 state_factories: Rc::new(self.state_factories),
526 services: Rc::new(RefCell::new(self.services)),
527 external: RefCell::new(self.external),
528 default: self.default,
529 extensions: RefCell::new(Some(self.extensions)),
530 case_insensitive: self.case_insensitive,
531 }
532 }
533}
534
535pub struct Filter<Err>(PhantomData<Err>);
536
537impl<Err: ErrorRenderer> Filter<Err> {
538 pub(super) fn new() -> Self {
539 Filter(PhantomData)
540 }
541}
542
543impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>, SharedCfg> for Filter<Err> {
544 type Response = WebRequest<Err>;
545 type Error = Err::Container;
546 type InitError = ();
547 type Service = Filter<Err>;
548
549 async fn create(&self, _: SharedCfg) -> Result<Self::Service, Self::InitError> {
550 Ok(Filter(PhantomData))
551 }
552}
553
554impl<Err: ErrorRenderer> Service<WebRequest<Err>> for Filter<Err> {
555 type Response = WebRequest<Err>;
556 type Error = Err::Container;
557
558 async fn call(
559 &self,
560 req: WebRequest<Err>,
561 _: ServiceCtx<'_, Self>,
562 ) -> Result<WebRequest<Err>, Err::Container> {
563 Ok(req)
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::http::{Method, StatusCode, header, header::HeaderValue};
571 use crate::web::test::{TestRequest, call_service, init_service, read_body};
572 use crate::web::{self, HttpRequest, HttpResponse, middleware::DefaultHeaders};
573 use crate::{service::fn_service, util::Ready};
574
575 #[crate::rt_test]
576 async fn test_default_resource() {
577 let srv = App::new()
578 .service(web::resource("/test").to(|| async { HttpResponse::Ok() }))
579 .finish()
580 .pipeline(SharedCfg::default())
581 .await
582 .unwrap();
583 let req = TestRequest::with_uri("/test").to_request();
584 let resp = srv.call(req).await.unwrap();
585 assert_eq!(resp.status(), StatusCode::OK);
586
587 let req = TestRequest::with_uri("/blah").to_request();
588 let resp = srv.call(req).await.unwrap();
589 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
590
591 let srv = App::new()
592 .service(web::resource("/test").to(|| async { HttpResponse::Ok() }))
593 .service(
594 web::resource("/test2")
595 .default_service(|r: WebRequest<DefaultError>| async move {
596 Ok(r.into_response(HttpResponse::Created()))
597 })
598 .route(web::get().to(|| async { HttpResponse::Ok() })),
599 )
600 .default_service(|r: WebRequest<DefaultError>| async move {
601 Ok(r.into_response(HttpResponse::MethodNotAllowed()))
602 })
603 .finish()
604 .pipeline(SharedCfg::default())
605 .await
606 .unwrap();
607
608 let req = TestRequest::with_uri("/blah").to_request();
609 let resp = srv.call(req).await.unwrap();
610 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
611
612 let req = TestRequest::with_uri("/test2").to_request();
613 let resp = srv.call(req).await.unwrap();
614 assert_eq!(resp.status(), StatusCode::OK);
615
616 let req = TestRequest::with_uri("/test2")
617 .method(Method::POST)
618 .to_request();
619 let resp = srv.call(req).await.unwrap();
620 assert_eq!(resp.status(), StatusCode::CREATED);
621 }
622
623 #[crate::rt_test]
624 async fn test_state_factory() {
625 let srv = init_service(
626 App::new()
627 .state_factory(|| async { Ok::<_, ()>(10usize) })
628 .service(
629 web::resource("/")
630 .to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
631 ),
632 )
633 .await;
634 let req = TestRequest::default().to_request();
635 let resp = srv.call(req).await.unwrap();
636 assert_eq!(resp.status(), StatusCode::OK);
637
638 let srv = init_service(
639 App::new()
640 .state_factory(|| async { Ok::<_, ()>(10u32) })
641 .service(
642 web::resource("/")
643 .to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
644 ),
645 )
646 .await;
647 let req = TestRequest::default().to_request();
648 let res = srv.call(req).await.unwrap();
649 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
650 }
651
652 #[crate::rt_test]
653 async fn test_extension() {
654 let srv = init_service(
655 App::new()
656 .state(10usize)
657 .filter(fn_service(move |req: WebRequest<_>| {
658 assert_eq!(*req.app_state::<usize>().unwrap(), 10);
659 Ready::Ok(req)
660 }))
661 .service(web::resource("/").to(|req: HttpRequest| async move {
662 assert_eq!(*req.app_state::<usize>().unwrap(), 10);
663 HttpResponse::Ok()
664 })),
665 )
666 .await;
667 let req = TestRequest::default().to_request();
668 let resp = srv.call(req).await.unwrap();
669 assert_eq!(resp.status(), StatusCode::OK);
670 }
671
672 #[crate::rt_test]
673 async fn test_filter() {
674 let filter = Rc::new(std::cell::Cell::new(false));
675 let filter2 = filter.clone();
676 let srv = init_service(
677 App::new()
678 .filter(fn_service(move |req: WebRequest<_>| {
679 filter2.set(true);
680 Ready::Ok(req)
681 }))
682 .route("/test", web::get().to(|| async { HttpResponse::Ok() })),
683 )
684 .await;
685 let req = TestRequest::with_uri("/test").to_request();
686 let resp = call_service(&srv, req).await;
687 assert_eq!(resp.status(), StatusCode::OK);
688 assert!(filter.get());
689 }
690
691 #[crate::rt_test]
692 async fn test_wrap() {
693 let srv = init_service(
694 App::new()
695 .middleware(
696 DefaultHeaders::new()
697 .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")),
698 )
699 .route("/test", web::get().to(|| async { HttpResponse::Ok() })),
700 )
701 .await;
702 let req = TestRequest::with_uri("/test").to_request();
703 let resp = call_service(&srv, req).await;
704 assert_eq!(resp.status(), StatusCode::OK);
705 assert_eq!(
706 resp.headers().get(header::CONTENT_TYPE).unwrap(),
707 HeaderValue::from_static("0001")
708 );
709 }
710
711 #[crate::rt_test]
712 async fn test_router_wrap() {
713 let srv = init_service(
714 App::new()
715 .route("/test", web::get().to(|| async { HttpResponse::Ok() }))
716 .middleware(
717 DefaultHeaders::new()
718 .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")),
719 ),
720 )
721 .await;
722 let req = TestRequest::with_uri("/test").to_request();
723 let resp = call_service(&srv, req).await;
724 assert_eq!(resp.status(), StatusCode::OK);
725 assert_eq!(
726 resp.headers().get(header::CONTENT_TYPE).unwrap(),
727 HeaderValue::from_static("0001")
728 );
729 }
730
731 #[crate::rt_test]
732 async fn test_case_insensitive_router() {
733 let srv = init_service(
734 App::new()
735 .case_insensitive_routing()
736 .route("/test", web::get().to(|| async { HttpResponse::Ok() })),
737 )
738 .await;
739 let req = TestRequest::with_uri("/test").to_request();
740 let resp = call_service(&srv, req).await;
741 assert_eq!(resp.status(), StatusCode::OK);
742
743 let req = TestRequest::with_uri("/Test").to_request();
744 let resp = call_service(&srv, req).await;
745 assert_eq!(resp.status(), StatusCode::OK);
746 }
747
748 #[cfg(feature = "url")]
749 #[crate::rt_test]
750 async fn test_external_resource() {
751 use crate::util::Bytes;
752
753 let srv = init_service(
754 App::new()
755 .external_resource("youtube", "https://youtube.com/watch/{video_id}")
756 .route(
757 "/test",
758 web::get().to(|req: HttpRequest| async move {
759 HttpResponse::Ok()
760 .body(format!("{}", req.url_for("youtube", ["12345"]).unwrap()))
761 }),
762 ),
763 )
764 .await;
765 let req = TestRequest::with_uri("/test").to_request();
766 let resp = call_service(&srv, req).await;
767 assert_eq!(resp.status(), StatusCode::OK);
768 let body = read_body(resp).await;
769 assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
770 }
771}