use std::cell::RefCell;
use std::future::Future;
use std::pin::pin;
use std::rc::Rc;
use std::task::Poll;
use crate::{
event::{Exchange, ExchangeComplete, RequestHeaders, ResponseHeaders, Start},
extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContext, FromContextOnce},
handler::{ExtractionError, Handler, IntoHandler},
hl::state::{CreateHandler, DoneHandler, EmptyCreateHandler},
host::Host,
reactor::http::{FlowStatus, HttpReactor},
BoxFuture,
};
#[cfg(feature = "experimental_websocket")]
use crate::BoxError;
#[cfg(feature = "experimental_websocket")]
use crate::extract::context::{UpgradeDownstreamContext, UpgradeUpstreamContext};
use super::{
context::{RequestContext, ResponseContext},
dynamic_exchange::DynamicExchange,
request_data::RequestData,
Flow, IntoFlow,
};
pub struct RequestFilter<ReqHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) request_handler: ReqHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ReqHnd, Sf, Done> RequestFilter<ReqHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow,
Sf: CreateHandler,
{
pub fn on_response<ResHnd, I>(
self,
response_handler: ResHnd,
) -> DualFilter<ReqHnd, ResHnd::Handler, Sf, Done>
where
ResHnd: IntoHandler<
ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>,
I,
Output = (),
>,
{
DualFilter {
request_handler: self.request_handler,
response_handler: response_handler.into_handler(),
state_factory: self.state_factory,
done_handler: self.done_handler,
}
}
}
impl<ReqHnd, T, Sf, Done> Handler<FilterContext> for RequestFilter<ReqHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow<RequestData = T>,
Sf: CreateHandler<State: Clone + 'static>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let state = Rc::new(self.state_factory.create());
#[cfg(feature = "experimental_websocket")]
{
*context.parent_rc().shared_state.borrow_mut() =
Some(Rc::clone(&state) as Rc<dyn std::any::Any>);
}
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let result = may_suspend_request(&reactor, async {
let request_context =
RequestContext::new(context, exchange.clone(), Rc::clone(&state));
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
if let Flow::Break(response) = flow {
let mut exchange = exchange.borrow_mut();
exchange.send_response(
response.status_code(),
response.headers(),
response.body(),
);
}
Ok(())
})
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
match Rc::try_unwrap(state) {
Ok(state) => self.done_handler.done(state),
Err(rc) => self.done_handler.done((*rc).clone()),
}
result
})
}
}
pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
where
ReqHnd: IntoHandler<RequestContext<()>, I>,
ReqHnd::Output: IntoFlow,
{
RequestFilter {
request_handler: request_handler.into_handler(),
state_factory: EmptyCreateHandler,
done_handler: (),
}
}
pub struct ResponseFilter<ResHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) response_handler: ResHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ResHnd, Sf, Done> Handler<FilterContext> for ResponseFilter<ResHnd, Sf, Done>
where
ResHnd: Handler<ResponseContext<(), Sf::State>, Output = ()>,
Sf: CreateHandler<State: Clone + 'static>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|e| ExtractionError(e.into()))?;
let reactor = Rc::clone(&exchange.reactor);
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let state = Rc::new(self.state_factory.create());
#[cfg(feature = "experimental_websocket")]
{
*context.parent_rc().shared_state.borrow_mut() =
Some(Rc::clone(&state) as Rc<dyn std::any::Any>);
}
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let result = may_suspend_response(&reactor, async {
let response_context = ResponseContext::new(
context,
exchange.clone(),
RequestData::Break,
Rc::clone(&state),
);
self.response_handler.call(response_context).await?;
Ok(())
})
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
match Rc::try_unwrap(state) {
Ok(state) => self.done_handler.done(state),
Err(rc) => self.done_handler.done((*rc).clone()),
}
result
})
}
}
pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
where
ResHnd: IntoHandler<ResponseContext<(), ()>, I, Output = ()>,
{
ResponseFilter {
response_handler: response_handler.into_handler(),
state_factory: EmptyCreateHandler,
done_handler: (),
}
}
pub struct DualFilter<ReqHnd, ResHnd, Sf = EmptyCreateHandler, Done = ()> {
pub(super) request_handler: ReqHnd,
pub(super) response_handler: ResHnd,
pub(super) state_factory: Sf,
pub(super) done_handler: Done,
}
impl<ReqHnd, ResHnd, Sf, Done> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd, Sf, Done>
where
ReqHnd: Handler<RequestContext<Sf::State>>,
ReqHnd::Output: IntoFlow,
ResHnd:
Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>, Output = ()>,
Sf: CreateHandler<State: Clone + 'static>,
Done: DoneHandler<Sf::State>,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
#[allow(clippy::await_holding_refcell_ref)]
Box::pin(async move {
let context = Rc::new(context);
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
})?;
let reactor = Rc::clone(&exchange.reactor);
let state = Rc::new(self.state_factory.create());
#[cfg(feature = "experimental_websocket")]
{
*context.parent_rc().shared_state.borrow_mut() =
Some(Rc::clone(&state) as Rc<dyn std::any::Any>);
}
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let host: Rc<dyn Host> =
FromContext::<_, crate::extract::extractability::Transitive>::from_context(
&*context,
)
.unwrap();
let request_data = may_suspend_request(&reactor, async {
let request_context =
RequestContext::new(context.clone(), exchange.clone(), Rc::clone(&state));
let flow = self
.request_handler
.call(request_context)
.await?
.into_flow();
match flow {
Flow::Break(response) => {
exchange.borrow_mut().send_response(
response.status_code(),
response.headers(),
response.body(),
);
Ok(RequestData::Break)
}
Flow::Continue(data) => Ok(RequestData::Continue(data)),
}
})
.await
.unwrap_or(Ok(RequestData::Cancel))?;
let exclusive_context = Exclusive::new(context.as_ref());
let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
.await
.map_err(|_| {
ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
})?;
let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
let result = may_suspend_response(
&reactor,
self.response_handler.call(ResponseContext::new(
context.clone(),
exchange.clone(),
request_data,
state.clone(),
)),
)
.await
.unwrap_or(Ok(()));
wait_for_on_done(Rc::clone(&reactor), host).await;
match Rc::try_unwrap(state) {
Ok(state) => self.done_handler.done(state),
Err(rc) => self.done_handler.done((*rc).clone()),
}
result
})
}
}
async fn wait_for_on_done(reactor: Rc<HttpReactor>, host: Rc<dyn Host>) {
let complete_exchange: Exchange<Start> = Exchange::new(reactor, host, None);
let _ = complete_exchange.wait_for_event::<ExchangeComplete>().await;
}
async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.request_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Request for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}
async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
let mut task = pin!(task);
std::future::poll_fn(move |cx| match reactor.response_status() {
FlowStatus::Suspended => {
let id: u32 = reactor.context_id().into();
log::debug!("Response for filter with context id {id} has been suspended.");
Poll::Ready(None)
}
FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
})
.await
}
#[cfg(feature = "experimental_websocket")]
use crate::extract::context::WebSocketHandlerFn;
#[cfg(feature = "experimental_websocket")]
pub struct WebSocketFilter<BaseFilter, Sf, WsUp = (), WsDown = ()> {
pub(super) base_filter: BaseFilter,
pub(super) websocket_upstream: WsUp,
pub(super) websocket_downstream: WsDown,
pub(super) _sf: std::marker::PhantomData<Sf>,
}
#[cfg(feature = "experimental_websocket")]
impl<BaseFilter, Sf, WsUp, WsDown> WebSocketFilter<BaseFilter, Sf, WsUp, WsDown>
where
Sf: CreateHandler,
{
pub fn on_upgrade_upstream<WsUpHnd, I>(
self,
websocket_upstream: WsUpHnd,
) -> WebSocketFilter<BaseFilter, Sf, WsUpHnd::Handler, WsDown>
where
WsUpHnd: IntoHandler<UpgradeUpstreamContext<Sf::State>, I, Output = Result<(), BoxError>>,
{
WebSocketFilter {
base_filter: self.base_filter,
websocket_upstream: websocket_upstream.into_handler(),
websocket_downstream: self.websocket_downstream,
_sf: std::marker::PhantomData,
}
}
pub fn on_upgrade_downstream<WsDownHnd, I>(
self,
websocket_downstream: WsDownHnd,
) -> WebSocketFilter<BaseFilter, Sf, WsUp, WsDownHnd::Handler>
where
WsDownHnd:
IntoHandler<UpgradeDownstreamContext<Sf::State>, I, Output = Result<(), BoxError>>,
{
WebSocketFilter {
base_filter: self.base_filter,
websocket_upstream: self.websocket_upstream,
websocket_downstream: websocket_downstream.into_handler(),
_sf: std::marker::PhantomData,
}
}
}
#[cfg(feature = "experimental_websocket")]
impl<BaseFilter, Sf, WsUp, WsDown> Handler<FilterContext>
for WebSocketFilter<BaseFilter, Sf, WsUp, WsDown>
where
BaseFilter: Handler<FilterContext, Output = ()>,
Sf: CreateHandler,
Sf::State: Clone + 'static,
WsUp:
Handler<UpgradeUpstreamContext<Sf::State>, Output = Result<(), BoxError>> + Clone + 'static,
WsDown: Handler<UpgradeDownstreamContext<Sf::State>, Output = Result<(), BoxError>>
+ Clone
+ 'static,
{
type Output = ();
type Future<'h>
= BoxFuture<'h, Result<Self::Output, ExtractionError>>
where
Self: 'h;
fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
where
Self: 'h,
{
Box::pin(async move {
let configure_context = Rc::clone(context.parent_rc());
if std::any::TypeId::of::<WsUp>() != std::any::TypeId::of::<()>() {
let handler = self.websocket_upstream.clone();
let parent_ctx = Rc::clone(&configure_context);
let upstream_closure: WebSocketHandlerFn = Box::new(move |reactor| {
let state: Option<Rc<Sf::State>> = parent_ctx
.shared_state
.borrow()
.as_ref()
.and_then(|s| s.clone().downcast::<Sf::State>().ok());
let ctx = state.map(|state| {
UpgradeUpstreamContext::new(
Rc::clone(&parent_ctx),
state,
Rc::clone(&reactor),
)
});
let handler = handler.clone();
Box::pin(async move {
match ctx {
Some(ctx) => handler
.call(ctx)
.await
.map_err(|e| Box::new(e) as BoxError)?,
None => {
Err(Box::from("WebSocket upstream: state unavailable") as BoxError)?
}
}
}) as BoxFuture<'static, Result<(), BoxError>>
});
*configure_context.websocket_upstream_handler.borrow_mut() = Some(upstream_closure);
}
if std::any::TypeId::of::<WsDown>() != std::any::TypeId::of::<()>() {
let handler = self.websocket_downstream.clone();
let parent_ctx = Rc::clone(&configure_context);
let downstream_closure: WebSocketHandlerFn =
Box::new(move |reactor| {
let state: Option<Rc<Sf::State>> = parent_ctx
.shared_state
.borrow()
.as_ref()
.and_then(|s| s.clone().downcast::<Sf::State>().ok());
let ctx = state.map(|state| {
UpgradeDownstreamContext::new(
Rc::clone(&parent_ctx),
state,
Rc::clone(&reactor),
)
});
let handler = handler.clone();
Box::pin(async move {
match ctx {
Some(ctx) => handler
.call(ctx)
.await
.map_err(|e| Box::new(e) as BoxError)?,
None => Err(Box::from("WebSocket downstream: state unavailable")
as BoxError)?,
}
}) as BoxFuture<'static, Result<(), BoxError>>
});
*configure_context.websocket_downstream_handler.borrow_mut() =
Some(downstream_closure);
}
self.base_filter.call(context).await
})
}
}