1#[cfg(feature = "pdk")]
5use crate::backends::{anypoint::AnypointBackend, ldap::LdapBackend, os::OSBackend};
6use crate::host::implementation::{Call, FlowType};
7use crate::host::{facade::HostFacade, implementation::ProxyWasmStub};
8use crate::tester::io::UnitLogLevel;
9use crate::tester::io::{RequestResponse, UnitHttpRequest, UnitHttpResponse};
10use crate::tester::unit_test_request::InnerUnitTestRequest;
11use crate::tester::websocket::{
12 ConnectionInner, Direction, InnerUnitUpgrade, UnitTestUpgrade, UpgradeConnection,
13};
14use crate::{Backend, GrpcBackend, UnitGrpcRequest};
15#[cfg(feature = "pdk")]
16use classy::Entrypoint;
17#[cfg(feature = "pdk")]
18use non_exhaustive::non_exhaustive;
19#[cfg(feature = "pdk")]
20use pdk_core::host::context::root::RootContextAdapter;
21#[cfg(feature = "pdk")]
22use pdk_core::init::configure;
23#[cfg(feature = "pdk")]
24use pdk_core::policy_context::api::{
25 ApiMetadata, FlexMetadata, Metadata, PlatformMetadata, PolicyMetadata,
26};
27#[cfg(feature = "pdk")]
28use pdk_core::policy_context::metadata::{
29 AnypointContext, Api, ApiContext, ApiSla as CoreApiSla, EnvironmentContext,
30 IdentityManagementContext, Tier as CoreTier,
31};
32use proxy_wasm_stub::stub::set_host;
33use proxy_wasm_stub::traits::{Context, RootContext};
34use proxy_wasm_stub::types::{BufferType, MapType};
35use std::backtrace::Backtrace;
36use std::cell::RefCell;
37use std::collections::HashMap;
38use std::panic;
39use std::rc::Rc;
40use std::rc::Weak;
41use std::task::Poll;
42use std::thread;
43use std::time::Duration;
44
45#[cfg(feature = "pdk")]
46pub(super) const IDENTITY_MANAGEMENT_SVC: &str = "__identity_management_svc";
47const CHUNK_SIZE: usize = 30_000;
48
49pub struct UnitTest {
56 host: Rc<RefCell<ProxyWasmStub>>,
57 context: Option<Box<dyn RootContext>>,
58 context_count: u32,
59 requests: Vec<UnitTestRequest>,
60 upgrades: Vec<UnitTestUpgrade>,
61 connections: Vec<Weak<RefCell<ConnectionInner>>>,
62 backends: Rc<RefCell<Backends>>,
63 #[cfg(feature = "pdk")]
64 anypoint: Rc<AnypointBackend>,
65 #[cfg(feature = "pdk")]
66 ldap: Rc<LdapBackend>,
67 stop_mode: Option<StopIterationMode>,
68 chunk_size: usize,
69 log_level: UnitLogLevel,
70 config: UnitTestConfig,
71 factory: Box<dyn Fn() -> Box<dyn RootContext>>,
72}
73
74pub(crate) struct Backends {
75 pub backend: Box<dyn Backend>,
76 pub upstreams: HashMap<String, Rc<dyn Backend>>,
77 pub grpc_upstreams: HashMap<String, Rc<dyn GrpcBackend>>,
78}
79
80#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
85pub enum StopIterationMode {
86 RequestsThenBody,
88 BodyThenRequests,
90}
91
92pub(crate) struct UnitTestConfig {
93 pub(crate) policy_config: String,
94 #[cfg(feature = "pdk")]
95 pub(crate) metadata: Metadata,
96 #[cfg(feature = "pdk")]
97 pub(crate) identity_management: Option<String>,
98 #[cfg(feature = "pdk")]
99 pub(crate) local_mode: bool,
100}
101
102impl Default for UnitTestConfig {
103 fn default() -> Self {
104 #[cfg(feature = "pdk")]
105 {
106 let policy_name = "test_policy_id".to_string();
107 let policy_namespace = "test_policy_namespace".to_string();
108 let api_name = "test_api_id".to_string();
109 let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
110
111 Self {
112 policy_config: "{}".to_string(),
113 metadata: non_exhaustive!(Metadata {
114 flex_metadata: non_exhaustive!(FlexMetadata {
115 flex_name: "test_flex_name".to_string(),
116 flex_version: "1.0.0".to_string(),
117 }),
118 policy_metadata: non_exhaustive!(PolicyMetadata {
119 policy_name: policy_name,
120 policy_namespace: policy_namespace,
121 filter_name: filter_name,
122 }),
123 api_metadata: non_exhaustive!(ApiMetadata {
124 id: Some("1".to_string()),
125 name: Some(api_name),
126 version: Some("1.0.0".to_string()),
127 base_path: Some("/".to_string()),
128 slas: None,
129 }),
130 platform_metadata: non_exhaustive!(PlatformMetadata {
131 organization_id: "test-org-id".to_string(),
132 environment_id: "test-env-id".to_string(),
133 root_organization_id: "test-root-org-id".to_string(),
134 }),
135 }),
136 identity_management: None,
137 local_mode: false,
138 }
139 }
140
141 #[cfg(not(feature = "pdk"))]
142 Self {
143 policy_config: "{}".to_string(),
144 }
145 }
146}
147
148impl UnitTest {
149 #[cfg(feature = "pdk")]
150 pub(crate) fn new<C, T, E: Entrypoint<C, T> + Clone + 'static>(
151 entrypoint: E,
152 config: UnitTestConfig,
153 mut backends: Backends,
154 ) -> Self {
155 let host = Rc::new(RefCell::new(ProxyWasmStub::default()));
156 set_host(HostFacade::new(Rc::clone(&host)));
157
158 let factory: Box<dyn Fn() -> Box<dyn RootContext>> = Box::new(move || {
159 Box::new(RootContextAdapter::new(
160 configure(0)
161 .entrypoint(entrypoint.clone())
162 .create_root_context(0),
163 ))
164 });
165
166 let anypoint = Rc::new(AnypointBackend::default());
168 let anypoint_ref = Rc::clone(&anypoint);
169 backends
170 .upstreams
171 .entry("anypoint_service_name".to_string())
172 .or_insert(anypoint_ref);
173
174 let ldap = Rc::new(LdapBackend::default());
175 let ldap_ref = Rc::clone(&ldap);
176 backends
177 .upstreams
178 .entry("x-flex-services".to_string())
179 .or_insert(ldap_ref);
180
181 backends
182 .upstreams
183 .entry("x-flex-keyvalue-store".to_string())
184 .or_insert(Rc::new(OSBackend::default()));
185
186 let mut test = Self {
187 host,
188 context: None,
189 context_count: 0,
190 requests: Vec::new(),
191 upgrades: Vec::new(),
192 connections: Vec::new(),
193 backends: Rc::new(RefCell::new(backends)),
194 anypoint,
195 ldap,
196 #[cfg(feature = "enable_stop_iteration")]
197 stop_mode: Some(StopIterationMode::BodyThenRequests),
198 #[cfg(not(feature = "enable_stop_iteration"))]
199 stop_mode: None,
200 chunk_size: CHUNK_SIZE,
201 log_level: UnitLogLevel::Trace,
202 config,
203 factory,
204 };
205
206 test.init();
207
208 test
209 }
210
211 #[cfg(feature = "proxy-wasm-rust-sdk")]
216 pub(crate) fn new_with_context<F>(
217 factory: F,
218 config: UnitTestConfig,
219 backends: Backends,
220 ) -> Self
221 where
222 F: Fn() -> Box<dyn RootContext> + 'static,
223 {
224 let host = Rc::new(RefCell::new(ProxyWasmStub::default()));
225 set_host(HostFacade::new(Rc::clone(&host)));
226
227 let mut test = Self {
228 host,
229 context: None,
230 context_count: 0,
231 requests: Vec::new(),
232 upgrades: Vec::new(),
233 connections: Vec::new(),
234 backends: Rc::new(RefCell::new(backends)),
235 #[cfg(feature = "pdk")]
236 anypoint: Rc::new(AnypointBackend::default()),
237 #[cfg(feature = "pdk")]
238 ldap: Rc::new(LdapBackend::default()),
239 #[cfg(feature = "enable_stop_iteration")]
240 stop_mode: Some(StopIterationMode::BodyThenRequests),
241 #[cfg(not(feature = "enable_stop_iteration"))]
242 stop_mode: None,
243 chunk_size: CHUNK_SIZE,
244 log_level: UnitLogLevel::Trace,
245 config,
246 factory: Box::new(factory),
247 };
248
249 test.init();
250
251 test
252 }
253
254 fn init(&mut self) {
255 let host = &self.host;
256
257 host.borrow_mut().set_log_level(self.log_level);
258
259 self.context_count = 1;
261 host.borrow_mut().create_context(0);
262 host.borrow_mut().create_buffer(
263 0,
264 BufferType::PluginConfiguration,
265 self.config.policy_config.as_bytes().to_vec(),
266 );
267 host.borrow_mut().set_context(0);
268
269 #[cfg(feature = "pdk")]
270 setup_metadata(host, &self.config);
271
272 let factory = &self.factory;
274 self.context = Some(factory());
275
276 enrich_panic_hook();
277
278 self.backends
279 .borrow()
280 .upstreams
281 .keys()
282 .for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
283 self.backends
284 .borrow()
285 .grpc_upstreams
286 .keys()
287 .for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
288
289 self.context
290 .as_mut()
291 .unwrap()
292 .on_configure(self.config.policy_config.len());
293
294 self.respond_calls();
296 }
297
298 pub fn restart(&mut self) {
300 self.requests.clear();
302 self.upgrades.clear();
303 self.connections.clear();
304
305 let mut host = ProxyWasmStub::default();
307 host.clock = self.host.borrow().clock;
308 let host = Rc::new(RefCell::new(host));
309 set_host(HostFacade::new(Rc::clone(&host)));
310 self.host = host;
311
312 self.init();
313 }
314
315 pub fn set_log_level(&mut self, level: UnitLogLevel) {
317 self.log_level = level;
318 self.host.borrow_mut().set_log_level(level);
319 }
320
321 #[cfg(feature = "enable_stop_iteration")]
325 pub fn set_host_mode(&mut self, mode: StopIterationMode) {
326 self.stop_mode = Some(mode);
327 }
328
329 #[cfg(feature = "experimental")]
330 pub fn get_metrics(&mut self) -> HashMap<String, u64> {
331 self.host
332 .borrow()
333 .get_metrics()
334 .into_iter()
335 .map(|(_id, (name, value))| (name, value))
336 .collect()
337 }
338
339 pub fn set_chunk_size(&mut self, chunk_size: usize) {
343 self.chunk_size = chunk_size;
344 }
345
346 pub fn request_partial(&mut self, request: UnitHttpRequest) -> UnitTestRequest {
351 let request = request.inner;
353 let context_id = self.context_count;
354 self.context_count += 1;
355
356 self.host.borrow_mut().create_context(context_id);
357
358 let request = add_request_properties(request, context_id);
359 let props = request.properties();
360
361 self.host.borrow_mut().set_properties(context_id, props);
362 self.host.borrow_mut().set_context(context_id);
363
364 let http_context = self
365 .context
366 .as_ref()
367 .unwrap()
368 .create_http_context(context_id)
369 .unwrap();
370
371 let mut inner = UnitTestRequest::new(InnerUnitTestRequest::new(
372 context_id,
373 request,
374 http_context,
375 Rc::clone(&self.backends),
376 Rc::clone(&self.host),
377 self.stop_mode,
378 self.chunk_size,
379 ));
380
381 if !inner.poll().is_ready() {
382 self.requests.push(inner.clone())
383 }
384
385 inner
386 }
387
388 fn forward_requests(&mut self) {
389 self.requests.retain_mut(|req| !req.poll().is_ready());
390 self.upgrades.retain_mut(|req| match req.poll() {
391 Poll::Ready(Ok(conn)) => {
392 self.connections.push(conn.weak_inner());
393 false
394 }
395 Poll::Ready(Err(_)) => false,
396 Poll::Pending => true,
397 });
398 self.connections.retain(|conn| {
399 if let Some(conn) = conn.upgrade() {
400 conn.borrow_mut().resume(Direction::Upstream);
401 conn.borrow_mut().resume(Direction::Downstream);
402 true
403 } else {
404 false
405 }
406 })
407 }
408
409 fn do_tick(&mut self) {
410 self.host.borrow_mut().set_context(0);
411 self.host.borrow_mut().set_flow_mode(FlowType::Async);
412 self.context.as_mut().unwrap().on_tick();
413 self.forward_requests();
414 self.respond_calls();
415 }
416
417 pub fn tick(&mut self) {
421 if !self.host.borrow_mut().tick().is_zero() {
422 self.do_tick();
423 }
424 }
425
426 pub fn sleep(&mut self, duration: Duration) {
430 let mut accumulated = Duration::new(0, 0);
431 while accumulated < duration {
432 let elapsed = self.host.borrow_mut().tick();
433 if elapsed.is_zero() {
434 self.host.borrow_mut().forward(duration - accumulated);
435 return;
436 }
437 accumulated += elapsed;
438 self.do_tick();
439 }
440 }
441
442 pub fn upgrade_partial(&mut self, request: UnitHttpRequest) -> UnitTestUpgrade {
449 let request = request.inner;
450 let context_id = self.context_count;
451 self.context_count += 1;
452
453 self.host.borrow_mut().create_context(context_id);
454
455 let request = add_request_properties(request, context_id);
456 let props = request.properties();
457 self.host.borrow_mut().set_properties(context_id, props);
458 self.host.borrow_mut().set_context(context_id);
459
460 let http_context = self
461 .context
462 .as_ref()
463 .unwrap()
464 .create_http_context(context_id)
465 .unwrap();
466
467 let upgrade = UnitTestUpgrade::new(InnerUnitUpgrade::new(
468 context_id,
469 request,
470 http_context,
471 Rc::clone(&self.backends),
472 Rc::clone(&self.host),
473 self.chunk_size,
474 ));
475
476 self.upgrades.push(upgrade.clone());
477
478 upgrade
479 }
480
481 pub fn upgrade(
487 &mut self,
488 request: UnitHttpRequest,
489 ) -> Result<UpgradeConnection, UnitHttpResponse> {
490 let mut handle = self.upgrade_partial(request);
491
492 loop {
493 match handle.poll() {
494 Poll::Ready(res) => {
495 return res;
496 }
497 Poll::Pending => self.tick(),
498 }
499 }
500 }
501
502 pub fn request(&mut self, request: UnitHttpRequest) -> UnitHttpResponse {
507 let mut response = self.request_partial(request);
508
509 loop {
510 if let Poll::Ready(value) = response.poll() {
511 return value;
512 } else {
513 self.tick()
514 }
515 }
516 }
517
518 #[cfg(feature = "experimental_logs")]
519 pub fn logs(&self) -> Vec<String> {
520 self.host.borrow().logs.borrow().clone()
521 }
522
523 fn respond_calls(&mut self) {
524 let mut pending_calls = self.host.borrow_mut().pending_calls(0);
525 while !pending_calls.is_empty() {
526 for (id, upstream, call) in pending_calls.into_iter() {
527 respond_call(
528 self.context.as_deref_mut().unwrap(),
529 &self.host,
530 &self.backends,
531 0,
532 id,
533 upstream,
534 call,
535 );
536 }
537 pending_calls = self.host.borrow_mut().pending_calls(0);
538 }
539 }
540}
541
542#[cfg(feature = "pdk")]
543impl UnitTest {
544 pub fn add_contract_data<I, N, S, Sla>(
548 &mut self,
549 id: I,
550 name: N,
551 secret: Option<S>,
552 sla_id: Option<Sla>,
553 ) where
554 I: Into<String>,
555 N: Into<String>,
556 S: Into<String>,
557 Sla: Into<String>,
558 {
559 self.anypoint.add_contract(
560 id.into(),
561 name.into(),
562 secret.map(|s| s.into()),
563 sla_id.map(|sla| sla.into()),
564 );
565 }
566
567 pub fn remove_contract_data<I>(&mut self, id: I)
569 where
570 I: Into<String>,
571 {
572 self.anypoint.remove_contract(id.into());
573 }
574
575 #[cfg(feature = "pdk")]
576 pub fn add_ldap_data<U, P>(&mut self, config: Option<crate::UnitLdapConfig>, user: U, pass: P)
588 where
589 U: Into<String>,
590 P: Into<String>,
591 {
592 self.ldap.add_data(config, user, pass);
593 }
594}
595
596#[derive(Clone)]
600pub struct UnitTestRequest {
601 inner: Rc<RefCell<InnerUnitTestRequest>>,
602}
603
604impl UnitTestRequest {
605 pub(crate) fn new(inner: InnerUnitTestRequest) -> Self {
606 Self {
607 inner: Rc::new(RefCell::new(inner)),
608 }
609 }
610
611 pub fn poll(&mut self) -> Poll<UnitHttpResponse> {
616 self.inner.borrow_mut().poll()
617 }
618}
619
620#[cfg(feature = "pdk")]
621fn setup_metadata(host: &Rc<RefCell<ProxyWasmStub>>, config: &UnitTestConfig) {
622 let api_name = config
623 .metadata
624 .api_metadata
625 .name
626 .clone()
627 .unwrap_or_default();
628 let policy_name = &config.metadata.policy_metadata.policy_name;
629 let policy_namespace = &config.metadata.policy_metadata.policy_namespace;
630
631 let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
632
633 host.borrow_mut().create_property(
634 0,
635 vec!["node", "id"],
636 Some(config.metadata.flex_metadata.flex_name.clone()),
637 );
638 host.borrow_mut()
639 .create_property(0, vec!["plugin_name"], Some(filter_name));
640
641 let context = match config.local_mode {
642 true => local_context(config),
643 false => connected_context(config),
644 };
645
646 let context = serde_json::to_string(&context).unwrap();
647
648 host.borrow_mut().create_property(
649 0,
650 vec![
651 "listener_metadata",
652 "filter_metadata",
653 config
654 .metadata
655 .api_metadata
656 .name
657 .as_deref()
658 .unwrap_or_default(),
659 "context",
660 ],
661 Some(context),
662 )
663}
664
665#[cfg(feature = "pdk")]
666fn local_context(config: &UnitTestConfig) -> ApiContext {
667 let environment = EnvironmentContext::new(
668 config.metadata.platform_metadata.organization_id.clone(),
669 config.metadata.platform_metadata.environment_id.clone(),
670 config
671 .metadata
672 .platform_metadata
673 .root_organization_id
674 .clone(),
675 "test_cluster_id".to_string(),
676 None,
677 None,
678 );
679
680 ApiContext::new(None, None, None, None, Some(environment), None)
681}
682
683#[cfg(feature = "pdk")]
684fn connected_context(config: &UnitTestConfig) -> ApiContext {
685 let tiers = config
686 .metadata
687 .api_metadata
688 .slas
689 .as_ref()
690 .map(|slas| {
691 slas.iter()
692 .map(|sla| {
693 CoreApiSla::new(
694 sla.id.clone(),
695 sla.name.clone(),
696 sla.tiers
697 .iter()
698 .map(|tier| CoreTier::new(tier.requests, tier.period_in_millis))
699 .collect(),
700 )
701 })
702 .collect::<Vec<_>>()
703 })
704 .unwrap_or_default();
705
706 #[cfg(feature = "experimental")]
707 let exchange_context = config.metadata.api_metadata.asset.as_ref().map(|ctx| {
708 pdk_core::policy_context::metadata::ExchangeContext::new(
709 ctx.service.cluster_name().to_string(),
710 ctx.service.uri().to_string(),
711 )
712 });
713
714 #[cfg(not(feature = "experimental"))]
715 let exchange_context = None;
716
717 let mut api = Api::new(
718 config.metadata.api_metadata.id.clone().unwrap_or_default(),
719 config
720 .metadata
721 .api_metadata
722 .name
723 .clone()
724 .unwrap_or_default(),
725 "v1".to_string(),
726 config
727 .metadata
728 .api_metadata
729 .version
730 .clone()
731 .unwrap_or_default(),
732 exchange_context,
733 );
734
735 if let Some(path) = config.metadata.api_metadata.base_path.as_ref() {
736 api.set_base_path(path.clone())
737 }
738
739 let anypoint = AnypointContext::new(
740 "test_client".to_string(),
741 "test_secret".to_string(),
742 "anypoint_service_name".to_string(),
743 "https://anypoint.mulesoft.com".to_string(),
744 );
745
746 let environment = EnvironmentContext::new(
747 config.metadata.platform_metadata.organization_id.clone(),
748 config.metadata.platform_metadata.environment_id.clone(),
749 config
750 .metadata
751 .platform_metadata
752 .root_organization_id
753 .clone(),
754 "test_cluster_id".to_string(),
755 Some(anypoint),
756 None,
757 );
758
759 let identity = config.identity_management.as_ref().map(|url| {
760 IdentityManagementContext::new(
761 "client_id".to_string(),
762 "client_secret".to_string(),
763 url.clone(),
764 IDENTITY_MANAGEMENT_SVC.to_string(),
765 )
766 });
767
768 ApiContext::new(
769 None,
770 Some(api),
771 Some(tiers),
772 identity,
773 Some(environment),
774 None,
775 )
776}
777
778fn enrich_panic_hook() {
779 let hook = panic::take_hook();
780
781 let thread_id = thread::current().id();
782 panic::set_hook(Box::new(move |panic_info| {
783 hook(panic_info);
784 if thread::current().id() == thread_id {
785 println!("{}", Backtrace::capture());
786 }
787 }));
788}
789
790pub(crate) fn add_request_properties(request: RequestResponse, context_id: u32) -> RequestResponse {
791 let request = request
792 .with_property_if_missing(&["request", "id"], context_id.to_string())
793 .with_property_if_missing(&["source", "address"], "127.0.0.1")
794 .with_property_if_missing(&["destination", "address"], "127.0.0.2")
795 .with_property_if_missing(&["request", "scheme"], "http")
796 .with_property_if_missing(&["request", "protocol"], "1.1");
797
798 #[cfg(feature = "pdk")]
799 let request =
800 request.with_property_if_missing(&["anypoint/mulesoft/tracing_id"], context_id.to_string());
801
802 request
803}
804
805pub(crate) fn respond_http<C: Context + ?Sized>(
806 context: &mut C,
807 host: &Rc<RefCell<ProxyWasmStub>>,
808 backends: &Rc<RefCell<Backends>>,
809 context_id: u32,
810 id: u32,
811 upstream: String,
812 req: RequestResponse,
813) {
814 let response = backends
815 .borrow()
816 .upstreams
817 .get(&upstream)
818 .unwrap()
819 .call(req.into())
820 .inner;
821 let response_headers = response.headers.len();
822 let response_body = response.body.len();
823
824 host.borrow_mut().create_map(
825 context_id,
826 MapType::HttpCallResponseHeaders,
827 response
828 .headers
829 .into_iter()
830 .map(|(k, v)| (k, v.into_bytes()))
831 .collect(),
832 );
833 host.borrow_mut()
834 .create_buffer(context_id, BufferType::HttpCallResponseBody, response.body);
835 context.on_http_call_response(id, response_headers, response_body, 0);
836}
837
838pub(crate) fn respond_grpc<C: Context + ?Sized>(
839 context: &mut C,
840 host: &Rc<RefCell<ProxyWasmStub>>,
841 backends: &Rc<RefCell<Backends>>,
842 context_id: u32,
843 id: u32,
844 upstream: String,
845 req: UnitGrpcRequest,
846) {
847 let response = backends
848 .borrow()
849 .grpc_upstreams
850 .get(&upstream)
851 .unwrap()
852 .call(req);
853
854 host.borrow_mut()
855 .set_grpc_status((response.status_code, response.status));
856
857 let body_len = response.message.len();
858 host.borrow_mut()
859 .create_buffer(context_id, BufferType::GrpcReceiveBuffer, response.message);
860
861 context.on_grpc_call_response(id, response.status_code, body_len)
862}
863
864pub(crate) fn respond_call<C: Context + ?Sized>(
865 context: &mut C,
866 host: &Rc<RefCell<ProxyWasmStub>>,
867 backends: &Rc<RefCell<Backends>>,
868 context_id: u32,
869 id: u32,
870 upstream: String,
871 call: Call,
872) {
873 match call {
874 Call::Http(req) => respond_http(context, host, backends, context_id, id, upstream, req),
875 Call::Grpc(req) => respond_grpc(context, host, backends, context_id, id, upstream, req),
876 }
877}