use crate::tester::io::RequestResponse;
use crate::UnitGrpcRequest;
use proxy_wasm_stub::stub::Host;
use proxy_wasm_stub::types::{BufferType, Bytes, LogLevel, MapType, MetricType, Status};
#[cfg(feature = "experimental_logs")]
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::ops::Add;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct ProxyWasmStub {
pub(crate) clock: SystemTime,
tick: Duration,
context_id: u32,
contexts: HashMap<u32, ProxyWasmContextStub>,
shared_data: HashMap<String, (Bytes, u32)>,
metrics: HashMap<u32, (String, u64)>,
upstreams: Vec<String>,
#[cfg(feature = "experimental_logs")]
pub(crate) logs: RefCell<Vec<String>>,
}
impl Default for ProxyWasmStub {
fn default() -> Self {
Self {
clock: SystemTime::UNIX_EPOCH,
tick: Default::default(),
context_id: 0,
contexts: Default::default(),
shared_data: Default::default(),
metrics: Default::default(),
upstreams: Default::default(),
#[cfg(feature = "experimental_logs")]
logs: Default::default(),
}
}
}
#[derive(Default)]
pub(crate) struct ProxyWasmContextStub {
properties: HashMap<Vec<String>, Bytes>,
buffers: HashMap<BufferType, Bytes>,
maps: HashMap<MapType, Vec<(String, Bytes)>>,
resume: bool,
send_response: bool,
next_req: u32,
requests: Vec<(u32, String, Call)>,
request_status: Option<(u32, Option<String>)>,
}
impl ProxyWasmContextStub {
pub(crate) fn get_buffer(
&self,
buffer_type: BufferType,
start: usize,
max_size: usize,
) -> Option<Bytes> {
self.buffers.get(&buffer_type).map(|bytes| {
if start > bytes.len() {
vec![]
} else if start + max_size >= bytes.len() {
bytes[start..].to_vec()
} else {
bytes[start..start + max_size].to_vec()
}
})
}
pub(crate) fn set_buffer(
&mut self,
buffer_type: BufferType,
start: usize,
size: usize,
value: &[u8],
) {
if let Some(bytes) = self.buffers.get_mut(&buffer_type) {
let mut new_bytes = Vec::new();
if start > bytes.len() {
new_bytes.extend_from_slice(bytes);
} else {
new_bytes.extend_from_slice(&bytes[..start]);
}
new_bytes.extend_from_slice(value);
if start + size < bytes.len() {
new_bytes.extend_from_slice(&bytes[start + size..]);
}
*bytes = new_bytes;
}
}
pub(crate) fn get_map_bytes(&self, map_type: MapType) -> Vec<(String, Bytes)> {
self.maps.get(&map_type).cloned().unwrap_or_default()
}
pub(crate) fn get_map_value_bytes(&self, map_type: MapType, key: &str) -> Option<Bytes> {
self.maps.get(&map_type).unwrap().iter().find_map(|(k, v)| {
if k.eq_ignore_ascii_case(key) {
Some(v.clone())
} else {
None
}
})
}
pub(crate) fn set_map_value_bytes(
&mut self,
map_type: MapType,
key: &str,
value: Option<&[u8]>,
) {
let map = self.maps.get_mut(&map_type).unwrap();
map.retain(|(k, _)| !k.eq_ignore_ascii_case(key));
if let Some(value) = value {
map.push((key.to_string(), value.to_vec()));
}
}
pub(crate) fn add_map_value_bytes(&mut self, map_type: MapType, key: &str, value: &[u8]) {
self.maps
.get_mut(&map_type)
.unwrap()
.push((key.to_string(), value.to_vec()));
}
pub(crate) fn create_buffer(&mut self, buffer_type: BufferType, value: Vec<u8>) {
self.buffers.insert(buffer_type, value);
}
pub(crate) fn create_map_bytes(&mut self, map_type: MapType, value: Vec<(String, Bytes)>) {
self.maps.insert(map_type, value);
}
pub(crate) fn read_buffer(&self, buffer_type: BufferType) -> Vec<u8> {
self.buffers.get(&buffer_type).cloned().unwrap_or_default()
}
pub(crate) fn clear_resume(&mut self) -> bool {
let result = self.resume;
self.resume = false;
result
}
pub(crate) fn clear_send_response(&mut self) -> bool {
let result = self.send_response;
self.send_response = false;
result
}
pub(crate) fn get_property(&self, path: Vec<&str>) -> Option<Bytes> {
self.properties
.get(&path.iter().map(|s| s.to_string()).collect::<Vec<_>>())
.cloned()
}
pub(crate) fn create_property<K: Into<String>, J: Into<Vec<u8>>>(
&mut self,
path: Vec<K>,
value: Option<J>,
) {
self.properties.insert(
path.into_iter().map(|s| s.into()).collect::<Vec<_>>(),
value.map(|bytes| bytes.into()).unwrap_or_default(),
);
}
pub(crate) fn dispatch_call(&mut self, upstream: String, call: Call) -> u32 {
let req_id = self.next_req;
self.next_req += 1;
self.requests.push((req_id, upstream, call));
req_id
}
pub(crate) fn get_grpc_status(&self) -> Option<(u32, Option<String>)> {
self.request_status.clone()
}
pub(crate) fn set_grpc_status(&mut self, status: (u32, Option<String>)) {
self.request_status = Some(status);
}
pub(crate) fn cancel_grpc_call(&mut self, token_id: u32) {
self.requests.retain(|(id, _, _)| *id != token_id);
}
pub(crate) fn set_resume(&mut self) {
self.resume = true;
}
pub(crate) fn set_send_response(&mut self) {
self.send_response = true;
}
pub(crate) fn send_http_response(
&mut self,
status_code: u32,
headers: Vec<(&str, &str)>,
body: Option<&[u8]>,
) {
self.set_send_response();
self.create_buffer(
BufferType::HttpResponseBody,
body.map(|b| b.to_vec()).unwrap_or_default(),
);
let mut effective_headers =
vec![(":status".to_string(), format!("{status_code}").into_bytes())];
effective_headers.extend(
headers
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string().into_bytes())),
);
self.create_map_bytes(MapType::HttpResponseHeaders, effective_headers);
}
pub(crate) fn pending_calls(&mut self) -> Vec<(u32, String, Call)> {
self.requests.split_off(0)
}
pub(crate) fn set_properties(&mut self, props: HashMap<Vec<String>, Bytes>) {
self.properties = props;
}
pub(crate) fn get_properties(&self) -> HashMap<Vec<String>, Bytes> {
self.properties.clone()
}
}
pub(crate) enum Call {
Http(RequestResponse),
Grpc(UnitGrpcRequest),
}
impl Host for ProxyWasmStub {
fn log(&self, level: LogLevel, message: &str) -> Result<(), Status> {
let log_line = format!(
"{} - {level:?}: {message}",
self.clock.duration_since(UNIX_EPOCH).unwrap().as_millis()
);
println!("{log_line}");
#[cfg(feature = "experimental_logs")]
self.logs.borrow_mut().push(log_line);
Ok(())
}
fn get_property(&self, path: Vec<&str>) -> Result<Option<Bytes>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_property(path))
}
fn set_property(&mut self, path: Vec<&str>, value: Option<&[u8]>) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.create_property(path, value);
Ok(())
}
fn get_shared_data(&self, key: &str) -> Result<(Option<Bytes>, Option<u32>), Status> {
Ok(self
.shared_data
.get(key)
.cloned()
.map(|(value, cas)| (Some(value), Some(cas)))
.unwrap_or((None, None)))
}
fn set_shared_data(
&mut self,
key: &str,
value: Option<&[u8]>,
cas: Option<u32>,
) -> Result<(), Status> {
let cas = cas.unwrap_or_default();
match self.shared_data.entry(key.to_string()) {
Entry::Occupied(mut entry) => {
let (_, old_cas) = entry.get();
if cas == 0 || cas.eq(old_cas) {
entry.insert((value.unwrap_or_default().to_vec(), old_cas + 1));
} else {
return Err(Status::CasMismatch);
}
}
Entry::Vacant(entry) => {
entry.insert((value.unwrap_or_default().to_vec(), 1));
}
}
Ok(())
}
fn get_buffer(
&self,
buffer_type: BufferType,
start: usize,
max_size: usize,
) -> Result<Option<Bytes>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_buffer(buffer_type, start, max_size))
}
fn set_buffer(
&mut self,
buffer_type: BufferType,
start: usize,
size: usize,
value: &[u8],
) -> Result<(), Status> {
self.contexts.get_mut(&self.context_id).unwrap().set_buffer(
buffer_type,
start,
size,
value,
);
Ok(())
}
fn set_effective_context(&mut self, context_id: u32) -> Result<(), Status> {
self.do_set_context(context_id);
Ok(())
}
fn get_map(&self, map_type: MapType) -> Result<Vec<(String, String)>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_map_bytes(map_type)
.into_iter()
.map(|(k, v)| (k, String::from_utf8_lossy(&v).to_string()))
.collect())
}
fn get_map_bytes(&self, map_type: MapType) -> Result<Vec<(String, Bytes)>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_map_bytes(map_type))
}
fn set_map(&mut self, map_type: MapType, map: Vec<(&str, &str)>) -> Result<(), Status> {
self.do_set_map_bytes(map_type, map)
}
fn set_map_bytes(&mut self, map_type: MapType, map: Vec<(&str, &[u8])>) -> Result<(), Status> {
self.do_set_map_bytes(map_type, map)
}
fn get_map_value(&self, map_type: MapType, key: &str) -> Result<Option<String>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_map_value_bytes(map_type, key)
.map(|v| String::from_utf8_lossy(&v).to_string()))
}
fn get_map_value_bytes(&self, map_type: MapType, key: &str) -> Result<Option<Bytes>, Status> {
Ok(self
.contexts
.get(&self.context_id)
.unwrap()
.get_map_value_bytes(map_type, key))
}
fn set_map_value(
&mut self,
map_type: MapType,
key: &str,
value: Option<&str>,
) -> Result<(), Status> {
self.set_map_value_bytes(map_type, key, value.map(|v| v.as_bytes()))
}
fn set_map_value_bytes(
&mut self,
map_type: MapType,
key: &str,
value: Option<&[u8]>,
) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.set_map_value_bytes(map_type, key, value);
Ok(())
}
fn add_map_value(&mut self, map_type: MapType, key: &str, value: &str) -> Result<(), Status> {
self.add_map_value_bytes(map_type, key, value.as_bytes())
}
fn add_map_value_bytes(
&mut self,
map_type: MapType,
key: &str,
value: &[u8],
) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.add_map_value_bytes(map_type, key, value);
Ok(())
}
fn send_http_response(
&mut self,
status_code: u32,
headers: Vec<(&str, &str)>,
body: Option<&[u8]>,
) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.send_http_response(status_code, headers, body);
Ok(())
}
fn get_current_time(&self) -> Result<SystemTime, Status> {
Ok(self.clock)
}
fn set_tick_period(&mut self, period: Duration) -> Result<(), Status> {
self.tick = period;
Ok(())
}
fn resume_http_request(&mut self) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.set_resume();
Ok(())
}
fn resume_http_response(&mut self) -> Result<(), Status> {
self.resume_http_request()
}
fn define_metric(&mut self, _metric_type: MetricType, name: &str) -> Result<u32, Status> {
if let Some(metric) = self.metrics.iter().find_map(|(id, (metric_name, _))| {
if metric_name.eq(name) {
Some(*id)
} else {
None
}
}) {
return Ok(metric);
}
let metric_id = self.metrics.len() as u32;
self.metrics.insert(metric_id, (name.to_string(), 0));
Ok(metric_id)
}
fn get_metric(&self, metric_id: u32) -> Result<u64, Status> {
self.metrics
.get(&metric_id)
.map(|(_, value)| *value)
.ok_or(Status::NotFound)
}
fn record_metric(&mut self, metric_id: u32, value: u64) -> Result<(), Status> {
let metric = self.metrics.get_mut(&metric_id).ok_or(Status::NotFound)?;
metric.1 = value;
Ok(())
}
fn increment_metric(&mut self, metric_id: u32, offset: i64) -> Result<(), Status> {
let metric = self.metrics.get_mut(&metric_id).ok_or(Status::NotFound)?;
metric.1 = (metric.1 as i128 + offset as i128) as u64;
Ok(())
}
fn dispatch_http_call(
&mut self,
upstream: &str,
headers: Vec<(&str, &str)>,
body: Option<&[u8]>,
_trailers: Vec<(&str, &str)>,
_timeout: Duration,
) -> Result<u32, Status> {
self.dispatch_call(upstream, Call::Http(RequestResponse::new(headers, body)))
}
fn dispatch_grpc_call(
&mut self,
upstream_name: &str,
service_name: &str,
method_name: &str,
initial_metadata: Vec<(&str, &[u8])>,
message: Option<&[u8]>,
_timeout: Duration,
) -> Result<u32, Status> {
self.dispatch_call(
upstream_name,
Call::Grpc(UnitGrpcRequest::new(
service_name,
method_name,
initial_metadata,
message,
)),
)
}
fn get_grpc_status(&self) -> Result<(u32, Option<String>), Status> {
self.contexts
.get(&self.context_id)
.unwrap()
.get_grpc_status()
.ok_or(Status::NotFound)
}
fn cancel_grpc_call(&mut self, token_id: u32) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.cancel_grpc_call(token_id);
Ok(())
}
fn call_foreign_function(
&mut self,
function_name: &str,
arguments: Option<&[u8]>,
) -> Result<Option<Bytes>, Status> {
match function_name {
"get_shared_data_keys" => {
let mut result = vec![];
for key in self.shared_data.keys() {
result.extend(key.as_bytes());
result.push(0u8)
}
Ok(Some(result))
}
"remove_shared_data_key" => {
const U32_SIZE: usize = std::mem::size_of::<u32>();
let arguments = arguments.ok_or(Status::BadArgument)?;
if arguments.len() < U32_SIZE {
return Err(Status::BadArgument);
}
let (key_bytes, cas_bytes) = arguments.split_at(arguments.len() - U32_SIZE);
let key = std::str::from_utf8(key_bytes).map_err(|_| Status::BadArgument)?;
let cas = u32::from_le_bytes(cas_bytes.try_into().unwrap());
match self.shared_data.get(key) {
Some((value, old_cas)) => {
if cas != 0 && cas != *old_cas {
return Err(Status::CasMismatch);
}
let mut result = value.clone();
result.extend(old_cas.to_le_bytes());
self.shared_data.remove(key);
Ok(Some(result))
}
None => Ok(None),
}
}
_ => Err(Status::NotFound),
}
}
}
impl ProxyWasmStub {
fn do_set_map_bytes<A: AsRef<[u8]>>(
&mut self,
map_type: MapType,
map: Vec<(&str, A)>,
) -> Result<(), Status> {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.create_map_bytes(
map_type,
map.into_iter()
.map(|(k, v)| (k.to_string(), v.as_ref().to_vec()))
.collect(),
);
Ok(())
}
fn do_create_context(&mut self, context_id: u32) {
self.contexts
.insert(context_id, ProxyWasmContextStub::default());
}
fn do_set_context(&mut self, context_id: u32) {
self.context_id = context_id;
}
fn do_tick(&mut self) -> Duration {
let response = self.clock.add(self.tick);
self.clock = response;
self.tick
}
fn do_forward(&mut self, duration: Duration) {
if !self.tick.is_zero() {
panic!("Force uptading the clock time when tick period is set")
}
self.clock = self.clock.add(duration)
}
fn dispatch_call(&mut self, upstream: &str, call: Call) -> Result<u32, Status> {
let upstream = upstream.to_string();
if !self.upstreams.contains(&upstream) {
return Err(Status::BadArgument);
}
let req_id = self
.contexts
.get_mut(&self.context_id)
.unwrap()
.dispatch_call(upstream, call);
Ok(req_id)
}
}
impl ProxyWasmStub {
pub fn create_context(&mut self, context_id: u32) {
self.do_create_context(context_id);
}
pub fn set_context(&mut self, context_id: u32) {
self.do_set_context(context_id);
}
pub fn create_property<K: Into<String>, J: Into<Vec<u8>>>(
&mut self,
context_id: u32,
path: Vec<K>,
value: Option<J>,
) {
self.contexts
.get_mut(&context_id)
.unwrap()
.create_property(path, value)
}
pub fn set_properties(&mut self, context_id: u32, props: HashMap<Vec<String>, Bytes>) {
self.contexts
.get_mut(&context_id)
.unwrap()
.set_properties(props);
}
pub fn get_properties(&self, context_id: u32) -> HashMap<Vec<String>, Bytes> {
self.contexts.get(&context_id).unwrap().get_properties()
}
pub fn create_buffer(&mut self, context_id: u32, buffer_type: BufferType, value: Vec<u8>) {
self.contexts
.get_mut(&context_id)
.unwrap()
.create_buffer(buffer_type, value)
}
pub fn create_map(&mut self, context_id: u32, map_type: MapType, value: Vec<(String, Bytes)>) {
self.contexts
.get_mut(&context_id)
.unwrap()
.create_map_bytes(map_type, value)
}
pub fn read_map(&self, contex_id: u32, map_type: MapType) -> Vec<(String, Bytes)> {
self.contexts
.get(&contex_id)
.unwrap()
.get_map_bytes(map_type)
}
pub fn read_buffer(&self, contex_id: u32, buffer_type: BufferType) -> Vec<u8> {
self.contexts
.get(&contex_id)
.unwrap()
.read_buffer(buffer_type)
}
pub fn add_upstream(&mut self, upstream: String) {
self.upstreams.push(upstream);
}
pub fn tick(&mut self) -> Duration {
self.do_tick()
}
pub fn forward(&mut self, duration: Duration) {
self.do_forward(duration)
}
pub fn clear_resume(&mut self, context_id: u32) -> bool {
self.contexts.get_mut(&context_id).unwrap().clear_resume()
}
pub fn clear_send_response(&mut self, context_id: u32) -> bool {
self.contexts
.get_mut(&context_id)
.unwrap()
.clear_send_response()
}
pub fn pending_calls(&mut self, context_id: u32) -> Vec<(u32, String, Call)> {
self.contexts.get_mut(&context_id).unwrap().pending_calls()
}
pub fn set_grpc_status(&mut self, status: (u32, Option<String>)) {
self.contexts
.get_mut(&self.context_id)
.unwrap()
.set_grpc_status(status)
}
#[cfg(feature = "experimental")]
pub fn get_metrics(&self) -> HashMap<u32, (String, u64)> {
self.metrics.clone()
}
}