use std::task::{Context, Poll};
use std::{cell::RefCell, future::Future, marker::PhantomData, pin::Pin, rc::Rc};
use crate::http::{Request, Response};
use crate::router::{Path, ResourceDef, Router};
use crate::service::boxed::{self, BoxService, BoxServiceFactory};
use crate::service::{fn_service, Middleware, PipelineFactory, Service, ServiceFactory};
use crate::util::{BoxFuture, Extensions};
use super::config::AppConfig;
use super::error::ErrorRenderer;
use super::guard::Guard;
use super::httprequest::{HttpRequest, HttpRequestPool};
use super::request::WebRequest;
use super::response::WebResponse;
use super::rmap::ResourceMap;
use super::service::{AppServiceFactory, AppState, WebServiceConfig};
type Guards = Vec<Box<dyn Guard>>;
type HttpService<Err: ErrorRenderer> =
BoxService<WebRequest<Err>, WebResponse, Err::Container>;
type HttpNewService<Err: ErrorRenderer> =
BoxServiceFactory<(), WebRequest<Err>, WebResponse, Err::Container, ()>;
type BoxResponse<'f, Err: ErrorRenderer> =
BoxFuture<'f, Result<WebResponse, Err::Container>>;
type FnStateFactory = Box<dyn Fn(Extensions) -> BoxFuture<'static, Result<Extensions, ()>>>;
pub struct AppFactory<T, F, Err: ErrorRenderer>
where
F: ServiceFactory<
WebRequest<Err>,
Response = WebRequest<Err>,
Error = Err::Container,
InitError = (),
>,
Err: ErrorRenderer,
{
pub(super) middleware: Rc<T>,
pub(super) filter: PipelineFactory<WebRequest<Err>, F>,
pub(super) extensions: RefCell<Option<Extensions>>,
pub(super) state_factories: Rc<Vec<FnStateFactory>>,
pub(super) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory<Err>>>>>,
pub(super) default: Option<Rc<HttpNewService<Err>>>,
pub(super) external: RefCell<Vec<ResourceDef>>,
pub(super) case_insensitive: bool,
}
impl<T, F, Err> ServiceFactory<Request> for AppFactory<T, F, Err>
where
T: Middleware<AppService<F::Service, Err>> + 'static,
T::Service: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
F: ServiceFactory<
WebRequest<Err>,
Response = WebRequest<Err>,
Error = Err::Container,
InitError = (),
>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = Err::Container;
type InitError = ();
type Service = AppFactoryService<T::Service, Err>;
type Future<'f> = BoxFuture<'f, Result<Self::Service, Self::InitError>> where Self: 'f;
fn create(&self, _: ()) -> Self::Future<'_> {
ServiceFactory::create(self, AppConfig::default())
}
}
impl<T, F, Err> ServiceFactory<Request, AppConfig> for AppFactory<T, F, Err>
where
T: Middleware<AppService<F::Service, Err>> + 'static,
T::Service: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
F: ServiceFactory<
WebRequest<Err>,
Response = WebRequest<Err>,
Error = Err::Container,
InitError = (),
>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = Err::Container;
type InitError = ();
type Service = AppFactoryService<T::Service, Err>;
type Future<'f> = BoxFuture<'f, Result<Self::Service, Self::InitError>> where Self: 'f;
fn create(&self, config: AppConfig) -> Self::Future<'_> {
let services = std::mem::take(&mut *self.services.borrow_mut());
let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::factory(fn_service(
|req: WebRequest<Err>| async move {
Ok(req.into_response(Response::NotFound().finish()))
},
)))
});
let filter_fut = self.filter.create(());
let state_factories = self.state_factories.clone();
let mut extensions = self
.extensions
.borrow_mut()
.take()
.unwrap_or_else(Extensions::new);
let middleware = self.middleware.clone();
let external = std::mem::take(&mut *self.external.borrow_mut());
let mut router = Router::build();
if self.case_insensitive {
router.case_insensitive();
}
Box::pin(async move {
for fut in state_factories.iter() {
extensions = fut(extensions).await?;
}
let state = AppState::new(extensions, None, config.clone());
let mut config = WebServiceConfig::new(state.clone(), default.clone());
services
.into_iter()
.for_each(|mut srv| srv.register(&mut config));
let services = config.into_services();
let mut rmap = ResourceMap::new(ResourceDef::new(""));
for mut rdef in external {
rmap.add(&mut rdef, None);
}
let services: Vec<_> = services
.into_iter()
.map(|(mut rdef, srv, guards, nested)| {
rmap.add(&mut rdef, nested);
(rdef, srv, RefCell::new(guards))
})
.collect();
let rmap = Rc::new(rmap);
rmap.finish(rmap.clone());
for (path, factory, guards) in &mut services.iter() {
let service = factory.create(()).await?;
router.rdef(path.clone(), service).2 = guards.borrow_mut().take();
}
let routing = AppRouting {
router: router.finish(),
default: Some(default.create(()).await?),
};
let service = AppService {
routing,
filter: filter_fut.await?,
};
Ok(AppFactoryService {
rmap,
state,
service: middleware.create(service),
pool: HttpRequestPool::create(),
_t: PhantomData,
})
})
}
}
pub struct AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
service: T,
rmap: Rc<ResourceMap>,
state: AppState,
pool: &'static HttpRequestPool,
_t: PhantomData<Err>,
}
impl<T, Err> Service<Request> for AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = T::Error;
type Future<'f> = T::Future<'f> where T: 'f;
crate::forward_poll_ready!(service);
crate::forward_poll_shutdown!(service);
fn call(&self, req: Request) -> Self::Future<'_> {
let (head, payload) = req.into_parts();
let req = if let Some(mut req) = self.pool.get_request() {
let inner = Rc::get_mut(&mut req.0).unwrap();
inner.path.set(head.uri.clone());
inner.head = head;
inner.payload = payload;
inner.app_state = self.state.clone();
req
} else {
HttpRequest::new(
Path::new(head.uri.clone()),
head,
payload,
self.rmap.clone(),
self.state.clone(),
self.pool,
)
};
self.service.call(WebRequest::new(req))
}
}
impl<T, Err> Drop for AppFactoryService<T, Err>
where
T: Service<WebRequest<Err>, Response = WebResponse, Error = Err::Container>,
Err: ErrorRenderer,
{
fn drop(&mut self) {
self.pool.clear();
}
}
struct AppRouting<Err: ErrorRenderer> {
router: Router<HttpService<Err>, Guards>,
default: Option<HttpService<Err>>,
}
impl<Err: ErrorRenderer> Service<WebRequest<Err>> for AppRouting<Err> {
type Response = WebResponse;
type Error = Err::Container;
type Future<'f> = BoxResponse<'f, Err>;
fn call(&self, mut req: WebRequest<Err>) -> Self::Future<'_> {
let res = self.router.recognize_checked(&mut req, |req, guards| {
if let Some(guards) = guards {
for f in guards {
if !f.check(req.head()) {
return false;
}
}
}
true
});
if let Some((srv, _info)) = res {
srv.call(req)
} else if let Some(ref default) = self.default {
default.call(req)
} else {
let req = req.into_parts().0;
Box::pin(async { Ok(WebResponse::new(Response::NotFound().finish(), req)) })
}
}
}
pub struct AppService<F, Err: ErrorRenderer> {
filter: F,
routing: AppRouting<Err>,
}
impl<F, Err> Service<WebRequest<Err>> for AppService<F, Err>
where
F: Service<WebRequest<Err>, Response = WebRequest<Err>, Error = Err::Container>,
Err: ErrorRenderer,
{
type Response = WebResponse;
type Error = Err::Container;
type Future<'f> = AppServiceResponse<'f, F, Err> where F: 'f;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let ready1 = self.filter.poll_ready(cx)?.is_ready();
let ready2 = self.routing.poll_ready(cx)?.is_ready();
if ready1 && ready2 {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn call(&self, req: WebRequest<Err>) -> Self::Future<'_> {
AppServiceResponse {
filter: self.filter.call(req),
routing: &self.routing,
endpoint: None,
}
}
}
pin_project_lite::pin_project! {
pub struct AppServiceResponse<'f, F: Service<WebRequest<Err>>, Err: ErrorRenderer>
where F: 'f
{
#[pin]
filter: F::Future<'f>,
routing: &'f AppRouting<Err>,
endpoint: Option<BoxResponse<'f, Err>>,
}
}
impl<'f, F, Err> Future for AppServiceResponse<'f, F, Err>
where
F: Service<WebRequest<Err>, Response = WebRequest<Err>, Error = Err::Container>,
Err: ErrorRenderer,
{
type Output = Result<WebResponse, Err::Container>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
loop {
if let Some(fut) = this.endpoint.as_mut() {
return Pin::new(fut).poll(cx);
} else {
let res = if let Poll::Ready(res) = this.filter.poll(cx) {
res?
} else {
return Poll::Pending;
};
*this.endpoint = Some(this.routing.call(res));
this = self.as_mut().project();
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::service::Service;
use crate::web::test::{init_service, TestRequest};
use crate::web::{self, App, HttpResponse};
struct DropData(Arc<AtomicBool>);
impl Drop for DropData {
fn drop(&mut self) {
self.0.store(true, Ordering::Relaxed);
}
}
#[crate::rt_test]
async fn test_drop_data() {
let data = Arc::new(AtomicBool::new(false));
{
let app = init_service(
App::new()
.state(DropData(data.clone()))
.service(web::resource("/test").to(|| async { HttpResponse::Ok() })),
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let _ = app.call(req).await.unwrap();
}
assert!(data.load(Ordering::Relaxed));
}
}