use super::FuncIdx;
use crate::{
core::{wasm, UntypedVal},
ExternRef,
Func,
Ref,
Val,
F32,
F64,
};
use alloc::{boxed::Box, vec::Vec};
use core::{fmt, mem};
use wasmparser::AbstractHeapType;
#[cfg(feature = "simd")]
use crate::V128;
pub trait Eval {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal>;
}
pub trait EvalContext {
fn get_global(&self, index: u32) -> Option<Val>;
fn get_func(&self, index: u32) -> Option<Ref<Func>>;
}
pub struct EmptyEvalContext;
impl EvalContext for EmptyEvalContext {
fn get_global(&self, _index: u32) -> Option<Val> {
None
}
fn get_func(&self, _index: u32) -> Option<Ref<Func>> {
None
}
}
#[derive(Debug)]
pub enum Op {
Const(ConstOp),
Global(GlobalOp),
FuncRef(FuncRefOp),
Expr(ExprOp),
}
#[derive(Debug)]
pub struct ConstOp {
value: UntypedVal,
}
impl Eval for ConstOp {
fn eval(&self, _ctx: &dyn EvalContext) -> Option<UntypedVal> {
Some(self.value)
}
}
#[derive(Debug)]
pub struct GlobalOp {
global_index: u32,
}
impl Eval for GlobalOp {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal> {
ctx.get_global(self.global_index).map(UntypedVal::from)
}
}
#[derive(Debug)]
pub struct FuncRefOp {
function_index: u32,
}
impl Eval for FuncRefOp {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal> {
ctx.get_func(self.function_index).map(UntypedVal::from)
}
}
#[allow(clippy::type_complexity)]
pub struct ExprOp {
expr: Box<dyn Fn(&dyn EvalContext) -> Option<UntypedVal> + Send + Sync>,
}
impl fmt::Debug for ExprOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExprOp").finish()
}
}
impl Eval for ExprOp {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal> {
(self.expr)(ctx)
}
}
impl Op {
pub fn constant<T>(value: T) -> Self
where
T: Into<Val>,
{
Self::Const(ConstOp {
value: value.into().into(),
})
}
pub fn global(global_index: u32) -> Self {
Self::Global(GlobalOp { global_index })
}
pub fn funcref(function_index: u32) -> Self {
Self::FuncRef(FuncRefOp { function_index })
}
pub fn expr<T>(expr: T) -> Self
where
T: Fn(&dyn EvalContext) -> Option<UntypedVal> + Send + Sync + 'static,
{
Self::Expr(ExprOp {
expr: Box::new(expr),
})
}
}
impl Eval for Op {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal> {
match self {
Op::Const(op) => op.eval(ctx),
Op::Global(op) => op.eval(ctx),
Op::FuncRef(op) => op.eval(ctx),
Op::Expr(op) => op.eval(ctx),
}
}
}
#[derive(Debug)]
pub struct ConstExpr {
op: Op,
}
impl Eval for ConstExpr {
fn eval(&self, ctx: &dyn EvalContext) -> Option<UntypedVal> {
self.op.eval(ctx)
}
}
macro_rules! def_expr {
($lhs:ident, $rhs:ident, $expr:expr) => {{
Op::expr(move |ctx: &dyn EvalContext| -> Option<UntypedVal> {
let lhs = $lhs.eval(ctx)?;
let rhs = $rhs.eval(ctx)?;
Some($expr(lhs.into(), rhs.into()).into())
})
}};
}
#[derive(Debug, Default)]
pub struct ConstExprStack {
top: Option<Op>,
ops: Vec<Op>,
}
impl ConstExprStack {
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn push(&mut self, op: Op) {
let old_top = self.top.replace(op);
if let Some(old_top) = old_top {
self.ops.push(old_top);
}
}
pub fn pop(&mut self) -> Option<Op> {
let new_top = self.ops.pop();
mem::replace(&mut self.top, new_top)
}
pub fn pop2(&mut self) -> Option<(Op, Op)> {
let rhs = self.pop()?;
let lhs = self.pop()?;
Some((lhs, rhs))
}
}
impl ConstExpr {
pub fn new(expr: wasmparser::ConstExpr<'_>) -> Self {
fn expr_op<Lhs, Rhs, T>(stack: &mut ConstExprStack, expr: fn(Lhs, Rhs) -> T) -> Op
where
Lhs: From<UntypedVal> + 'static,
Rhs: From<UntypedVal> + 'static,
T: 'static,
UntypedVal: From<T>,
{
let (lhs, rhs) = stack
.pop2()
.expect("must have 2 operators on the stack due to Wasm validation");
match (lhs, rhs) {
(Op::Const(lhs), Op::Const(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Const(lhs), Op::Global(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Const(lhs), Op::FuncRef(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Const(lhs), Op::Expr(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Global(lhs), Op::Const(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Global(lhs), Op::Global(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Global(lhs), Op::FuncRef(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Global(lhs), Op::Expr(rhs)) => def_expr!(lhs, rhs, expr),
(Op::FuncRef(lhs), Op::Const(rhs)) => def_expr!(lhs, rhs, expr),
(Op::FuncRef(lhs), Op::Global(rhs)) => def_expr!(lhs, rhs, expr),
(Op::FuncRef(lhs), Op::FuncRef(rhs)) => def_expr!(lhs, rhs, expr),
(Op::FuncRef(lhs), Op::Expr(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Expr(lhs), Op::Const(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Expr(lhs), Op::Global(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Expr(lhs), Op::FuncRef(rhs)) => def_expr!(lhs, rhs, expr),
(Op::Expr(lhs), Op::Expr(rhs)) => def_expr!(lhs, rhs, expr),
}
}
let mut reader = expr.get_operators_reader();
let mut stack = ConstExprStack::default();
loop {
let wasm_op = reader
.read()
.unwrap_or_else(|error| panic!("invalid const expression operator: {error}"));
let op = match wasm_op {
wasmparser::Operator::I32Const { value } => Op::constant(value),
wasmparser::Operator::I64Const { value } => Op::constant(value),
wasmparser::Operator::F32Const { value } => {
Op::constant(F32::from_bits(value.bits()))
}
wasmparser::Operator::F64Const { value } => {
Op::constant(F64::from_bits(value.bits()))
}
#[cfg(feature = "simd")]
wasmparser::Operator::V128Const { value } => {
Op::constant(V128::from(value.i128() as u128))
}
wasmparser::Operator::GlobalGet { global_index } => Op::global(global_index),
wasmparser::Operator::RefNull { hty } => {
let value = match hty {
wasmparser::HeapType::Abstract {
shared: false,
ty: AbstractHeapType::Func,
} => Val::from(<Ref<Func>>::Null),
wasmparser::HeapType::Abstract {
shared: false,
ty: AbstractHeapType::Extern,
} => Val::from(<Ref<ExternRef>>::Null),
invalid => {
panic!("invalid heap type for `ref.null`: {invalid:?}")
}
};
Op::constant(value)
}
wasmparser::Operator::RefFunc { function_index } => Op::funcref(function_index),
wasmparser::Operator::I32Add => expr_op(&mut stack, wasm::i32_add),
wasmparser::Operator::I32Sub => expr_op(&mut stack, wasm::i32_sub),
wasmparser::Operator::I32Mul => expr_op(&mut stack, wasm::i32_mul),
wasmparser::Operator::I64Add => expr_op(&mut stack, wasm::i64_add),
wasmparser::Operator::I64Sub => expr_op(&mut stack, wasm::i64_sub),
wasmparser::Operator::I64Mul => expr_op(&mut stack, wasm::i64_mul),
wasmparser::Operator::End => break,
op => panic!("unexpected Wasm const expression operator: {op:?}"),
};
stack.push(op);
}
reader
.ensure_end()
.expect("Wasm validation requires const expressions to have an `end`");
let op = stack
.pop()
.expect("must contain the root const expression at this point");
debug_assert!(stack.is_empty());
Self { op }
}
pub fn new_funcref(function_index: u32) -> Self {
Self {
op: Op::FuncRef(FuncRefOp { function_index }),
}
}
pub fn funcref(&self) -> Option<FuncIdx> {
if let Op::FuncRef(op) = &self.op {
return Some(FuncIdx::from(op.function_index));
}
None
}
pub fn eval_const(&self) -> Option<UntypedVal> {
self.eval(&EmptyEvalContext)
}
pub fn eval_with_context<G, F>(&self, global_get: G, func_get: F) -> Option<UntypedVal>
where
G: Fn(u32) -> Val,
F: Fn(u32) -> Ref<Func>,
{
struct WrappedEvalContext<G, F> {
global_get: G,
func_get: F,
}
impl<G, F> EvalContext for WrappedEvalContext<G, F>
where
G: Fn(u32) -> Val,
F: Fn(u32) -> Ref<Func>,
{
fn get_global(&self, index: u32) -> Option<Val> {
Some((self.global_get)(index))
}
fn get_func(&self, index: u32) -> Option<Ref<Func>> {
Some((self.func_get)(index))
}
}
self.eval(&WrappedEvalContext::<G, F> {
global_get,
func_get,
})
}
}