use std::cell::RefCell;
use std::future::Future;
use std::pin::pin;
use std::rc::Rc;
use std::task::Poll;
use crate::{
event::{Exchange, RequestHeaders, ResponseHeaders},
extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContextOnce},
handler::{ExtractionError, Handler, IntoHandler},
reactor::http::{FlowStatus, HttpReactor},
BoxFuture,
};
use super::{
context::{RequestContext, ResponseContext},
dynamic_exchange::DynamicExchange,
request_data::RequestData,
Flow, IntoFlow,
};
pub struct RequestFilter<ReqHnd> {
request_handler: ReqHnd,
}
impl<ReqHnd> RequestFilter<ReqHnd>
where
ReqHnd: Handler<RequestContext>,
ReqHnd::Output: IntoFlow,
{
pub fn on_response<ResHnd, I>(
self,
response_handler: ResHnd,
) -> DualFilter<ReqHnd, ResHnd::Handler>
where
ResHnd:
IntoHandler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, I, Output = ()>,
{
DualFilter {
request_handler: self.request_handler,
response_handler: response_handler.into_handler(),
}
}
}
impl<ReqHnd, T> Handler<FilterContext> for RequestFilter<ReqHnd>
where
ReqHnd: Handler<RequestContext>,
ReqHnd::Output: IntoFlow<RequestData = T>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
may_suspend_request(&reactor, async {
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let request_context = RequestContext::new(context, exchange.clone());
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
if let Flow::Break(response) = flow {
let mut exchange = exchange.borrow_mut();
exchange.send_response(
response.status_code(),
response.headers(),
response.body(),
);
}
Ok(())
})
.await
.unwrap_or(Ok(()))
})
}
}
pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
where
ReqHnd: IntoHandler<RequestContext, I>,
ReqHnd::Output: IntoFlow,
{
RequestFilter {
request_handler: request_handler.into_handler(),
}
}
pub struct ResponseFilter<ResHnd> {
response_handler: ResHnd,
}
impl<ResHnd> Handler<FilterContext> for ResponseFilter<ResHnd>
where
ResHnd: Handler<ResponseContext<()>, Output = ()>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
may_suspend_response(&reactor, async {
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let response_context = ResponseContext::new(context, exchange, RequestData::Break);
self.response_handler.call(response_context).await?;
Ok(())
})
.await
.unwrap_or(Ok(()))
})
}
}
pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
where
ResHnd: IntoHandler<ResponseContext<()>, I, Output = ()>,
{
ResponseFilter {
response_handler: response_handler.into_handler(),
}
}
pub struct DualFilter<ReqHnd, ResHnd> {
request_handler: ReqHnd,
response_handler: ResHnd,
}
impl<ReqHnd, ResHnd> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd>
where
ReqHnd: Handler<RequestContext>,
ReqHnd::Output: IntoFlow,
ResHnd: Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, Output = ()>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
})?;
let reactor = Rc::clone(&exchange.reactor);
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let request_data = may_suspend_request(&reactor, async {
let request_context = RequestContext::new(context.clone(), exchange.clone());
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
match flow {
Flow::Break(response) => {
exchange.borrow_mut().send_response(
response.status_code(),
response.headers(),
response.body(),
);
Ok(RequestData::Break)
}
Flow::Continue(data) => Ok(RequestData::Continue(data)),
}
})
.await
.unwrap_or(Ok(RequestData::Cancel))?;
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
})?;
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
may_suspend_response(
&reactor,
self.response_handler
.call(ResponseContext::new(context, exchange, request_data)),
)
.await
.unwrap_or(Ok(()))
})
}
}
async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.request_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Request for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}
async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.response_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Response for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}