pdk-classy 1.9.0-alpha.3

PDK Classy
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use std::cell::RefCell;
use std::future::Future;
use std::pin::pin;
use std::rc::Rc;
use std::task::Poll;

use crate::{
    event::{Exchange, ExchangeComplete, RequestHeaders, ResponseHeaders, Start},
    extract::{context::FilterContext, AlreadyExtracted, Exclusive, FromContext, FromContextOnce},
    handler::{ExtractionError, Handler, IntoHandler},
    hl::state::{CreateHandler, DoneHandler, EmptyCreateHandler},
    host::Host,
    reactor::http::{FlowStatus, HttpReactor},
    BoxFuture,
};

use super::{
    context::{RequestContext, ResponseContext},
    dynamic_exchange::DynamicExchange,
    request_data::RequestData,
    Flow, IntoFlow,
};

pub struct RequestFilter<ReqHnd, Sf = EmptyCreateHandler, Done = ()> {
    pub(super) request_handler: ReqHnd,
    pub(super) state_factory: Sf,
    pub(super) done_handler: Done,
}

impl<ReqHnd, Sf, Done> RequestFilter<ReqHnd, Sf, Done>
where
    ReqHnd: Handler<RequestContext<Sf::State>>,
    ReqHnd::Output: IntoFlow,
    Sf: CreateHandler,
{
    /// Creates a Response filter from a handler.
    pub fn on_response<ResHnd, I>(
        self,
        response_handler: ResHnd,
    ) -> DualFilter<ReqHnd, ResHnd::Handler, Sf, Done>
    where
        ResHnd: IntoHandler<
            ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>,
            I,
            Output = (),
        >,
    {
        DualFilter {
            request_handler: self.request_handler,
            response_handler: response_handler.into_handler(),
            state_factory: self.state_factory,
            done_handler: self.done_handler,
        }
    }
}

impl<ReqHnd, T, Sf, Done> Handler<FilterContext> for RequestFilter<ReqHnd, Sf, Done>
where
    ReqHnd: Handler<RequestContext<Sf::State>>,
    ReqHnd::Output: IntoFlow<RequestData = T>,
    Sf: CreateHandler<State: Clone>,
    Done: DoneHandler<Sf::State>,
{
    type Output = ();

    type Future<'h>
        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
    where
        Self: 'h;

    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
    where
        Self: 'h,
    {
        #[allow(clippy::await_holding_refcell_ref)]
        Box::pin(async move {
            let context = Rc::new(context);
            let exclusive_context = Exclusive::new(context.as_ref());
            let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
                .await
                .map_err(|e| ExtractionError(e.into()))?;

            let reactor = Rc::clone(&exchange.reactor);
            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
            let state = Rc::new(self.state_factory.create());

            // Extract host for potential use in ExchangeComplete waiting
            let host: Rc<dyn Host> =
                FromContext::<_, crate::extract::extractability::Transitive>::from_context(
                    &*context,
                )
                .unwrap();

            let result = may_suspend_request(&reactor, async {
                let request_context =
                    RequestContext::new(context, exchange.clone(), Rc::clone(&state));

                let flow = self
                    .request_handler
                    .call(request_context)
                    .await?
                    .into_flow();
                if let Flow::Break(response) = flow {
                    let mut exchange = exchange.borrow_mut();
                    exchange.send_response(
                        response.status_code(),
                        response.headers(),
                        response.body(),
                    );
                }

                Ok(())
            })
            .await
            .unwrap_or(Ok(()));

            // Wait for ExchangeComplete event (triggered by on_done from proxy-wasm)
            wait_for_on_done(Rc::clone(&reactor), host).await;

            // Execute done handler if state can be unwrapped
            if let Ok(state) = Rc::try_unwrap(state) {
                self.done_handler.done(state);
            }

            result
        })
    }
}

/// Creates a Request filter from a handler.
pub fn on_request<ReqHnd, I>(request_handler: ReqHnd) -> RequestFilter<ReqHnd::Handler>
where
    ReqHnd: IntoHandler<RequestContext<()>, I>,
    ReqHnd::Output: IntoFlow,
{
    RequestFilter {
        request_handler: request_handler.into_handler(),
        state_factory: EmptyCreateHandler,
        done_handler: (),
    }
}

pub struct ResponseFilter<ResHnd, Sf = EmptyCreateHandler, Done = ()> {
    pub(super) response_handler: ResHnd,
    pub(super) state_factory: Sf,
    pub(super) done_handler: Done,
}

impl<ResHnd, Sf, Done> Handler<FilterContext> for ResponseFilter<ResHnd, Sf, Done>
where
    ResHnd: Handler<ResponseContext<(), Sf::State>, Output = ()>,
    Sf: CreateHandler<State: Clone>,
    Done: DoneHandler<Sf::State>,
{
    type Output = ();

    type Future<'h>
        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
    where
        Self: 'h;

    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
    where
        Self: 'h,
    {
        #[allow(clippy::await_holding_refcell_ref)]
        Box::pin(async move {
            let context = Rc::new(context);
            let exclusive_context = Exclusive::new(context.as_ref());
            let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
                .await
                .map_err(|e| ExtractionError(e.into()))?;
            let reactor = Rc::clone(&exchange.reactor);

            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));
            let state = Rc::new(self.state_factory.create());

            // Extract host for potential use in ExchangeComplete waiting
            let host: Rc<dyn Host> =
                FromContext::<_, crate::extract::extractability::Transitive>::from_context(
                    &*context,
                )
                .unwrap();

            let result = may_suspend_response(&reactor, async {
                let response_context = ResponseContext::new(
                    context,
                    exchange.clone(),
                    RequestData::Break,
                    Rc::clone(&state),
                );
                self.response_handler.call(response_context).await?;

                Ok(())
            })
            .await
            .unwrap_or(Ok(()));

            // Wait for ExchangeComplete event (triggered by on_done from proxy-wasm)
            wait_for_on_done(Rc::clone(&reactor), host).await;

            // Execute done handler if state can be unwrapped
            if let Ok(state) = Rc::try_unwrap(state) {
                self.done_handler.done(state);
            }

            result
        })
    }
}

/// Creates a Response filter from a handler.
pub fn on_response<ResHnd, I>(response_handler: ResHnd) -> ResponseFilter<ResHnd::Handler>
where
    ResHnd: IntoHandler<ResponseContext<(), ()>, I, Output = ()>,
{
    ResponseFilter {
        response_handler: response_handler.into_handler(),
        state_factory: EmptyCreateHandler,
        done_handler: (),
    }
}

pub struct DualFilter<ReqHnd, ResHnd, Sf = EmptyCreateHandler, Done = ()> {
    pub(super) request_handler: ReqHnd,
    pub(super) response_handler: ResHnd,
    pub(super) state_factory: Sf,
    pub(super) done_handler: Done,
}

impl<ReqHnd, ResHnd, Sf, Done> Handler<FilterContext> for DualFilter<ReqHnd, ResHnd, Sf, Done>
where
    ReqHnd: Handler<RequestContext<Sf::State>>,
    ReqHnd::Output: IntoFlow,
    ResHnd:
        Handler<ResponseContext<<ReqHnd::Output as IntoFlow>::RequestData, Sf::State>, Output = ()>,
    Sf: CreateHandler<State: Clone>,
    Done: DoneHandler<Sf::State>,
{
    type Output = ();

    type Future<'h>
        = BoxFuture<'h, Result<Self::Output, ExtractionError>>
    where
        Self: 'h;

    fn call<'h>(&'h self, context: FilterContext) -> Self::Future<'h>
    where
        Self: 'h,
    {
        #[allow(clippy::await_holding_refcell_ref)]
        Box::pin(async move {
            let context = Rc::new(context);
            let exclusive_context = Exclusive::new(context.as_ref());
            let exchange = <Exchange<RequestHeaders>>::from_context_once(exclusive_context)
                .await
                .map_err(|_| {
                    ExtractionError(AlreadyExtracted::<Exchange<RequestHeaders>>::default().into())
                })?;
            let reactor = Rc::clone(&exchange.reactor);

            let state = Rc::new(self.state_factory.create());

            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));

            // Extract host for potential use in ExchangeComplete waiting
            let host: Rc<dyn Host> =
                FromContext::<_, crate::extract::extractability::Transitive>::from_context(
                    &*context,
                )
                .unwrap();

            let request_data = may_suspend_request(&reactor, async {
                let request_context =
                    RequestContext::new(context.clone(), exchange.clone(), Rc::clone(&state));

                let flow = self
                    .request_handler
                    .call(request_context)
                    .await?
                    .into_flow();
                match flow {
                    Flow::Break(response) => {
                        exchange.borrow_mut().send_response(
                            response.status_code(),
                            response.headers(),
                            response.body(),
                        );
                        Ok(RequestData::Break)
                    }
                    Flow::Continue(data) => Ok(RequestData::Continue(data)),
                }
            })
            .await
            .unwrap_or(Ok(RequestData::Cancel))?;

            let exclusive_context = Exclusive::new(context.as_ref());
            let exchange = <Exchange<ResponseHeaders>>::from_context_once(exclusive_context)
                .await
                .map_err(|_| {
                    ExtractionError(AlreadyExtracted::<Exchange<ResponseHeaders>>::default().into())
                })?;

            let exchange = Rc::new(RefCell::new(DynamicExchange::new(exchange)));

            let result = may_suspend_response(
                &reactor,
                self.response_handler.call(ResponseContext::new(
                    context.clone(),
                    exchange.clone(),
                    request_data,
                    state.clone(),
                )),
            )
            .await
            .unwrap_or(Ok(()));

            // Wait for ExchangeComplete event (triggered by on_done from proxy-wasm)
            wait_for_on_done(Rc::clone(&reactor), host).await;

            // Execute done handler if state can be unwrapped
            if let Ok(state) = Rc::try_unwrap(state) {
                self.done_handler.done(state);
            }

            result
        })
    }
}

async fn wait_for_on_done(reactor: Rc<HttpReactor>, host: Rc<dyn Host>) {
    let complete_exchange: Exchange<Start> = Exchange::new(reactor, host, None);
    let _ = complete_exchange.wait_for_event::<ExchangeComplete>().await;
}

async fn may_suspend_request<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
    let mut task = pin!(task);

    std::future::poll_fn(move |cx| match reactor.request_status() {
        FlowStatus::Suspended => {
            let id: u32 = reactor.context_id().into();
            log::debug!("Request for filter with context id {id} has been suspended.");

            Poll::Ready(None)
        }
        FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
    })
    .await
}

async fn may_suspend_response<F: Future>(reactor: &HttpReactor, task: F) -> Option<F::Output> {
    let mut task = pin!(task);

    std::future::poll_fn(move |cx| match reactor.response_status() {
        FlowStatus::Suspended => {
            let id: u32 = reactor.context_id().into();
            log::debug!("Response for filter with context id {id} has been suspended.");

            Poll::Ready(None)
        }
        FlowStatus::Unsuspended => task.as_mut().poll(cx).map(Some),
    })
    .await
}