use crate::backends::ldap::LdapBackend;
use crate::backends::{anypoint::AnypointBackend, os::OSBackend};
use crate::host::implementation::Call;
use crate::host::{facade::HostFacade, implementation::ProxyWasmStub};
use crate::tester::io::{RequestResponse, UnitHttpRequest, UnitHttpResponse};
use crate::tester::unit_test_request::InnerUnitTestRequest;
use crate::{Backend, GrpcBackend, UnitGrpcRequest, UnitLdapConfig};
use classy::Entrypoint;
use non_exhaustive::non_exhaustive;
use pdk_core::host::context::root::RootContextAdapter;
use pdk_core::init::configure;
use pdk_core::policy_context::api::{
ApiMetadata, FlexMetadata, Metadata, PlatformMetadata, PolicyMetadata,
};
use pdk_core::policy_context::metadata::{
AnypointContext, Api, ApiContext, ApiSla as CoreApiSla, EnvironmentContext,
};
use pdk_core::policy_context::metadata::{IdentityManagementContext, Tier as CoreTier};
use proxy_wasm_stub::stub::set_host;
use proxy_wasm_stub::traits::{Context, RootContext};
use proxy_wasm_stub::types::{BufferType, MapType};
use std::backtrace::Backtrace;
use std::cell::RefCell;
use std::collections::HashMap;
use std::panic;
use std::rc::Rc;
use std::task::Poll;
use std::time::Duration;
pub(super) const IDENTITY_MANAGEMENT_SVC: &str = "__identity_management_svc";
const CHUNK_SIZE: usize = 3;
pub struct UnitTest {
host: Rc<RefCell<ProxyWasmStub>>,
context: Option<RootContextAdapter>,
context_count: u32,
requests: Vec<UnitTestRequest>,
backends: Rc<RefCell<Backends>>,
anypoint: Rc<AnypointBackend>,
ldap: Rc<LdapBackend>,
stop_mode: Option<StopIterationMode>,
chunk_size: usize,
config: UnitTestConfig,
factory: Box<dyn Fn() -> RootContextAdapter>,
}
pub(crate) struct Backends {
pub backend: Box<dyn Backend>,
pub upstreams: HashMap<String, Rc<dyn Backend>>,
pub grpc_upstreams: HashMap<String, Rc<dyn GrpcBackend>>,
}
#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
pub enum StopIterationMode {
RequestsThenBody,
BodyThenRequests,
}
pub(crate) struct UnitTestConfig {
pub(crate) policy_config: String,
pub(crate) metadata: Metadata,
pub(crate) identity_management: Option<String>,
}
impl Default for UnitTestConfig {
fn default() -> Self {
let policy_name = "test_policy_id".to_string();
let policy_namespace = "test_policy_namespace".to_string();
let api_name = "test_api_id".to_string();
let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
Self {
policy_config: "{}".to_string(),
metadata: non_exhaustive!(Metadata {
flex_metadata: non_exhaustive!(FlexMetadata {
flex_name: "test_flex_name".to_string(),
flex_version: "1.0.0".to_string(),
}),
policy_metadata: non_exhaustive!(PolicyMetadata {
policy_name: policy_name,
policy_namespace: policy_namespace,
filter_name: filter_name,
}),
api_metadata: non_exhaustive!(ApiMetadata {
id: Some("1".to_string()),
name: Some(api_name),
version: Some("1.0.0".to_string()),
base_path: Some("/".to_string()),
slas: None,
}),
platform_metadata: non_exhaustive!(PlatformMetadata {
organization_id: "test-org-id".to_string(),
environment_id: "test-env-id".to_string(),
root_organization_id: "test-root-org-id".to_string(),
}),
}),
identity_management: None,
}
}
}
impl UnitTest {
pub(crate) fn new<C, T, E: Entrypoint<C, T> + Clone + 'static>(
entrypoint: E,
config: UnitTestConfig,
mut backends: Backends,
) -> Self {
let host = Rc::new(RefCell::new(ProxyWasmStub::default()));
set_host(HostFacade::new(Rc::clone(&host)));
let factory = Box::new(move || {
RootContextAdapter::new(
configure(0)
.entrypoint(entrypoint.clone())
.create_root_context(0),
)
});
let anypoint = Rc::new(AnypointBackend::default());
let anypoint_ref = Rc::clone(&anypoint);
backends
.upstreams
.entry("anypoint_service_name".to_string())
.or_insert(anypoint_ref);
let ldap = Rc::new(LdapBackend::default());
let ldap_ref = Rc::clone(&ldap);
backends
.upstreams
.entry("x-flex-services".to_string())
.or_insert(ldap_ref);
backends
.upstreams
.entry("x-flex-keyvalue-store".to_string())
.or_insert(Rc::new(OSBackend::default()));
let mut test = Self {
host,
context: None,
context_count: 0,
requests: Vec::new(),
backends: Rc::new(RefCell::new(backends)),
anypoint,
ldap,
#[cfg(feature = "enable_stop_iteration")]
stop_mode: Some(StopIterationMode::BodyThenRequests),
#[cfg(not(feature = "enable_stop_iteration"))]
stop_mode: None,
chunk_size: CHUNK_SIZE,
config,
factory,
};
test.init();
test
}
fn init(&mut self) {
let host = &self.host;
self.context_count = 1;
host.borrow_mut().create_context(0);
host.borrow_mut().create_buffer(
0,
BufferType::PluginConfiguration,
self.config.policy_config.as_bytes().to_vec(),
);
host.borrow_mut().set_context(0);
setup_metadata(host, &self.config);
let factory = &self.factory;
self.context = Some(factory());
enrich_panic_hook();
self.backends
.borrow()
.upstreams
.keys()
.for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
self.backends
.borrow()
.grpc_upstreams
.keys()
.for_each(|key| host.borrow_mut().add_upstream(key.to_string()));
self.context
.as_mut()
.unwrap()
.on_configure(self.config.policy_config.len());
self.respond_calls();
}
pub fn restart(&mut self) {
self.requests.clear();
let mut host = ProxyWasmStub::default();
host.clock = self.host.borrow().clock;
let host = Rc::new(RefCell::new(host));
set_host(HostFacade::new(Rc::clone(&host)));
self.host = host;
self.init();
}
#[cfg(feature = "enable_stop_iteration")]
pub fn set_host_mode(&mut self, mode: StopIterationMode) {
self.stop_mode = Some(mode);
}
#[cfg(feature = "experimental")]
pub fn get_metrics(&mut self) -> HashMap<String, u64> {
self.host
.borrow()
.get_metrics()
.into_iter()
.map(|(_id, (name, value))| (name, value))
.collect()
}
pub fn set_chunk_size(&mut self, chunk_size: usize) {
self.chunk_size = chunk_size;
}
pub fn add_contract_data<I, N, S, Sla>(
&mut self,
id: I,
name: N,
secret: Option<S>,
sla_id: Option<Sla>,
) where
I: Into<String>,
N: Into<String>,
S: Into<String>,
Sla: Into<String>,
{
self.anypoint.add_contract(
id.into(),
name.into(),
secret.map(|s| s.into()),
sla_id.map(|sla| sla.into()),
);
}
pub fn remove_contract_data<I>(&mut self, id: I)
where
I: Into<String>,
{
self.anypoint.remove_contract(id.into());
}
pub fn add_ldap_data<U, P>(&mut self, config: Option<UnitLdapConfig>, user: U, pass: P)
where
U: Into<String>,
P: Into<String>,
{
self.ldap.add_data(config, user, pass);
}
pub fn request_partial(&mut self, request: UnitHttpRequest) -> UnitTestRequest {
let request = request.inner;
let context_id = self.context_count;
self.context_count += 1;
self.host.borrow_mut().create_context(context_id);
let request = add_request_properties(request, context_id);
let props = request.properties();
self.host.borrow_mut().set_properties(context_id, props);
self.host.borrow_mut().set_context(context_id);
let http_context = self
.context
.as_ref()
.unwrap()
.create_http_context(context_id)
.unwrap();
let mut inner = UnitTestRequest::new(InnerUnitTestRequest::new(
context_id,
request,
http_context,
Rc::clone(&self.backends),
Rc::clone(&self.host),
self.stop_mode,
self.chunk_size,
));
if !inner.poll().is_ready() {
self.requests.push(inner.clone())
}
inner
}
fn forward_requests(&mut self) {
self.requests.retain_mut(|req| !req.poll().is_ready());
}
fn do_tick(&mut self) {
self.host.borrow_mut().set_context(0);
self.context.as_mut().unwrap().on_tick();
self.forward_requests();
self.respond_calls();
}
pub fn tick(&mut self) {
if !self.host.borrow_mut().tick().is_zero() {
self.do_tick();
}
}
pub fn sleep(&mut self, duration: Duration) {
let mut accumulated = Duration::new(0, 0);
while accumulated < duration {
let elapsed = self.host.borrow_mut().tick();
if elapsed.is_zero() {
self.host.borrow_mut().forward(duration - accumulated);
return;
}
accumulated += elapsed;
self.do_tick();
}
}
pub fn request(&mut self, request: UnitHttpRequest) -> UnitHttpResponse {
let mut response = self.request_partial(request);
loop {
if let Poll::Ready(value) = response.poll() {
return value;
} else {
self.tick()
}
}
}
#[cfg(feature = "experimental_logs")]
pub fn logs(&self) -> Vec<String> {
self.host.borrow().logs.borrow().clone()
}
fn respond_calls(&mut self) {
let mut pending_calls = self.host.borrow_mut().pending_calls(0);
while !pending_calls.is_empty() {
for (id, upstream, call) in pending_calls.into_iter() {
respond_call(
self.context.as_mut().unwrap(),
&self.host,
&self.backends,
0,
id,
upstream,
call,
);
}
pending_calls = self.host.borrow_mut().pending_calls(0);
}
}
}
#[derive(Clone)]
pub struct UnitTestRequest {
inner: Rc<RefCell<InnerUnitTestRequest>>,
}
impl UnitTestRequest {
pub(crate) fn new(inner: InnerUnitTestRequest) -> Self {
Self {
inner: Rc::new(RefCell::new(inner)),
}
}
pub fn poll(&mut self) -> Poll<UnitHttpResponse> {
self.inner.borrow_mut().poll()
}
}
fn setup_metadata(host: &Rc<RefCell<ProxyWasmStub>>, config: &UnitTestConfig) {
let api_name = config
.metadata
.api_metadata
.name
.clone()
.unwrap_or_default();
let policy_name = &config.metadata.policy_metadata.policy_name;
let policy_namespace = &config.metadata.policy_metadata.policy_namespace;
let filter_name = format!("{policy_name}.{policy_namespace}.{api_name}");
host.borrow_mut().create_property(
0,
vec!["node", "id"],
Some(config.metadata.flex_metadata.flex_name.clone()),
);
host.borrow_mut()
.create_property(0, vec!["plugin_name"], Some(filter_name));
let tiers = config
.metadata
.api_metadata
.slas
.as_ref()
.map(|slas| {
slas.iter()
.map(|sla| {
CoreApiSla::new(
sla.id.clone(),
sla.name.clone(),
sla.tiers
.iter()
.map(|tier| CoreTier::new(tier.requests, tier.period_in_millis))
.collect(),
)
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
let mut api = Api::new(
config.metadata.api_metadata.id.clone().unwrap_or_default(),
api_name,
"v1".to_string(),
config
.metadata
.api_metadata
.version
.clone()
.unwrap_or_default(),
None,
);
if let Some(path) = config.metadata.api_metadata.base_path.as_ref() {
api.set_base_path(path.clone())
}
let anypoint = AnypointContext::new(
"test_client".to_string(),
"test_secret".to_string(),
"anypoint_service_name".to_string(),
"https://anypoint.mulesoft.com".to_string(),
);
let environment = EnvironmentContext::new(
config.metadata.platform_metadata.organization_id.clone(),
config.metadata.platform_metadata.environment_id.clone(),
config
.metadata
.platform_metadata
.root_organization_id
.clone(),
"test_cluster_id".to_string(),
Some(anypoint),
None,
);
let identity = config.identity_management.as_ref().map(|url| {
IdentityManagementContext::new(
"client_id".to_string(),
"client_secret".to_string(),
url.clone(),
IDENTITY_MANAGEMENT_SVC.to_string(),
)
});
let context = ApiContext::new(
None,
Some(api),
Some(tiers),
identity,
Some(environment),
None,
);
let context = serde_json::to_string(&context).unwrap();
host.borrow_mut().create_property(
0,
vec![
"listener_metadata",
"filter_metadata",
config
.metadata
.api_metadata
.name
.as_deref()
.unwrap_or_default(),
"context",
],
Some(context),
)
}
fn enrich_panic_hook() {
let hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| {
hook(panic_info);
println!("{}", Backtrace::capture());
}));
}
pub(crate) fn add_request_properties(request: RequestResponse, context_id: u32) -> RequestResponse {
request
.with_property_if_missing(&["anypoint/mulesoft/tracing_id"], context_id.to_string())
.with_property_if_missing(&["request", "id"], context_id.to_string())
.with_property_if_missing(&["source", "address"], "127.0.0.1")
.with_property_if_missing(&["destination", "address"], "127.0.0.2")
.with_property_if_missing(&["request", "scheme"], "http")
.with_property_if_missing(&["request", "protocol"], "1.1")
}
pub(crate) fn respond_http<C: Context + ?Sized>(
context: &mut C,
host: &Rc<RefCell<ProxyWasmStub>>,
backends: &Rc<RefCell<Backends>>,
context_id: u32,
id: u32,
upstream: String,
req: RequestResponse,
) {
let response = backends
.borrow()
.upstreams
.get(&upstream)
.unwrap()
.call(req.into())
.inner;
let response_headers = response.headers.len();
let response_body = response.body.len();
host.borrow_mut().create_map(
context_id,
MapType::HttpCallResponseHeaders,
response
.headers
.into_iter()
.map(|(k, v)| (k, v.into_bytes()))
.collect(),
);
host.borrow_mut()
.create_buffer(context_id, BufferType::HttpCallResponseBody, response.body);
context.on_http_call_response(id, response_headers, response_body, 0);
}
pub(crate) fn respond_grpc<C: Context + ?Sized>(
context: &mut C,
host: &Rc<RefCell<ProxyWasmStub>>,
backends: &Rc<RefCell<Backends>>,
context_id: u32,
id: u32,
upstream: String,
req: UnitGrpcRequest,
) {
let response = backends
.borrow()
.grpc_upstreams
.get(&upstream)
.unwrap()
.call(req);
host.borrow_mut()
.set_grpc_status((response.status_code, response.status));
let body_len = response.message.len();
host.borrow_mut()
.create_buffer(context_id, BufferType::GrpcReceiveBuffer, response.message);
context.on_grpc_call_response(id, response.status_code, body_len)
}
pub(crate) fn respond_call<C: Context + ?Sized>(
context: &mut C,
host: &Rc<RefCell<ProxyWasmStub>>,
backends: &Rc<RefCell<Backends>>,
context_id: u32,
id: u32,
upstream: String,
call: Call,
) {
match call {
Call::Http(req) => respond_http(context, host, backends, context_id, id, upstream, req),
Call::Grpc(req) => respond_grpc(context, host, backends, context_id, id, upstream, req),
}
}