use smallvec::SmallVec;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::ops::{Deref, Range};
use wasmer_types::{LocalFunctionIndex, MiddlewareError, ModuleInfo, WasmError, WasmResult};
use wasmparser::{BinaryReader, FunctionBody, Operator, OperatorsReader, ValType};
use super::error::from_binaryreadererror_wasmerror;
use crate::translator::environ::FunctionBinaryReader;
pub trait ModuleMiddleware: Debug + Send + Sync {
fn generate_function_middleware<'a>(
&self,
local_function_index: LocalFunctionIndex,
) -> Box<dyn FunctionMiddleware<'a> + 'a>;
fn transform_module_info(&self, _: &mut ModuleInfo) -> Result<(), MiddlewareError> {
Ok(())
}
}
pub trait FunctionMiddleware<'a>: Debug {
fn locals_info(&mut self, _locals: &[ValType]) {}
fn feed(
&mut self,
operator: Operator<'a>,
state: &mut MiddlewareReaderState<'a>,
) -> Result<(), MiddlewareError> {
state.push_operator(operator);
Ok(())
}
}
pub struct MiddlewareBinaryReader<'a> {
state: MiddlewareReaderState<'a>,
chain: Vec<Box<dyn FunctionMiddleware<'a> + 'a>>,
}
enum MiddlewareInnerReader<'a> {
Binary {
reader: BinaryReader<'a>,
original_reader: BinaryReader<'a>,
},
Operator(OperatorsReader<'a>),
}
pub struct MiddlewareReaderState<'a> {
inner: Option<MiddlewareInnerReader<'a>>,
pending_operations: VecDeque<Operator<'a>>,
local_decls_group: u32,
local_decls_group_read: u32,
locals: Vec<ValType>,
}
pub trait ModuleMiddlewareChain {
fn generate_function_middleware_chain<'a>(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware<'a> + 'a>>;
fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError>;
}
impl<T: Deref<Target = dyn ModuleMiddleware>> ModuleMiddlewareChain for [T] {
fn generate_function_middleware_chain<'a>(
&self,
local_function_index: LocalFunctionIndex,
) -> Vec<Box<dyn FunctionMiddleware<'a> + 'a>> {
self.iter()
.map(|x| x.generate_function_middleware(local_function_index))
.collect()
}
fn apply_on_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
for item in self {
item.transform_module_info(module_info)?;
}
Ok(())
}
}
impl<'a> MiddlewareReaderState<'a> {
pub fn push_operator(&mut self, operator: Operator<'a>) {
self.pending_operations.push_back(operator);
}
}
impl<'a> Extend<Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter);
}
}
impl<'a: 'b, 'b> Extend<&'b Operator<'a>> for MiddlewareReaderState<'a> {
fn extend<I: IntoIterator<Item = &'b Operator<'a>>>(&mut self, iter: I) {
self.pending_operations.extend(iter.into_iter().cloned());
}
}
impl<'a> MiddlewareBinaryReader<'a> {
pub fn new_with_offset(data: &'a [u8], original_offset: usize) -> Self {
let inner = BinaryReader::new(data, original_offset);
Self {
state: MiddlewareReaderState {
inner: Some(MiddlewareInnerReader::Binary {
original_reader: inner.clone(),
reader: inner,
}),
pending_operations: VecDeque::new(),
local_decls_group: 0,
local_decls_group_read: 0,
locals: vec![],
},
chain: vec![],
}
}
pub fn set_middleware_chain(&mut self, stages: Vec<Box<dyn FunctionMiddleware<'a> + 'a>>) {
self.chain = stages;
}
fn emit_locals_info(&mut self) {
for middleware in &mut self.chain {
middleware.locals_info(&self.state.locals)
}
}
}
impl<'a> FunctionBinaryReader<'a> for MiddlewareBinaryReader<'a> {
fn read_local_count(&mut self) -> WasmResult<u32> {
let total = match self.state.inner.as_mut().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader
.read_var_u32()
.map_err(from_binaryreadererror_wasmerror),
MiddlewareInnerReader::Operator(..) => Err(WasmError::InvalidWebAssembly {
message: "locals must be read before the function body".to_string(),
offset: self.current_position(),
}),
}?;
self.state.local_decls_group = total;
self.state.locals.reserve(total as usize);
if total == 0 {
self.emit_locals_info();
}
Ok(total)
}
fn read_local_decl(&mut self) -> WasmResult<(u32, ValType)> {
let (count, ty) = match self.state.inner.as_mut().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => {
let count = reader
.read_var_u32()
.map_err(from_binaryreadererror_wasmerror)?;
let ty: ValType = reader
.read::<ValType>()
.map_err(from_binaryreadererror_wasmerror)?;
Ok((count, ty))
}
MiddlewareInnerReader::Operator(..) => Err(WasmError::InvalidWebAssembly {
message: "locals must be read before the function body".to_string(),
offset: self.current_position(),
}),
}?;
for _ in 0..count {
self.state.locals.push(ty);
}
self.state.local_decls_group_read += 1;
if self.state.local_decls_group_read == self.state.local_decls_group {
self.emit_locals_info();
}
Ok((count, ty))
}
fn read_operator(&mut self) -> WasmResult<Operator<'a>> {
if let Some(inner) = self.state.inner.take() {
self.state.inner = Some(match inner {
MiddlewareInnerReader::Binary {
original_reader, ..
} => {
let operator_reader = FunctionBody::new(original_reader)
.get_operators_reader()
.map_err(from_binaryreadererror_wasmerror)?;
MiddlewareInnerReader::Operator(operator_reader)
}
other => other,
});
}
let read_operator = |state: &mut MiddlewareReaderState<'a>| {
let Some(MiddlewareInnerReader::Operator(operator_reader)) = state.inner.as_mut()
else {
unreachable!();
};
operator_reader
.read()
.map_err(from_binaryreadererror_wasmerror)
};
if self.chain.is_empty() {
return read_operator(&mut self.state);
}
while self.state.pending_operations.is_empty() {
let raw_op = read_operator(&mut self.state)?;
self.state.pending_operations.push_back(raw_op);
for stage in &mut self.chain {
let pending: SmallVec<[Operator<'a>; 2]> =
self.state.pending_operations.drain(0..).collect();
for pending_op in pending {
stage.feed(pending_op, &mut self.state)?;
}
}
}
Ok(self.state.pending_operations.pop_front().unwrap())
}
fn current_position(&self) -> usize {
match self.state.inner.as_ref().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader.current_position(),
MiddlewareInnerReader::Operator(operator_reader) => {
operator_reader.get_binary_reader().current_position()
}
}
}
fn original_position(&self) -> usize {
match self.state.inner.as_ref().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader.original_position(),
MiddlewareInnerReader::Operator(operator_reader) => operator_reader.original_position(),
}
}
fn bytes_remaining(&self) -> usize {
match self.state.inner.as_ref().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader.bytes_remaining(),
MiddlewareInnerReader::Operator(operator_reader) => {
operator_reader.get_binary_reader().bytes_remaining()
}
}
}
fn eof(&self) -> bool {
match self.state.inner.as_ref().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader.eof(),
MiddlewareInnerReader::Operator(operator_reader) => operator_reader.eof(),
}
}
fn range(&self) -> Range<usize> {
match self.state.inner.as_ref().expect("inner state must exist") {
MiddlewareInnerReader::Binary { reader, .. } => reader.range(),
MiddlewareInnerReader::Operator(operator_reader) => {
operator_reader.get_binary_reader().range()
}
}
}
}