use std::{cell::RefCell, marker, rc::Rc, task::Context};
use crate::http::{Request, Response};
use crate::router::{Path, ResourceDef, Router};
use crate::service::boxed::{self, BoxService, BoxServiceFactory};
use crate::service::cfg::SharedCfg;
use crate::service::dev::ServiceChainFactory;
use crate::service::{Middleware, Service, ServiceCtx, ServiceFactory, fn_service};
use crate::util::{BoxFuture, Extensions, join};
use super::error::ErrorRenderer;
use super::guard::Guard;
use super::httprequest::HttpRequest;
use super::request::WebRequest;
use super::response::WebResponse;
use super::rmap::ResourceMap;
use super::service::{AppServiceFactory, AppState, WebServiceConfig};
type Guards = Vec<Box<dyn Guard>>;
type HttpService<Err: ErrorRenderer> =
BoxService<WebRequest<Err>, WebResponse, Err::Container>;
type HttpNewService<Err: ErrorRenderer> =
BoxServiceFactory<SharedCfg, WebRequest<Err>, WebResponse, Err::Container, ()>;
type FnStateFactory = Box<dyn Fn(Extensions) -> BoxFuture<'static, Result<Extensions, ()>>>;
#[derive(derive_more::Debug)]
#[debug("AppFactory")]
pub struct AppFactory<T, F, Err: ErrorRenderer>
where
F: ServiceFactory<
WebRequest<Err>,
SharedCfg,
Response = WebRequest<Err>,
Error = Err::Container,
InitError = (),
>,
Err: ErrorRenderer,
{
pub(super) middleware: Rc<T>,
pub(super) filter: ServiceChainFactory<F, WebRequest<Err>, SharedCfg>,
pub(super) extensions: RefCell<Option<Extensions>>,
pub(super) state_factories: Rc<Vec<FnStateFactory>>,
pub(super) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory<Err>>>>>,
pub(super) default: Option<Rc<HttpNewService<Err>>>,
pub(super) external: RefCell<Vec<ResourceDef>>,
pub(super) case_insensitive: bool,
}
impl<T, F, Err> ServiceFactory<Request, SharedCfg> for AppFactory<T, F, Err>
where
T: Middleware<AppService<F::Service, Err>, SharedCfg> + 'static,
T::Service: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
F: ServiceFactory<
WebRequest<Err>,
SharedCfg,
Response = WebRequest<Err>,
Error = Err::Container,
InitError = (),
>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = Err::Container;
type InitError = ();
type Service = AppFactoryService<T::Service, Err>;
async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
let services = std::mem::take(&mut *self.services.borrow_mut());
let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::factory(fn_service(
|req: WebRequest<Err>| async move {
Ok(req.into_response(Response::NotFound().finish()))
},
)))
});
let filter_fut = self.filter.create(cfg.clone());
let state_factories = self.state_factories.clone();
let mut extensions = self.extensions.borrow_mut().take().unwrap_or_default();
let middleware = self.middleware.clone();
let external = std::mem::take(&mut *self.external.borrow_mut());
let mut router = Router::build();
if self.case_insensitive {
router.case_insensitive();
}
for fut in state_factories.iter() {
extensions = fut(extensions)
.await
.map_err(|()| log::error!("Cannot initialize state factory"))?;
}
let state = AppState::new(extensions, None, cfg.get());
let mut config = WebServiceConfig::new(state.clone(), default.clone());
for mut srv in services {
srv.register(&mut config);
}
let services = config.into_services();
let mut rmap = ResourceMap::new(ResourceDef::new(""));
for mut rdef in external {
rmap.add(&mut rdef, None);
}
let services: Vec<_> = services
.into_iter()
.map(|(mut rdef, srv, guards, nested)| {
rmap.add(&mut rdef, nested);
(rdef, srv, RefCell::new(guards))
})
.collect();
let rmap = Rc::new(rmap);
rmap.finish(&rmap);
for (path, factory, guards) in &mut services.iter() {
let service = factory
.create(cfg.clone())
.await
.map_err(|()| log::error!("Cannot construct app service"))?;
router.rdef(path.clone(), service).2 = guards.borrow_mut().take();
}
let routing = AppRouting {
router: router.finish(),
default: Some(
default
.create(cfg.clone())
.await
.map_err(|()| log::error!("Cannot construct default service"))?,
),
};
let service = AppService {
routing,
filter: filter_fut
.await
.map_err(|()| log::error!("Cannot construct app filter"))?,
};
Ok(AppFactoryService {
rmap,
state,
service: middleware.create(service, cfg),
_t: marker::PhantomData,
})
}
}
#[derive(derive_more::Debug)]
#[debug("AppFactoryService")]
pub struct AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
service: T,
rmap: Rc<ResourceMap>,
state: AppState,
_t: marker::PhantomData<Err>,
}
impl<T, Err> Service<Request> for AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = T::Error;
crate::forward_poll!(service);
crate::forward_ready!(service);
crate::forward_shutdown!(service);
async fn call(
&self,
req: Request,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
let (head, payload) = req.into_parts();
let req = if let Some(mut req) = self.state.config().get_request() {
let inner = Rc::get_mut(&mut req.0).unwrap();
inner.path.set(head.uri.clone());
inner.head = head;
inner.payload = payload;
inner.app_state = self.state.clone();
req
} else {
HttpRequest::new(
Path::new(head.uri.clone()),
head,
payload,
self.rmap.clone(),
self.state.clone(),
)
};
ctx.call(&self.service, WebRequest::new(req)).await
}
}
impl<T, Err> Drop for AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
fn drop(&mut self) {
self.state.config().clear_requests();
}
}
struct AppRouting<Err: ErrorRenderer> {
router: Router<HttpService<Err>, Guards>,
default: Option<HttpService<Err>>,
}
impl<Err: ErrorRenderer> Service<WebRequest<Err>> for AppRouting<Err> {
type Response = WebResponse;
type Error = Err::Container;
async fn call(
&self,
mut req: WebRequest<Err>,
ctx: ServiceCtx<'_, Self>,
) -> Result<WebResponse, Err::Container> {
let res = self.router.recognize_checked(&mut req, |req, guards| {
if let Some(guards) = guards {
for f in guards {
if !f.check(req.head()) {
return false;
}
}
}
true
});
if let Some((srv, _info)) = res {
ctx.call(srv, req).await
} else if let Some(ref default) = self.default {
ctx.call(default, req).await
} else {
let req = req.into_parts().0;
Ok(WebResponse::new(Response::NotFound().finish(), req))
}
}
}
#[derive(derive_more::Debug)]
#[debug("AppService")]
pub struct AppService<F, Err: ErrorRenderer> {
filter: F,
routing: AppRouting<Err>,
}
impl<F, Err> Service<WebRequest<Err>> for AppService<F, Err>
where
F: Service<WebRequest<Err>, Response = WebRequest<Err>, Error = Err::Container>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = Err::Container;
#[inline]
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
let (ready1, ready2) =
join(ctx.ready(&self.filter), ctx.ready(&self.routing)).await;
ready1?;
ready2
}
#[inline]
fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
self.filter.poll(cx)?;
self.routing.poll(cx)
}
async fn call(
&self,
req: WebRequest<Err>,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
let req = ctx.call(&self.filter, req).await?;
ctx.call(&self.routing, req).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::web::test::{TestRequest, init_service};
use crate::web::{self, App, HttpResponse};
struct DropData(Arc<AtomicBool>);
impl Drop for DropData {
fn drop(&mut self) {
self.0.store(true, Ordering::Relaxed);
}
}
#[crate::rt_test]
async fn test_drop_data() {
let data = Arc::new(AtomicBool::new(false));
{
let app = init_service(
App::new()
.state(DropData(data.clone()))
.service(web::resource("/test").to(|| async { HttpResponse::Ok() })),
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let _ = app.call(req).await.unwrap();
}
assert!(data.load(Ordering::Relaxed));
}
}