pdk_classy/hl/
bootstrap.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use std::cell::RefCell;
6use std::future::Future;
7use std::pin::pin;
8use std::rc::Rc;
9use std::task::Poll;
10
11use crate::{
12    event::{Exchange, RequestHeaders, ResponseHeaders},
13    extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContextOnce},
14    handler::{ExtractionError, Handler, IntoHandler},
15    reactor::http::{FlowStatus, HttpReactor},
16    BoxFuture,
17};
18
19use super::{
20    context::{RequestContext, ResponseContext},
21    dynamic_exchange::DynamicExchange,
22    request_data::RequestData,
23    Flow, IntoFlow,
24};
25
26pub struct RequestFilter<ReqHnd> {
27    request_handler: ReqHnd,
28}
29
30impl<ReqHnd> RequestFilter<ReqHnd>
31where
32    ReqHnd: Handler<RequestContext>,
33    ReqHnd::Output: IntoFlow,
34{
35    /// Creates a Response filter from a handler.
36    pub fn on_response<ResHnd, I>(
37        self,
38        response_handler: ResHnd,
39    ) -> DualFilter<ReqHnd, ResHnd::Handler>
40    where
41        ResHnd:
42            IntoHandler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, I, Output = ()>,
43    {
44        DualFilter {
45            request_handler: self.request_handler,
46            response_handler: response_handler.into_handler(),
47        }
48    }
49}
50
51impl<ReqHnd, T> Handler<FilterContext> for RequestFilter<ReqHnd>
52where
53    ReqHnd: Handler<RequestContext>,
54    ReqHnd::Output: IntoFlow<RequestData = T>,
55{
56    type Output = ();
57
58    type Future<'h>
59        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
60    where
61        Self: 'h;
62
63    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
64    where
65        Self: 'h,
66    {
67        Box::pin(async move {
68            let context = Rc::new(context);
69            let exclusive_context = Exclusive::new(context.as_ref());
70            let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
71                .await
72                .map_err(|e| ExtractionError(e.into()))?;
73
74            let reactor = Rc::clone(&exchange.reactor);
75
76            may_suspend_request(&reactor, async {
77                let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
78
79                let request_context = RequestContext::new(context, exchange.clone());
80
81                let flow = self
82                    .request_handler
83                    .call(request_context)
84                    .await?
85                    .into_flow();
86                if let Flow::Break(response) = flow {
87                    let mut exchange = exchange.borrow_mut();
88                    exchange.send_response(
89                        response.status_code(),
90                        response.headers(),
91                        response.body(),
92                    );
93                }
94                Ok(())
95            })
96            .await
97            .unwrap_or(Ok(()))
98        })
99    }
100}
101
102/// Creates a Request filter from a handler.
103pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
104where
105    ReqHnd: IntoHandler<RequestContext, I>,
106    ReqHnd::Output: IntoFlow,
107{
108    RequestFilter {
109        request_handler: request_handler.into_handler(),
110    }
111}
112
113pub struct ResponseFilter<ResHnd> {
114    response_handler: ResHnd,
115}
116
117impl<ResHnd> Handler<FilterContext> for ResponseFilter<ResHnd>
118where
119    ResHnd: Handler<ResponseContext<()>, Output = ()>,
120{
121    type Output = ();
122
123    type Future<'h>
124        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
125    where
126        Self: 'h;
127
128    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
129    where
130        Self: 'h,
131    {
132        Box::pin(async move {
133            let context = Rc::new(context);
134            let exclusive_context = Exclusive::new(context.as_ref());
135            let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
136                .await
137                .map_err(|e| ExtractionError(e.into()))?;
138            let reactor = Rc::clone(&exchange.reactor);
139
140            may_suspend_response(&reactor, async {
141                let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
142                let response_context = ResponseContext::new(context, exchange, RequestData::Break);
143                self.response_handler.call(response_context).await?;
144                Ok(())
145            })
146            .await
147            .unwrap_or(Ok(()))
148        })
149    }
150}
151
152/// Creates a Response filter from a handler.
153pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
154where
155    ResHnd: IntoHandler<ResponseContext<()>, I, Output = ()>,
156{
157    ResponseFilter {
158        response_handler: response_handler.into_handler(),
159    }
160}
161
162pub struct DualFilter<ReqHnd, ResHnd> {
163    request_handler: ReqHnd,
164    response_handler: ResHnd,
165}
166
167impl<ReqHnd, ResHnd> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd>
168where
169    ReqHnd: Handler<RequestContext>,
170    ReqHnd::Output: IntoFlow,
171    ResHnd: Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData>, Output = ()>,
172{
173    type Output = ();
174
175    type Future<'h>
176        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
177    where
178        Self: 'h;
179
180    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
181    where
182        Self: 'h,
183    {
184        Box::pin(async move {
185            let context = Rc::new(context);
186            let exclusive_context = Exclusive::new(context.as_ref());
187            let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
188                .await
189                .map_err(|_| {
190                    ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
191                })?;
192            let reactor = Rc::clone(&exchange.reactor);
193
194            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
195
196            let request_data = may_suspend_request(&reactor, async {
197                let request_context = RequestContext::new(context.clone(), exchange.clone());
198
199                let flow = self
200                    .request_handler
201                    .call(request_context)
202                    .await?
203                    .into_flow();
204                match flow {
205                    Flow::Break(response) => {
206                        exchange.borrow_mut().send_response(
207                            response.status_code(),
208                            response.headers(),
209                            response.body(),
210                        );
211                        Ok(RequestData::Break)
212                    }
213                    Flow::Continue(data) => Ok(RequestData::Continue(data)),
214                }
215            })
216            .await
217            .unwrap_or(Ok(RequestData::Cancel))?;
218
219            let exclusive_context = Exclusive::new(context.as_ref());
220            let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
221                .await
222                .map_err(|_| {
223                    ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
224                })?;
225
226            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
227
228            may_suspend_response(
229                &reactor,
230                self.response_handler
231                    .call(ResponseContext::new(context, exchange, request_data)),
232            )
233            .await
234            .unwrap_or(Ok(()))
235        })
236    }
237}
238
239async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
240    let mut task = pin!(task);
241
242    std::future::poll_fn(move |cx| match reactor.request_status() {
243        FlowStatus::Suspended => {
244            let id: u32 = reactor.context_id().into();
245            log::debug!("Request for filter with context id {id} has been suspended.");
246
247            Poll::Ready(None)
248        }
249        FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
250    })
251    .await
252}
253
254async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
255    let mut task = pin!(task);
256
257    std::future::poll_fn(move |cx| match reactor.response_status() {
258        FlowStatus::Suspended => {
259            let id: u32 = reactor.context_id().into();
260            log::debug!("Response for filter with context id {id} has been suspended.");
261
262            Poll::Ready(None)
263        }
264        FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
265    })
266    .await
267}