use std::fmt;
use std::sync::{Arc, Mutex};
use wasmer::wasmparser::{BlockType as WpTypeOrFuncType, Operator};
use wasmer::{
ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, LocalFunctionIndex, MiddlewareError,
MiddlewareReaderState, ModuleMiddleware, Mutability, Type,
};
use wasmer_types::{GlobalIndex, ModuleInfo};
#[derive(Clone)]
struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
impl MeteringGlobalIndexes {
fn remaining_points(&self) -> GlobalIndex {
self.0
}
fn points_exhausted(&self) -> GlobalIndex {
self.1
}
}
impl fmt::Debug for MeteringGlobalIndexes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MeteringGlobalIndexes")
.field("remaining_points", &self.remaining_points())
.field("points_exhausted", &self.points_exhausted())
.finish()
}
}
pub struct Metering<F: Fn(&Operator) -> u64 + Send + Sync> {
initial_limit: u64,
cost_function: Arc<F>,
global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
}
pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Send + Sync> {
cost_function: Arc<F>,
global_indexes: MeteringGlobalIndexes,
accumulated_cost: u64,
}
impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function: Arc::new(cost_function),
global_indexes: Mutex::new(None),
}
}
}
pub fn is_accounting(operator: &Operator) -> bool {
matches!(
operator,
Operator::Loop { .. } | Operator::End | Operator::If { .. } | Operator::Else | Operator::Br { .. } | Operator::BrTable { .. } | Operator::BrIf { .. } | Operator::Call { .. } | Operator::CallIndirect { .. } | Operator::Return | Operator::Throw { .. } | Operator::ThrowRef | Operator::Rethrow { .. } | Operator::Delegate { .. } | Operator::Catch { .. } | Operator::ReturnCall { .. } | Operator::ReturnCallIndirect { .. } | Operator::BrOnCast { .. } | Operator::BrOnCastFail { .. } | Operator::CallRef { .. } | Operator::ReturnCallRef { .. } | Operator::BrOnNull { .. } | Operator::BrOnNonNull { .. } )
}
impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for Metering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Metering")
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("global_indexes", &self.global_indexes)
.finish()
}
}
impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Metering<F> {
fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
Box::new(FunctionMetering {
cost_function: self.cost_function.clone(),
global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
accumulated_cost: 0,
})
}
fn transform_module_info(&self, module_info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
let mut global_indexes = self.global_indexes.lock().unwrap();
if global_indexes.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));
module_info.exports.insert(
"wasmer_metering_remaining_points".to_string(),
ExportIndex::Global(remaining_points_global_index),
);
let points_exhausted_global_index = module_info
.globals
.push(GlobalType::new(Type::I32, Mutability::Var));
module_info
.global_initializers
.push(GlobalInit::I32Const(0));
module_info.exports.insert(
"wasmer_metering_points_exhausted".to_string(),
ExportIndex::Global(points_exhausted_global_index),
);
*global_indexes = Some(MeteringGlobalIndexes(
remaining_points_global_index,
points_exhausted_global_index,
));
Ok(())
}
}
impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("global_indexes", &self.global_indexes)
.finish()
}
}
impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMetering<F> {
fn feed<'a>(
&mut self,
operator: Operator<'a>,
state: &mut MiddlewareReaderState<'a>,
) -> Result<(), MiddlewareError> {
self.accumulated_cost += (self.cost_function)(&operator);
if is_accounting(&operator) && self.accumulated_cost > 0 {
state.extend(&[
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64LtU,
Operator::If {
blockty: WpTypeOrFuncType::Empty,
},
Operator::I32Const { value: 1 },
Operator::GlobalSet {
global_index: self.global_indexes.points_exhausted().as_u32(),
},
Operator::Unreachable,
Operator::End,
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64Sub,
Operator::GlobalSet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
]);
self.accumulated_cost = 0;
}
state.push_operator(operator);
Ok(())
}
}