use crate::compiler::RuleId;
use crate::wasm;
use rustc_hash::FxHashMap;
use std::mem;
use walrus::ValType::{F64, I32, I64};
use walrus::ir::ExtendedLoad::ZeroExtend;
use walrus::ir::{BinaryOp, Block, InstrSeqId, LoadKind, MemArg, UnaryOp};
use walrus::{
FunctionBuilder, FunctionId, GlobalId, InstrSeqBuilder, MemoryId, Module,
};
use super::WasmSymbols;
macro_rules! global_var {
($module:ident, $name:ident, $ty:ident) => {
let ($name, _) = $module.add_import_global(
"yara_x",
stringify!($name),
$ty,
true, false, );
};
}
macro_rules! global_const {
($module:ident, $name:ident, $ty:ident) => {
let ($name, _) = $module.add_import_global(
"yara_x",
stringify!($name),
$ty,
false, false, );
};
}
pub(crate) struct WasmModuleBuilder {
module: walrus::Module,
wasm_symbols: WasmSymbols,
wasm_exports: FxHashMap<String, FunctionId>,
main_func: FunctionBuilder,
namespace_func: FunctionBuilder,
rules_func: FunctionBuilder,
namespace_block: InstrSeqId,
rule_id: RuleId,
num_rules: usize,
num_namespaces: usize,
namespaces_per_func: usize,
rules_per_func: usize,
global_rule: bool,
}
impl WasmModuleBuilder {
const RULES_FUNC_RET: [walrus::ValType; 1] = [I32; 1];
pub fn new() -> Self {
let config = walrus::ModuleConfig::new();
let mut module = walrus::Module::with_config(config);
let mut wasm_exports = FxHashMap::default();
for export in super::wasm_exports() {
let ty = module.types.add(
export.func.walrus_args().as_slice(),
export.func.walrus_results().as_slice(),
);
let fully_qualified_name = export.fully_qualified_mangled_name();
let (func_id, _) = module.add_import_func(
export.rust_module_path,
fully_qualified_name.as_str(),
ty,
);
wasm_exports.insert(fully_qualified_name, func_id);
}
global_const!(module, matching_patterns_bitmap_base, I32);
global_var!(module, filesize, I64);
global_var!(module, pattern_search_done, I32);
let (main_memory, _) = module.add_import_memory(
"yara_x",
"main_memory",
false, false, 1,
None,
None,
);
let check_for_pattern_match = Self::gen_check_for_pattern_match(
&mut module,
main_memory,
matching_patterns_bitmap_base,
);
let wasm_symbols = WasmSymbols {
main_memory,
check_for_pattern_match,
filesize,
pattern_search_done,
i64_tmp_a: module.locals.add(I64),
i64_tmp_b: module.locals.add(I64),
i32_tmp: module.locals.add(I32),
f64_tmp: module.locals.add(F64),
};
let mut namespace_func =
FunctionBuilder::new(&mut module.types, &[], &[]);
let rules_func = FunctionBuilder::new(
&mut module.types,
&[],
&Self::RULES_FUNC_RET,
);
let main_func = FunctionBuilder::new(&mut module.types, &[], &[I32]);
let namespace_block = namespace_func.dangling_instr_seq(None).id();
Self {
module,
wasm_symbols,
wasm_exports,
main_func,
namespace_func,
rules_func,
namespace_block,
rule_id: RuleId::default(),
num_rules: 0,
num_namespaces: 0,
namespaces_per_func: 10,
rules_per_func: 10,
global_rule: false,
}
}
pub fn wasm_symbols(&self) -> WasmSymbols {
self.wasm_symbols.clone()
}
pub fn wasm_exports(&self) -> FxHashMap<String, FunctionId> {
self.wasm_exports.clone()
}
pub fn namespaces_per_func(&mut self, n: usize) -> &mut Self {
self.namespaces_per_func = n;
self
}
pub fn rules_per_func(&mut self, n: usize) -> &mut Self {
self.rules_per_func = n;
self
}
pub fn start_rule(
&mut self,
rule_id: RuleId,
global: bool,
) -> InstrSeqBuilder<'_> {
if self.num_rules == self.rules_per_func {
self.finish_rule_func();
self.num_rules = 0;
}
self.num_rules += 1;
self.rule_id = rule_id;
self.global_rule = global;
self.rules_func.func_body()
}
pub fn finish_rule(&mut self) {
let rule_no_match =
self.function_id(wasm::export__rule_no_match.mangled_name);
let rule_match =
self.function_id(wasm::export__rule_match.mangled_name);
let mut instr = self.rules_func.func_body();
instr.unop(UnaryOp::I32Eqz).if_else(
None,
|then_| {
if self.global_rule {
then_
.i32_const(self.rule_id.into())
.call(rule_no_match)
.i32_const(1)
.return_();
} else {
#[cfg(feature = "rules-profiling")]
then_
.i32_const(self.rule_id.into())
.call(rule_no_match);
}
},
|else_| {
else_.i32_const(self.rule_id.into()).call(rule_match);
},
);
}
pub fn new_namespace(&mut self) {
self.finish_rule_func();
self.finish_namespace_block();
if self.num_namespaces == self.namespaces_per_func {
self.finish_namespace_func();
self.num_namespaces = 0;
}
self.num_namespaces += 1;
self.num_rules = 0;
}
pub fn build(mut self) -> walrus::Module {
self.finish_rule_func();
self.finish_namespace_block();
self.finish_namespace_func();
self.main_func.func_body().i32_const(0);
let main_func =
self.main_func.finish(Vec::new(), &mut self.module.funcs);
self.module.exports.add("main", main_func);
self.module
}
}
impl WasmModuleBuilder {
pub fn function_id(&self, fn_mangled_name: &str) -> FunctionId {
*self.wasm_exports.get(fn_mangled_name).unwrap_or_else(|| {
panic!("can't find function `{fn_mangled_name}`")
})
}
fn finish_namespace_block(&mut self) {
if !self
.namespace_func
.instr_seq(self.namespace_block)
.instrs()
.is_empty()
{
self.namespace_func
.func_body()
.instr(Block { seq: self.namespace_block });
self.namespace_block =
self.namespace_func.dangling_instr_seq(None).id();
}
}
fn finish_namespace_func(&mut self) {
let namespace_func = mem::replace(
&mut self.namespace_func,
FunctionBuilder::new(&mut self.module.types, &[], &[]),
);
self.namespace_block =
self.namespace_func.dangling_instr_seq(None).id();
self.main_func.func_body().call(
self.module.funcs.add_local(namespace_func.local_func(Vec::new())),
);
}
fn finish_rule_func(&mut self) {
let mut rule_func = mem::replace(
&mut self.rules_func,
FunctionBuilder::new(
&mut self.module.types,
&[],
&Self::RULES_FUNC_RET,
),
);
if !rule_func.func_body().instrs().is_empty() {
rule_func.func_body().i32_const(0);
let mut namespace_block =
self.namespace_func.instr_seq(self.namespace_block);
namespace_block.call(
self.module.funcs.add_local(rule_func.local_func(Vec::new())),
);
let namespace_block_id = namespace_block.id();
namespace_block.br_if(namespace_block_id);
}
}
fn gen_check_for_pattern_match(
module: &mut Module,
main_memory: MemoryId,
matching_patterns_bitmap_base: GlobalId,
) -> FunctionId {
let mut func = FunctionBuilder::new(&mut module.types, &[I32], &[I32]);
let pattern_id = module.locals.add(I32);
let tmp = module.locals.add(I32);
let mut instr = func.func_body();
instr.local_get(pattern_id);
instr.i32_const(3);
instr.binop(BinaryOp::I32ShrU);
instr.global_get(matching_patterns_bitmap_base);
instr.binop(BinaryOp::I32Add);
instr.load(
main_memory,
LoadKind::I32_8 { kind: ZeroExtend },
MemArg { align: mem::size_of::<i8>() as u32, offset: 0 },
);
instr.i32_const(1);
instr.local_get(pattern_id);
instr.i32_const(8);
instr.binop(BinaryOp::I32RemU);
instr.local_tee(tmp);
instr.binop(BinaryOp::I32Shl);
instr.binop(BinaryOp::I32And);
instr.local_get(tmp);
instr.binop(BinaryOp::I32ShrU);
func.finish(vec![pattern_id], &mut module.funcs)
}
}