1use std::cell::RefCell;
6use std::future::Future;
7use std::pin::pin;
8use std::rc::Rc;
9use std::task::Poll;
10
11use crate::{
12 event::{Exchange, RequestHeaders, ResponseHeaders},
13 extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContextOnce},
14 handler::{ExtractionError, Handler, IntoHandler},
15 reactor::http::{FlowStatus, HttpReactor},
16 BoxFuture,
17};
18
19use super::{
20 context::{RequestContext, ResponseContext},
21 dynamic_exchange::DynamicExchange,
22 request_data::RequestData,
23 Flow, IntoFlow,
24};
25
26pub struct RequestFilter<ReqHnd> {
27 request_handler: ReqHnd,
28}
29
30impl<ReqHnd> RequestFilter<ReqHnd>
31where
32 ReqHnd: Handler<RequestContext>,
33 ReqHnd::Output: IntoFlow,
34{
35 pub fn on_response<ResHnd, I>(
37 self,
38 response_handler: ResHnd,
39 ) -> DualFilter<ReqHnd, ResHnd::Handler>
40 where
41 ResHnd:
42 IntoHandler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, I, Output = ()>,
43 {
44 DualFilter {
45 request_handler: self.request_handler,
46 response_handler: response_handler.into_handler(),
47 }
48 }
49}
50
51impl<ReqHnd, T> Handler<FilterContext> for RequestFilter<ReqHnd>
52where
53 ReqHnd: Handler<RequestContext>,
54 ReqHnd::Output: IntoFlow<RequestData = T>,
55{
56 type Output = ();
57
58 type Future<'h>
59 = BoxFuture<'h, Result<Self::Output, ExtractionError>>
60 where
61 Self: 'h;
62
63 fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
64 where
65 Self: 'h,
66 {
67 Box::pin(async move {
68 let context = Rc::new(context);
69 let exclusive_context = Exclusive::new(context.as_ref());
70 let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
71 .await
72 .map_err(|e| ExtractionError(e.into()))?;
73
74 let reactor = Rc::clone(&exchange.reactor);
75
76 may_suspend_request(&reactor, async {
77 let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
78
79 let request_context = RequestContext::new(context, exchange.clone());
80
81 let flow = self
82 .request_handler
83 .call(request_context)
84 .await?
85 .into_flow();
86 if let Flow::Break(response) = flow {
87 let mut exchange = exchange.borrow_mut();
88 exchange.send_response(
89 response.status_code(),
90 response.headers(),
91 response.body(),
92 );
93 }
94 Ok(())
95 })
96 .await
97 .unwrap_or(Ok(()))
98 })
99 }
100}
101
102pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
104where
105 ReqHnd: IntoHandler<RequestContext, I>,
106 ReqHnd::Output: IntoFlow,
107{
108 RequestFilter {
109 request_handler: request_handler.into_handler(),
110 }
111}
112
113pub struct ResponseFilter<ResHnd> {
114 response_handler: ResHnd,
115}
116
117impl<ResHnd> Handler<FilterContext> for ResponseFilter<ResHnd>
118where
119 ResHnd: Handler<ResponseContext<()>, Output = ()>,
120{
121 type Output = ();
122
123 type Future<'h>
124 = BoxFuture<'h, Result<Self::Output, ExtractionError>>
125 where
126 Self: 'h;
127
128 fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
129 where
130 Self: 'h,
131 {
132 Box::pin(async move {
133 let context = Rc::new(context);
134 let exclusive_context = Exclusive::new(context.as_ref());
135 let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
136 .await
137 .map_err(|e| ExtractionError(e.into()))?;
138 let reactor = Rc::clone(&exchange.reactor);
139
140 may_suspend_response(&reactor, async {
141 let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
142 let response_context = ResponseContext::new(context, exchange, RequestData::Break);
143 self.response_handler.call(response_context).await?;
144 Ok(())
145 })
146 .await
147 .unwrap_or(Ok(()))
148 })
149 }
150}
151
152pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
154where
155 ResHnd: IntoHandler<ResponseContext<()>, I, Output = ()>,
156{
157 ResponseFilter {
158 response_handler: response_handler.into_handler(),
159 }
160}
161
162pub struct DualFilter<ReqHnd, ResHnd> {
163 request_handler: ReqHnd,
164 response_handler: ResHnd,
165}
166
167impl<ReqHnd, ResHnd> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd>
168where
169 ReqHnd: Handler<RequestContext>,
170 ReqHnd::Output: IntoFlow,
171 ResHnd: Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, Output = ()>,
172{
173 type Output = ();
174
175 type Future<'h>
176 = BoxFuture<'h, Result<Self::Output, ExtractionError>>
177 where
178 Self: 'h;
179
180 fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
181 where
182 Self: 'h,
183 {
184 Box::pin(async move {
185 let context = Rc::new(context);
186 let exclusive_context = Exclusive::new(context.as_ref());
187 let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
188 .await
189 .map_err(|_| {
190 ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
191 })?;
192 let reactor = Rc::clone(&exchange.reactor);
193
194 let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
195
196 let request_data = may_suspend_request(&reactor, async {
197 let request_context = RequestContext::new(context.clone(), exchange.clone());
198
199 let flow = self
200 .request_handler
201 .call(request_context)
202 .await?
203 .into_flow();
204 match flow {
205 Flow::Break(response) => {
206 exchange.borrow_mut().send_response(
207 response.status_code(),
208 response.headers(),
209 response.body(),
210 );
211 Ok(RequestData::Break)
212 }
213 Flow::Continue(data) => Ok(RequestData::Continue(data)),
214 }
215 })
216 .await
217 .unwrap_or(Ok(RequestData::Cancel))?;
218
219 let exclusive_context = Exclusive::new(context.as_ref());
220 let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
221 .await
222 .map_err(|_| {
223 ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
224 })?;
225
226 let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
227
228 may_suspend_response(
229 &reactor,
230 self.response_handler
231 .call(ResponseContext::new(context, exchange, request_data)),
232 )
233 .await
234 .unwrap_or(Ok(()))
235 })
236 }
237}
238
239async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
240 let mut task = pin!(task);
241
242 std::future::poll_fn(move |cx| match reactor.request_status() {
243 FlowStatus::Suspended => {
244 let id: u32 = reactor.context_id().into();
245 log::debug!("Request for filter with context id {id} has been suspended.");
246
247 Poll::Ready(None)
248 }
249 FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
250 })
251 .await
252}
253
254async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
255 let mut task = pin!(task);
256
257 std::future::poll_fn(move |cx| match reactor.response_status() {
258 FlowStatus::Suspended => {
259 let id: u32 = reactor.context_id().into();
260 log::debug!("Response for filter with context id {id} has been suspended.");
261
262 Poll::Ready(None)
263 }
264 FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
265 })
266 .await
267}