pdk-unit 1.8.0

PDK Unit Test Framework
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use crate::host::implementation::ProxyWasmStub;
use crate::tester::io::{RequestResponse, UnitHttpResponse};
use crate::tester::unit_test::{add_request_properties, respond_call, Backends};
use crate::StopIterationMode;
use proxy_wasm_stub::traits::HttpContext;
use proxy_wasm_stub::types::{Action, BufferType, MapType};
use std::cell::RefCell;
use std::rc::Rc;
use std::task::Poll;

#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
pub enum State {
    RequestHeaders,
    RequestHeadersPaused,
    /// A weird state where the headers event was marked as paused but the body will arrive shortly.
    /// There are two ways to exit this state: resume from a call response or body event.
    RequestHeadersLimbo,
    RequestBody,
    RequestBodyPaused,
    ResponseHeaders,
    ResponseHeadersPaused,
    /// A weird state where the headers event was marked as paused but the body will arrive shortly.
    /// There are two ways to exit this state: resume from a call response or body event.
    ResponseHeadersLimbo,
    ResponseBody,
    ResponseBodyPaused,
    Done,
    Exit,
}

pub struct InnerUnitTestRequest {
    state: State,
    context_id: u32,
    chunk_size: usize,
    http_context: Box<dyn HttpContext>,
    backends: Rc<RefCell<Backends>>,
    host: Rc<RefCell<ProxyWasmStub>>,
    stop_mode: Option<StopIterationMode>,
    /// Holds the request and the response during their corresponding phases.
    request_response: RequestResponse,
    /// Represents the buffer chunks that were "Continued" by the filter.
    body_buffer: Vec<u8>,
}

impl InnerUnitTestRequest {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        context_id: u32,
        request: RequestResponse,
        http_context: Box<dyn HttpContext>,
        backends: Rc<RefCell<Backends>>,
        host: Rc<RefCell<ProxyWasmStub>>,
        stop_mode: Option<StopIterationMode>,
        chunk_size: usize,
    ) -> Self {
        Self {
            state: State::RequestHeaders,
            context_id,
            chunk_size,
            request_response: request,
            http_context,
            backends,
            host,
            stop_mode,
            body_buffer: vec![],
        }
    }

    pub fn poll(&mut self) -> Poll<UnitHttpResponse> {
        let context_id = self.context_id;
        self.host.borrow_mut().set_context(context_id);

        self.resume();

        // Forward the request headers
        if self.state == State::RequestHeaders {
            self.set_map(MapType::HttpRequestHeaders);
            let headers_len = self.request_response.headers().len();
            let eos = self.eos();
            let action = self.http_context.on_http_request_headers(headers_len, eos);

            if action == Action::Pause && self.clear_send_response() {
                self.load_response();
                self.state = State::ResponseHeaders;
            } else if action == Action::Pause {
                if self.stop_mode.is_some() && !eos {
                    self.state = State::RequestHeadersLimbo;
                } else {
                    self.state = State::RequestHeadersPaused;
                }
            } else {
                self.state = State::RequestBody;
            }

            if self.stop_mode != Some(StopIterationMode::BodyThenRequests) || eos {
                self.respond_calls();
            }
        }

        // Skip request body state if there isn't a body in the request.
        if self.state == State::RequestBody && self.eos() {
            self.call_backend();
            self.state = State::ResponseHeaders;
        }

        // Forward the request body
        if self.state == State::RequestBody || self.state == State::RequestHeadersLimbo {
            self.body_buffer.clear();
            let mut buffered_chunks: Vec<u8> = vec![];
            for chunk in self.request_response.body.clone().chunks(self.chunk_size) {
                buffered_chunks.extend_from_slice(chunk);

                self.set_buffer(BufferType::HttpRequestBody, buffered_chunks.clone());

                let action = self
                    .http_context
                    .on_http_request_body(buffered_chunks.len(), false);

                if action == Action::Pause {
                    if self.clear_send_response() {
                        self.body_buffer.clear();
                        self.load_response();
                        self.state = State::ResponseHeaders;
                        break;
                    }
                } else {
                    let continued_chunk = self.read_and_clear_buffer(BufferType::HttpRequestBody);
                    self.body_buffer.extend(continued_chunk);
                    buffered_chunks.clear();
                }
            }

            if self.state == State::RequestBody || self.state == State::RequestHeadersLimbo {
                let action = self
                    .http_context
                    .on_http_request_body(buffered_chunks.len(), true);

                if action == Action::Pause && self.clear_send_response() {
                    self.body_buffer.clear();
                    self.load_response();
                    self.state = State::ResponseHeaders;
                } else if action == Action::Pause {
                    self.state = State::RequestBodyPaused;
                } else {
                    let continued_chunk = self.read_and_clear_buffer(BufferType::HttpRequestBody);
                    self.body_buffer.extend(continued_chunk);
                    self.call_backend();
                    self.state = State::ResponseHeaders;
                }
            }

            self.respond_calls();
        }

        // Forward the response headers
        if self.state == State::ResponseHeaders {
            let headers_len = self.request_response.headers.len();
            let eos = self.eos();
            self.set_map(MapType::HttpResponseHeaders);

            let action = self.http_context.on_http_response_headers(headers_len, eos);

            if self.host.borrow_mut().clear_send_response(context_id) {
                self.body_buffer.clear();
                self.load_response();
                self.state = State::Done;
            } else if action == Action::Pause {
                if self.stop_mode.is_some() && !eos {
                    self.state = State::ResponseHeadersLimbo;
                } else {
                    self.state = State::ResponseHeadersPaused;
                }
            } else {
                self.state = State::ResponseBody;
            }

            if self.stop_mode != Some(StopIterationMode::BodyThenRequests) || eos {
                self.respond_calls();
            }
        }

        // Skip response body state if there isn't a body in the response.
        if self.state == State::ResponseBody && self.eos() {
            self.body_buffer.clear();
            self.load_response();
            self.state = State::Done;
        }

        // Forward the response body
        if self.state == State::ResponseBody || self.state == State::ResponseHeadersLimbo {
            self.body_buffer.clear();
            let mut buffered_chunks = vec![];
            for chunk in self.request_response.body.clone().chunks(self.chunk_size) {
                buffered_chunks.extend_from_slice(chunk);

                self.set_buffer(BufferType::HttpResponseBody, buffered_chunks.clone());

                let action = self
                    .http_context
                    .on_http_response_body(buffered_chunks.len(), false);

                if action == Action::Continue {
                    let continued_chunk = self.read_and_clear_buffer(BufferType::HttpResponseBody);
                    self.body_buffer.extend(continued_chunk);
                    buffered_chunks.clear();
                }
            }

            let action = self
                .http_context
                .on_http_response_body(buffered_chunks.len(), true);

            if action == Action::Pause {
                self.state = State::ResponseBodyPaused;
            } else {
                let continued_chunk = self.read_and_clear_buffer(BufferType::HttpResponseBody);
                self.body_buffer.extend(continued_chunk);
                self.load_response();
                self.state = State::Done;
            }

            self.respond_calls();
        }

        // Call final callbacks
        if self.state == State::Done {
            self.http_context.on_log();
            self.http_context.on_done();
            self.state = State::Exit;
        }

        if self.state == State::Exit {
            return Poll::Ready(self.request_response.clone().into());
        }

        Poll::Pending
    }

    fn respond_calls(&mut self) {
        if self.state >= State::Done {
            return;
        }

        let mut pending_calls = self.host.borrow_mut().pending_calls(self.context_id);
        while !pending_calls.is_empty() {
            for (id, upstream, call) in pending_calls.into_iter() {
                respond_call(
                    self.http_context.as_mut(),
                    &self.host,
                    &self.backends,
                    self.context_id,
                    id,
                    upstream,
                    call,
                );

                if self.resume() {
                    return;
                }
            }
            pending_calls = self.host.borrow_mut().pending_calls(self.context_id);
        }
    }

    fn resume(&mut self) -> bool {
        let resume = self.host.borrow_mut().clear_resume(self.context_id);
        let send_response = self.clear_send_response();

        assert!(!(resume && send_response));

        if resume {
            self.state = match self.state {
                State::RequestHeadersPaused | State::RequestHeadersLimbo => State::RequestBody,
                State::RequestBodyPaused => {
                    self.body_buffer
                        .extend(self.read_and_clear_buffer(BufferType::HttpRequestBody));
                    self.call_backend();
                    State::ResponseHeaders
                }
                State::ResponseHeadersPaused | State::ResponseHeadersLimbo => State::ResponseBody,
                State::ResponseBodyPaused => {
                    self.body_buffer
                        .extend(self.read_and_clear_buffer(BufferType::HttpResponseBody));
                    self.load_response();
                    State::Done
                }
                state => panic!("Called resume on non paused state {state:?}"),
            };
            true
        } else if send_response {
            self.body_buffer.clear();
            self.load_response();
            self.state = match self.state {
                State::RequestHeadersPaused | State::RequestBodyPaused => State::ResponseHeaders,
                State::ResponseHeadersPaused => State::Done,
                state => state,
            };
            true
        } else {
            false
        }
    }

    fn clear_send_response(&mut self) -> bool {
        self.host.borrow_mut().clear_send_response(self.context_id)
    }

    fn set_map(&self, map_type: MapType) {
        let context_id = self.context_id;
        self.host.borrow_mut().create_map(
            context_id,
            map_type,
            self.request_response
                .headers()
                .iter()
                .map(|(k, v)| (k.to_string(), v.to_string().into_bytes()))
                .collect(),
        );
    }

    fn set_buffer(&self, buffer_type: BufferType, buffer: Vec<u8>) {
        let context_id = self.context_id;
        self.host
            .borrow_mut()
            .create_buffer(context_id, buffer_type, buffer);
    }

    fn read_and_clear_buffer(&self, buffer_type: BufferType) -> Vec<u8> {
        let context_id = self.context_id;
        let buffer = self.host.borrow().read_buffer(context_id, buffer_type);

        self.host
            .borrow_mut()
            .create_buffer(context_id, buffer_type, vec![]);

        buffer
    }

    fn read_request_response(&self, map_type: MapType, buffer_type: BufferType) -> RequestResponse {
        let context_id = self.context_id;
        let headers = self
            .host
            .borrow()
            .read_map(context_id, map_type)
            .into_iter()
            .map(|(k, v)| (k, String::from_utf8(v).unwrap()))
            .collect();

        let mut body = self.body_buffer.clone();
        body.extend(self.host.borrow().read_buffer(context_id, buffer_type));

        let properties = self.host.borrow().get_properties(context_id);

        RequestResponse::create(headers, body, properties)
    }

    fn read_request(&mut self) -> RequestResponse {
        self.read_request_response(MapType::HttpRequestHeaders, BufferType::HttpRequestBody)
    }

    fn load_response(&mut self) {
        self.request_response =
            self.read_request_response(MapType::HttpResponseHeaders, BufferType::HttpResponseBody);
    }

    fn call_backend(&mut self) {
        let context_id = self.context_id;
        let request = self.read_request();
        self.body_buffer.clear();
        let response = self.backends.borrow().backend.call(request.into()).inner;
        let response = add_request_properties(response, context_id);

        self.request_response = response;

        self.host
            .borrow_mut()
            .set_properties(context_id, self.request_response.properties());
    }

    fn eos(&self) -> bool {
        self.request_response.body.is_empty()
    }
}