use std::cell::RefCell;
use std::future::Future;
use std::pin::pin;
use std::rc::Rc;
use std::task::Poll;
use crate::{
event::{Exchange, ExchangeComplete, RequestHeaders, ResponseHeaders, Start},
extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContext, FromContextOnce},
handler::{ExtractionError, Handler, IntoHandler},
hl::state::{CreateHandler, DoneHandler, EmptyCreateHandler},
host::Host,
reactor::http::{FlowStatus, HttpReactor},
BoxFuture,
};
use super::{
context::{RequestContext, ResponseContext},
dynamic_exchange::DynamicExchange,
request_data::RequestData,
Flow, IntoFlow,
};
pub struct RequestFilter<ReqHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) request_handler: ReqHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ReqHnd, Sf, Done> RequestFilter<ReqHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow,
Sf: CreateHandler,
{
pub fn on_response<ResHnd, I>(
self,
response_handler: ResHnd,
) -> DualFilter<ReqHnd, ResHnd::Handler, Sf, Done>
where
ResHnd: IntoHandler<
ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>,
I,
Output = (),
>,
{
DualFilter {
request_handler: self.request_handler,
response_handler: response_handler.into_handler(),
state_factory: self.state_factory,
done_handler: self.done_handler,
}
}
}
impl<ReqHnd, T, Sf, Done> Handler<FilterContext> for RequestFilter<ReqHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow<RequestData = T>,
Sf: CreateHandler<State: Clone>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let state = Rc::new(self.state_factory.create());
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let result = may_suspend_request(&reactor, async {
let request_context =
RequestContext::new(context, exchange.clone(), Rc::clone(&state));
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
if let Flow::Break(response) = flow {
let mut exchange = exchange.borrow_mut();
exchange.send_response(
response.status_code(),
response.headers(),
response.body(),
);
}
Ok(())
})
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
if let Ok(state) = Rc::try_unwrap(state) {
self.done_handler.done(state);
}
result
})
}
}
pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
where
ReqHnd: IntoHandler<RequestContext<()>, I>,
ReqHnd::Output: IntoFlow,
{
RequestFilter {
request_handler: request_handler.into_handler(),
state_factory: EmptyCreateHandler,
done_handler: (),
}
}
pub struct ResponseFilter<ResHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) response_handler: ResHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ResHnd, Sf, Done> Handler<FilterContext> for ResponseFilter<ResHnd, Sf, Done>
where
ResHnd: Handler<ResponseContext<(), Sf::State>, Output = ()>,
Sf: CreateHandler<State: Clone>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let state = Rc::new(self.state_factory.create());
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let result = may_suspend_response(&reactor, async {
let response_context = ResponseContext::new(
context,
exchange.clone(),
RequestData::Break,
Rc::clone(&state),
);
self.response_handler.call(response_context).await?;
Ok(())
})
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
if let Ok(state) = Rc::try_unwrap(state) {
self.done_handler.done(state);
}
result
})
}
}
pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
where
ResHnd: IntoHandler<ResponseContext<(), ()>, I, Output = ()>,
{
ResponseFilter {
response_handler: response_handler.into_handler(),
state_factory: EmptyCreateHandler,
done_handler: (),
}
}
pub struct DualFilter<ReqHnd, ResHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) request_handler: ReqHnd,
pub(super) response_handler: ResHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ReqHnd, ResHnd, Sf, Done> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow,
ResHnd:
Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>, Output = ()>,
Sf: CreateHandler<State: Clone>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
})?;
let reactor = Rc::clone(&exchange.reactor);
let state = Rc::new(self.state_factory.create());
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let request_data = may_suspend_request(&reactor, async {
let request_context =
RequestContext::new(context.clone(), exchange.clone(), Rc::clone(&state));
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
match flow {
Flow::Break(response) => {
exchange.borrow_mut().send_response(
response.status_code(),
response.headers(),
response.body(),
);
Ok(RequestData::Break)
}
Flow::Continue(data) => Ok(RequestData::Continue(data)),
}
})
.await
.unwrap_or(Ok(RequestData::Cancel))?;
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
})?;
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let result = may_suspend_response(
&reactor,
self.response_handler.call(ResponseContext::new(
context.clone(),
exchange.clone(),
request_data,
state.clone(),
)),
)
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
if let Ok(state) = Rc::try_unwrap(state) {
self.done_handler.done(state);
}
result
})
}
}
async fn wait_for_on_done(reactor: Rc<HttpReactor>, host: Rc<dyn Host>) {
let complete_exchange: Exchange<Start> = Exchange::new(reactor, host, None);
let _ = complete_exchange.wait_for_event::<ExchangeComplete>().await;
}
async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.request_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Request for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}
async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.response_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Response for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}