use loupe::MemoryUsage;
use smallvec::SmallVec;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::ops::Deref;
use wasmer_types::{LocalFunctionIndex, ModuleInfo};
use wasmparser::{BinaryReader, Operator, Range, Type};
use crate::error::{MiddlewareError, WasmResult};
use crate::translator::environ::FunctionBinaryReader;
pub trait ModuleMiddleware: Debug + Send + Sync + MemoryUsage {
fn generate_function_middleware(
&self,
local_function_index: LocalFunctionIndex,
) -> Box<dyn FunctionMiddleware>;
fn transform_module_info(&self, _: &mut ModuleInfo) {}
}
pub trait FunctionMiddleware: Debug {
fn feed<'a>(
&mut self,
operator: Operator<'a>,
state: &mut MiddlewareReaderState<'a>,
) -> Result<(), MiddlewareError> {
state.push_operator(operator);
Ok(())
}
}
#[derive(Debug)]
pub struct MiddlewareBinaryReader<'a> {
state: MiddlewareReaderState<'a>,
chain: Vec<Box<dyn FunctionMiddleware>>,
}
#[derive(Debug)]
pub struct MiddlewareReaderState<'a> {
inner: BinaryReader<'a>,
pending_operations: VecDeque<Operator<'a>>,
}
pub trait ModuleMiddlewareChain {
fn generate_function_middleware_chain(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware>>;
fn apply_on_module_info(&self, module_info: &mut ModuleInfo);
}
impl<T: Deref<Target = dyn ModuleMiddleware>> ModuleMiddlewareChain for [T] {
fn generate_function_middleware_chain(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware>> {
self.iter()
.map(|x| x.generate_function_middleware(local_function_index))
.collect()
}
fn apply_on_module_info(&self, module_info: &mut ModuleInfo) {
for item in self {
item.transform_module_info(module_info);
}
}
}
impl<'a> MiddlewareReaderState<'a> {
pub fn push_operator(&mut self, operator: Operator<'a>) {
self.pending_operations.push_back(operator);
}
}
impl<'a> Extend<Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter);
}
}
impl<'a: 'b, 'b> Extend<&'b Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = &'b Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter.into_iter().cloned());
}
}
impl<'a> MiddlewareBinaryReader<'a> {
pub fn new_with_offset(data: &'a [u8], original_offset: usize) -> Self {
let inner = BinaryReader::new_with_offset(data, original_offset);
Self {
state: MiddlewareReaderState {
inner,
pending_operations: VecDeque::new(),
},
chain: vec![],
}
}
pub fn set_middleware_chain(&mut self, stages: Vec<Box<dyn FunctionMiddleware>>) {
self.chain = stages;
}
}
impl<'a> FunctionBinaryReader<'a> for MiddlewareBinaryReader<'a> {
fn read_local_count(&mut self) -> WasmResult<u32> {
Ok(self.state.inner.read_var_u32()?)
}
fn read_local_decl(&mut self) -> WasmResult<(u32, Type)> {
let count = self.state.inner.read_var_u32()?;
let ty = self.state.inner.read_type()?;
Ok((count, ty))
}
fn read_operator(&mut self) -> WasmResult<Operator<'a>> {
if self.chain.is_empty() {
return Ok(self.state.inner.read_operator()?);
}
while self.state.pending_operations.is_empty() {
let raw_op = self.state.inner.read_operator()?;
self.state.pending_operations.push_back(raw_op);
for stage in &mut self.chain {
let pending: SmallVec<[Operator<'a>; 2]> =
self.state.pending_operations.drain(0..).collect();
for pending_op in pending {
stage.feed(pending_op, &mut self.state)?;
}
}
}
Ok(self.state.pending_operations.pop_front().unwrap())
}
fn current_position(&self) -> usize {
self.state.inner.current_position()
}
fn original_position(&self) -> usize {
self.state.inner.original_position()
}
fn bytes_remaining(&self) -> usize {
self.state.inner.bytes_remaining()
}
fn eof(&self) -> bool {
self.state.inner.eof()
}
fn range(&self) -> Range {
self.state.inner.range()
}
}