Skip to main content

pdk_unit/tester/
unit_test.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4use crate::backends::ldap::LdapBackend;
5use crate::backends::{anypoint::AnypointBackend, os::OSBackend};
6use crate::host::implementation::Call;
7use crate::host::{facade::HostFacade, implementation::ProxyWasmStub};
8use crate::tester::io::{RequestResponse, UnitHttpRequest, UnitHttpResponse};
9use crate::tester::unit_test_request::InnerUnitTestRequest;
10use crate::{Backend, GrpcBackend, UnitGrpcRequest, UnitLdapConfig};
11use classy::Entrypoint;
12use non_exhaustive::non_exhaustive;
13use pdk_core::host::context::root::RootContextAdapter;
14use pdk_core::init::configure;
15use pdk_core::policy_context::api::{
16    ApiMetadata, FlexMetadata, Metadata, PlatformMetadata, PolicyMetadata,
17};
18use pdk_core::policy_context::metadata::{
19    AnypointContext, Api, ApiContext, ApiSla as CoreApiSla, EnvironmentContext,
20};
21use pdk_core::policy_context::metadata::{IdentityManagementContext, Tier as CoreTier};
22use proxy_wasm_stub::stub::set_host;
23use proxy_wasm_stub::traits::{Context, RootContext};
24use proxy_wasm_stub::types::{BufferType, MapType};
25use std::backtrace::Backtrace;
26use std::cell::RefCell;
27use std::collections::HashMap;
28use std::panic;
29use std::rc::Rc;
30use std::task::Poll;
31use std::time::Duration;
32
33pub(super) const IDENTITY_MANAGEMENT_SVC: &str = "__identity_management_svc";
34const CHUNK_SIZE: usize = 3;
35
36/// The main test orchestrator for running PDK policy unit tests.
37///
38/// This struct manages the Proxy-Wasm host stub, handles request lifecycles,
39/// and coordinates interactions between the policy under test and mock backends.
40///
41/// Created via [`UnitTestBuilder`](crate::UnitTestBuilder).
42pub struct UnitTest {
43    host: Rc<RefCell<ProxyWasmStub>>,
44    context: Option<RootContextAdapter>,
45    context_count: u32,
46    requests: Vec<UnitTestRequest>,
47    backends: Rc<RefCell<Backends>>,
48    anypoint: Rc<AnypointBackend>,
49    ldap: Rc<LdapBackend>,
50    stop_mode: Option<StopIterationMode>,
51    chunk_size: usize,
52    config: UnitTestConfig,
53    factory: Box<dyn Fn() -> RootContextAdapter>,
54}
55
56pub(crate) struct Backends {
57    pub backend: Box<dyn Backend>,
58    pub upstreams: HashMap<String, Rc<dyn Backend>>,
59    pub grpc_upstreams: HashMap<String, Rc<dyn GrpcBackend>>,
60}
61
62/// Controls in which order how the test framework handles forwarding call responses and body processing
63/// when using the stop_iteration mode.
64///
65/// This is only available when the `enable_stop_iteration` feature is enabled.
66#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
67pub enum StopIterationMode {
68    /// Process the call responses first and then the event body.
69    RequestsThenBody,
70    /// Process the event body first and then the call responses.
71    BodyThenRequests,
72}
73
74pub(crate) struct UnitTestConfig {
75    pub(crate) policy_config: String,
76    pub(crate) metadata: Metadata,
77    pub(crate) identity_management: Option<String>,
78}
79
80impl Default for UnitTestConfig {
81    fn default() -> Self {
82        let policy_name = "test_policy_id".to_string();
83        let policy_namespace = "test_policy_namespace".to_string();
84        let api_name = "test_api_id".to_string();
85        let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
86
87        Self {
88            policy_config: "{}".to_string(),
89            metadata: non_exhaustive!(Metadata {
90                flex_metadata: non_exhaustive!(FlexMetadata {
91                    flex_name: "test_flex_name".to_string(),
92                    flex_version: "1.0.0".to_string(),
93                }),
94                policy_metadata: non_exhaustive!(PolicyMetadata {
95                    policy_name: policy_name,
96                    policy_namespace: policy_namespace,
97                    filter_name: filter_name,
98                }),
99                api_metadata: non_exhaustive!(ApiMetadata {
100                    id: Some("1".to_string()),
101                    name: Some(api_name),
102                    version: Some("1.0.0".to_string()),
103                    base_path: Some("/".to_string()),
104                    slas: None,
105                }),
106                platform_metadata: non_exhaustive!(PlatformMetadata {
107                    organization_id: "test-org-id".to_string(),
108                    environment_id: "test-env-id".to_string(),
109                    root_organization_id: "test-root-org-id".to_string(),
110                }),
111            }),
112            identity_management: None,
113        }
114    }
115}
116
117impl UnitTest {
118    pub(crate) fn new<C, T, E: Entrypoint<C, T> + Clone + 'static>(
119        entrypoint: E,
120        config: UnitTestConfig,
121        mut backends: Backends,
122    ) -> Self {
123        let host = Rc::new(RefCell::new(ProxyWasmStub::default()));
124        set_host(HostFacade::new(Rc::clone(&host)));
125
126        let factory = Box::new(move || {
127            RootContextAdapter::new(
128                configure(0)
129                    .entrypoint(entrypoint.clone())
130                    .create_root_context(0),
131            )
132        });
133
134        // Register the upstreams before on_configure to be able to fire request from on_configure.
135        let anypoint = Rc::new(AnypointBackend::default());
136        let anypoint_ref = Rc::clone(&anypoint);
137        backends
138            .upstreams
139            .entry("anypoint_service_name".to_string())
140            .or_insert(anypoint_ref);
141
142        let ldap = Rc::new(LdapBackend::default());
143        let ldap_ref = Rc::clone(&ldap);
144        backends
145            .upstreams
146            .entry("x-flex-services".to_string())
147            .or_insert(ldap_ref);
148
149        backends
150            .upstreams
151            .entry("x-flex-keyvalue-store".to_string())
152            .or_insert(Rc::new(OSBackend::default()));
153
154        let mut test = Self {
155            host,
156            context: None,
157            context_count: 0,
158            requests: Vec::new(),
159            backends: Rc::new(RefCell::new(backends)),
160            anypoint,
161            ldap,
162            #[cfg(feature = "enable_stop_iteration")]
163            stop_mode: Some(StopIterationMode::BodyThenRequests),
164            #[cfg(not(feature = "enable_stop_iteration"))]
165            stop_mode: None,
166            chunk_size: CHUNK_SIZE,
167            config,
168            factory,
169        };
170
171        test.init();
172
173        test
174    }
175
176    fn init(&mut self) {
177        let host = &self.host;
178
179        // Create factory context
180        self.context_count = 1;
181        host.borrow_mut().create_context(0);
182        host.borrow_mut().create_buffer(
183            0,
184            BufferType::PluginConfiguration,
185            self.config.policy_config.as_bytes().to_vec(),
186        );
187        host.borrow_mut().set_context(0);
188
189        setup_metadata(host, &self.config);
190
191        // Create & initialize factory
192        let factory = &self.factory;
193        self.context = Some(factory());
194
195        enrich_panic_hook();
196
197        self.backends
198            .borrow()
199            .upstreams
200            .keys()
201            .for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
202        self.backends
203            .borrow()
204            .grpc_upstreams
205            .keys()
206            .for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
207
208        self.context
209            .as_mut()
210            .unwrap()
211            .on_configure(self.config.policy_config.len());
212
213        // Respond to any pending calls triggered during on_configure.
214        self.respond_calls();
215    }
216
217    /// Simulate a system restart by cleaning all contexts and keeping the configured upstreams.
218    pub fn restart(&mut self) {
219        // Clear pending requests
220        self.requests.clear();
221
222        // Create a new host
223        let mut host = ProxyWasmStub::default();
224        host.clock = self.host.borrow().clock;
225        let host = Rc::new(RefCell::new(host));
226        set_host(HostFacade::new(Rc::clone(&host)));
227        self.host = host;
228
229        self.init();
230    }
231
232    /// Sets the stop iteration mode for handling paused requests.
233    ///
234    /// Only available when the `enable_stop_iteration` feature is enabled.
235    #[cfg(feature = "enable_stop_iteration")]
236    pub fn set_host_mode(&mut self, mode: StopIterationMode) {
237        self.stop_mode = Some(mode);
238    }
239
240    #[cfg(feature = "experimental")]
241    pub fn get_metrics(&mut self) -> HashMap<String, u64> {
242        self.host
243            .borrow()
244            .get_metrics()
245            .into_iter()
246            .map(|(_id, (name, value))| (name, value))
247            .collect()
248    }
249
250    /// Sets the chunk size for body processing.
251    ///
252    /// Bodies larger than this size will be processed in multiple chunks.
253    pub fn set_chunk_size(&mut self, chunk_size: usize) {
254        self.chunk_size = chunk_size;
255    }
256
257    /// Adds contract data for client ID enforcement testing.
258    ///
259    /// This simulates registered API contracts in the Anypoint Platform.
260    pub fn add_contract_data<I, N, S, Sla>(
261        &mut self,
262        id: I,
263        name: N,
264        secret: Option<S>,
265        sla_id: Option<Sla>,
266    ) where
267        I: Into<String>,
268        N: Into<String>,
269        S: Into<String>,
270        Sla: Into<String>,
271    {
272        self.anypoint.add_contract(
273            id.into(),
274            name.into(),
275            secret.map(|s| s.into()),
276            sla_id.map(|sla| sla.into()),
277        );
278    }
279
280    /// Removes contract data for client ID enforcement testing.
281    pub fn remove_contract_data<I>(&mut self, id: I)
282    where
283        I: Into<String>,
284    {
285        self.anypoint.remove_contract(id.into());
286    }
287
288    /// Registers a valid LDAP credential pair for use during testing.
289    ///
290    /// If `config` is [`Some`], the pair is matched only when the policy uses
291    /// LDAP connection parameters equal to that config. If `config` is [`None`],
292    /// the pair acts as a wildcard and matches regardless of the LDAP config.
293    ///
294    /// # Arguments
295    ///
296    /// * `config` - Optional LDAP server configuration to scope this credential to.
297    /// * `user` - The username that should be considered valid.
298    /// * `pass` - The password that should be considered valid for `user`.
299    pub fn add_ldap_data<U, P>(&mut self, config: Option<UnitLdapConfig>, user: U, pass: P)
300    where
301        U: Into<String>,
302        P: Into<String>,
303    {
304        self.ldap.add_data(config, user, pass);
305    }
306
307    /// Sends a request through the policy and returns a handle to track its progress.
308    ///
309    /// The returned [`UnitTestRequest`] can be polled to advance the request
310    /// through the policy lifecycle and eventually retrieve the response.
311    pub fn request_partial(&mut self, request: UnitHttpRequest) -> UnitTestRequest {
312        // Create new request context
313        let request = request.inner;
314        let context_id = self.context_count;
315        self.context_count += 1;
316
317        self.host.borrow_mut().create_context(context_id);
318
319        let request = add_request_properties(request, context_id);
320        let props = request.properties();
321
322        self.host.borrow_mut().set_properties(context_id, props);
323        self.host.borrow_mut().set_context(context_id);
324
325        let http_context = self
326            .context
327            .as_ref()
328            .unwrap()
329            .create_http_context(context_id)
330            .unwrap();
331
332        let mut inner = UnitTestRequest::new(InnerUnitTestRequest::new(
333            context_id,
334            request,
335            http_context,
336            Rc::clone(&self.backends),
337            Rc::clone(&self.host),
338            self.stop_mode,
339            self.chunk_size,
340        ));
341
342        if !inner.poll().is_ready() {
343            self.requests.push(inner.clone())
344        }
345
346        inner
347    }
348
349    fn forward_requests(&mut self) {
350        self.requests.retain_mut(|req| !req.poll().is_ready());
351    }
352
353    fn do_tick(&mut self) {
354        self.host.borrow_mut().set_context(0);
355        self.context.as_mut().unwrap().on_tick();
356        self.forward_requests();
357        self.respond_calls();
358    }
359
360    /// Advances the simulated time by one tick interval.
361    ///
362    /// This triggers `on_tick` callbacks and processes any pending requests.
363    pub fn tick(&mut self) {
364        if !self.host.borrow_mut().tick().is_zero() {
365            self.do_tick();
366        }
367    }
368
369    /// Advances the simulated time by the specified duration.
370    ///
371    /// This will trigger multiple ticks if the duration spans multiple tick intervals.
372    pub fn sleep(&mut self, duration: Duration) {
373        let mut accumulated = Duration::new(0, 0);
374        while accumulated < duration {
375            let elapsed = self.host.borrow_mut().tick();
376            if elapsed.is_zero() {
377                self.host.borrow_mut().forward(duration - accumulated);
378                return;
379            }
380            accumulated += elapsed;
381            self.do_tick();
382        }
383    }
384
385    /// Sends a request and blocks until the full response is received.
386    ///
387    /// This is a convenience method that combines `request_partial()` with polling
388    /// until completion. Use this for simple synchronous test scenarios.
389    pub fn request(&mut self, request: UnitHttpRequest) -> UnitHttpResponse {
390        let mut response = self.request_partial(request);
391
392        loop {
393            if let Poll::Ready(value) = response.poll() {
394                return value;
395            } else {
396                self.tick()
397            }
398        }
399    }
400
401    #[cfg(feature = "experimental_logs")]
402    pub fn logs(&self) -> Vec<String> {
403        self.host.borrow().logs.borrow().clone()
404    }
405
406    fn respond_calls(&mut self) {
407        let mut pending_calls = self.host.borrow_mut().pending_calls(0);
408        while !pending_calls.is_empty() {
409            for (id, upstream, call) in pending_calls.into_iter() {
410                respond_call(
411                    self.context.as_mut().unwrap(),
412                    &self.host,
413                    &self.backends,
414                    0,
415                    id,
416                    upstream,
417                    call,
418                );
419            }
420            pending_calls = self.host.borrow_mut().pending_calls(0);
421        }
422    }
423}
424
425/// A handle to an in-flight request being processed by the policy.
426///
427/// Use `poll()` to advance the request through the policy lifecycle.
428#[derive(Clone)]
429pub struct UnitTestRequest {
430    inner: Rc<RefCell<InnerUnitTestRequest>>,
431}
432
433impl UnitTestRequest {
434    pub(crate) fn new(inner: InnerUnitTestRequest) -> Self {
435        Self {
436            inner: Rc::new(RefCell::new(inner)),
437        }
438    }
439
440    /// Advances the request processing and returns the current state.
441    ///
442    /// Returns `Poll::Ready(response)` when the request is complete,
443    /// or `Poll::Pending` if more processing is needed.
444    pub fn poll(&mut self) -> Poll<UnitHttpResponse> {
445        self.inner.borrow_mut().poll()
446    }
447}
448
449fn setup_metadata(host: &Rc<RefCell<ProxyWasmStub>>, config: &UnitTestConfig) {
450    let api_name = config
451        .metadata
452        .api_metadata
453        .name
454        .clone()
455        .unwrap_or_default();
456    let policy_name = &config.metadata.policy_metadata.policy_name;
457    let policy_namespace = &config.metadata.policy_metadata.policy_namespace;
458
459    let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
460
461    host.borrow_mut().create_property(
462        0,
463        vec!["node", "id"],
464        Some(config.metadata.flex_metadata.flex_name.clone()),
465    );
466    host.borrow_mut()
467        .create_property(0, vec!["plugin_name"], Some(filter_name));
468
469    let tiers = config
470        .metadata
471        .api_metadata
472        .slas
473        .as_ref()
474        .map(|slas| {
475            slas.iter()
476                .map(|sla| {
477                    CoreApiSla::new(
478                        sla.id.clone(),
479                        sla.name.clone(),
480                        sla.tiers
481                            .iter()
482                            .map(|tier| CoreTier::new(tier.requests, tier.period_in_millis))
483                            .collect(),
484                    )
485                })
486                .collect::<Vec<_>>()
487        })
488        .unwrap_or_default();
489
490    let mut api = Api::new(
491        config.metadata.api_metadata.id.clone().unwrap_or_default(),
492        api_name,
493        "v1".to_string(),
494        config
495            .metadata
496            .api_metadata
497            .version
498            .clone()
499            .unwrap_or_default(),
500        None,
501    );
502
503    if let Some(path) = config.metadata.api_metadata.base_path.as_ref() {
504        api.set_base_path(path.clone())
505    }
506
507    let anypoint = AnypointContext::new(
508        "test_client".to_string(),
509        "test_secret".to_string(),
510        "anypoint_service_name".to_string(),
511        "https://anypoint.mulesoft.com".to_string(),
512    );
513
514    let environment = EnvironmentContext::new(
515        config.metadata.platform_metadata.organization_id.clone(),
516        config.metadata.platform_metadata.environment_id.clone(),
517        config
518            .metadata
519            .platform_metadata
520            .root_organization_id
521            .clone(),
522        "test_cluster_id".to_string(),
523        Some(anypoint),
524        None,
525    );
526
527    let identity = config.identity_management.as_ref().map(|url| {
528        IdentityManagementContext::new(
529            "client_id".to_string(),
530            "client_secret".to_string(),
531            url.clone(),
532            IDENTITY_MANAGEMENT_SVC.to_string(),
533        )
534    });
535
536    let context = ApiContext::new(
537        None,
538        Some(api),
539        Some(tiers),
540        identity,
541        Some(environment),
542        None,
543    );
544    let context = serde_json::to_string(&context).unwrap();
545
546    host.borrow_mut().create_property(
547        0,
548        vec![
549            "listener_metadata",
550            "filter_metadata",
551            config
552                .metadata
553                .api_metadata
554                .name
555                .as_deref()
556                .unwrap_or_default(),
557            "context",
558        ],
559        Some(context),
560    )
561}
562
563fn enrich_panic_hook() {
564    let hook = panic::take_hook();
565
566    panic::set_hook(Box::new(move |panic_info| {
567        hook(panic_info);
568        println!("{}", Backtrace::capture());
569    }));
570}
571
572pub(crate) fn add_request_properties(request: RequestResponse, context_id: u32) -> RequestResponse {
573    request
574        .with_property_if_missing(&["anypoint/mulesoft/tracing_id"], context_id.to_string())
575        .with_property_if_missing(&["request", "id"], context_id.to_string())
576        .with_property_if_missing(&["source", "address"], "127.0.0.1")
577        .with_property_if_missing(&["destination", "address"], "127.0.0.2")
578        .with_property_if_missing(&["request", "scheme"], "http")
579        .with_property_if_missing(&["request", "protocol"], "1.1")
580}
581
582pub(crate) fn respond_http<C: Context + ?Sized>(
583    context: &mut C,
584    host: &Rc<RefCell<ProxyWasmStub>>,
585    backends: &Rc<RefCell<Backends>>,
586    context_id: u32,
587    id: u32,
588    upstream: String,
589    req: RequestResponse,
590) {
591    let response = backends
592        .borrow()
593        .upstreams
594        .get(&upstream)
595        .unwrap()
596        .call(req.into())
597        .inner;
598    let response_headers = response.headers.len();
599    let response_body = response.body.len();
600
601    host.borrow_mut().create_map(
602        context_id,
603        MapType::HttpCallResponseHeaders,
604        response
605            .headers
606            .into_iter()
607            .map(|(k, v)| (k, v.into_bytes()))
608            .collect(),
609    );
610    host.borrow_mut()
611        .create_buffer(context_id, BufferType::HttpCallResponseBody, response.body);
612    context.on_http_call_response(id, response_headers, response_body, 0);
613}
614
615pub(crate) fn respond_grpc<C: Context + ?Sized>(
616    context: &mut C,
617    host: &Rc<RefCell<ProxyWasmStub>>,
618    backends: &Rc<RefCell<Backends>>,
619    context_id: u32,
620    id: u32,
621    upstream: String,
622    req: UnitGrpcRequest,
623) {
624    let response = backends
625        .borrow()
626        .grpc_upstreams
627        .get(&upstream)
628        .unwrap()
629        .call(req);
630
631    host.borrow_mut()
632        .set_grpc_status((response.status_code, response.status));
633
634    let body_len = response.message.len();
635    host.borrow_mut()
636        .create_buffer(context_id, BufferType::GrpcReceiveBuffer, response.message);
637
638    context.on_grpc_call_response(id, response.status_code, body_len)
639}
640
641pub(crate) fn respond_call<C: Context + ?Sized>(
642    context: &mut C,
643    host: &Rc<RefCell<ProxyWasmStub>>,
644    backends: &Rc<RefCell<Backends>>,
645    context_id: u32,
646    id: u32,
647    upstream: String,
648    call: Call,
649) {
650    match call {
651        Call::Http(req) => respond_http(context, host, backends, context_id, id, upstream, req),
652        Call::Grpc(req) => respond_grpc(context, host, backends, context_id, id, upstream, req),
653    }
654}