1mod backend;
8
9pub use backend::{host_function, mutable_global, Backend, GasMeter};
10
11#[cfg(test)]
12mod validation;
13
14#[cfg(not(feature = "ignore_custom_section"))]
15use crate::utils::transform::update_custom_section_function_indices;
16use crate::utils::{
17 module_info::{copy_locals, truncate_len_from_encoder, ModuleInfo},
18 translator::{ConstExprKind, DefaultTranslator, Translator},
19};
20use alloc::{string::String, vec, vec::Vec};
21use anyhow::{anyhow, Result};
22use core::{cmp::min, mem, num::NonZeroU32};
23use wasm_encoder::{
24 ElementMode, ElementSection, ElementSegment, Elements, ExportKind, ExportSection, Function,
25 Instruction, SectionId, StartSection,
26};
27use wasmparser::{
28 ElementItems, ElementKind, ExternalKind, FuncType, FunctionBody, GlobalType, Operator, Type,
29 ValType,
30};
31
32pub trait Rules {
34 fn instruction_cost(&self, instruction: &Operator) -> Option<u32>;
40
41 fn memory_grow_cost(&self) -> MemoryGrowCost;
50
51 fn call_per_local_cost(&self) -> u32;
53}
54
55#[derive(Debug, PartialEq, Eq, Copy, Clone)]
57pub enum MemoryGrowCost {
58 Free,
67 Linear(NonZeroU32),
69}
70
71impl MemoryGrowCost {
72 fn enabled(&self) -> bool {
74 match self {
75 Self::Free => false,
76 Self::Linear(_) => true,
77 }
78 }
79}
80
81pub struct ConstantCostRules {
91 instruction_cost: u32,
92 memory_grow_cost: u32,
93 call_per_local_cost: u32,
94}
95
96impl ConstantCostRules {
97 pub fn new(instruction_cost: u32, memory_grow_cost: u32, call_per_local_cost: u32) -> Self {
102 Self { instruction_cost, memory_grow_cost, call_per_local_cost }
103 }
104}
105
106impl Default for ConstantCostRules {
107 fn default() -> Self {
109 Self { instruction_cost: 1, memory_grow_cost: 0, call_per_local_cost: 1 }
110 }
111}
112
113impl Rules for ConstantCostRules {
114 fn instruction_cost(&self, _: &Operator) -> Option<u32> {
115 Some(self.instruction_cost)
116 }
117
118 fn memory_grow_cost(&self) -> MemoryGrowCost {
119 NonZeroU32::new(self.memory_grow_cost).map_or(MemoryGrowCost::Free, MemoryGrowCost::Linear)
120 }
121
122 fn call_per_local_cost(&self) -> u32 {
123 self.call_per_local_cost
124 }
125}
126
127pub fn inject<R: Rules, B: Backend>(
169 module_info: &mut ModuleInfo,
170 backend: B,
171 rules: &R,
172) -> Result<Vec<u8>> {
173 let gas_meter = backend.gas_meter(module_info, rules);
175
176 let import_count = module_info.imported_functions_count;
177 let functions_count = module_info.num_functions();
178
179 let (
182 gas_func_idx,
184 grow_cnt_func_idx,
186 gas_fn_cost,
188 ) = match gas_meter {
189 GasMeter::External { module: gas_module, function } => {
190 let ty = Type::Func(FuncType::new(vec![ValType::I64], vec![]));
192 module_info.add_import_func(gas_module, function, ty)?;
193
194 (
195 import_count,
197 functions_count + 1,
198 0,
199 )
200 },
201 GasMeter::Internal { module: _gas_module, global: global_name, ref func, cost } => {
202 let gas_global_idx = module_info.num_globals();
203
204 module_info.add_global(
206 GlobalType { content_type: ValType::I64, mutable: true },
207 &wasm_encoder::ConstExpr::i64_const(0),
208 )?;
209
210 module_info.add_exports(&[(
211 String::from(global_name),
212 ExportKind::Global,
213 gas_global_idx,
214 )])?;
215
216 let ty = Type::Func(FuncType::new(vec![ValType::I64], vec![]));
218 module_info.add_functions(&[(ty, func.clone())])?;
219
220 (
221 functions_count,
225 functions_count + 1,
226 cost,
227 )
228 },
229 };
230
231 let mut need_grow_counter = false;
232 let mut error = false;
233
234 if module_info.code_section_entry_count > 0 {
238 let mut code_section_builder = wasm_encoder::CodeSection::new();
239
240 for (func_body, is_last) in module_info
241 .code_section()?
242 .ok_or_else(|| anyhow!("no code section"))?
243 .into_iter()
244 .enumerate()
245 .map(|(index, item)| (item, index as u32 == module_info.code_section_entry_count - 1))
246 {
247 let current_locals = copy_locals(&func_body)?;
248
249 let locals_count = current_locals.iter().map(|(count, _)| count).sum();
250
251 let mut func_builder = wasm_encoder::Function::new(copy_locals(&func_body)?);
252
253 let operator_reader = func_body.get_operators_reader()?;
254 for op in operator_reader {
255 let op = op?;
256 let mut instruction: Option<Instruction> = None;
257 if let GasMeter::External { .. } = gas_meter {
258 if let Operator::Call { function_index } = op {
259 if function_index >= gas_func_idx {
260 instruction = Some(Instruction::Call(function_index + 1));
261 }
262 }
263 }
264 let instruction = match instruction {
265 Some(instruction) => instruction,
266 None => DefaultTranslator.translate_op(&op)?,
267 };
268 func_builder.instruction(&instruction);
269 }
270
271 if let GasMeter::Internal { .. } = gas_meter {
272 if is_last {
276 code_section_builder.function(&func_builder);
277 continue;
278 }
279 }
280
281 match inject_counter(
282 &FunctionBody::new(0, &truncate_len_from_encoder(&func_builder)?),
283 gas_fn_cost,
284 locals_count,
285 rules,
286 gas_func_idx,
287 ) {
288 Ok(new_builder) => func_builder = new_builder,
289 Err(_) => {
290 error = true;
291 break;
292 },
293 }
294 if rules.memory_grow_cost().enabled() {
295 let counter;
296 (func_builder, counter) = inject_grow_counter(
297 &FunctionBody::new(0, &truncate_len_from_encoder(&func_builder)?),
298 grow_cnt_func_idx,
299 )?;
300 if counter > 0 {
301 need_grow_counter = true;
302 }
303 }
304 code_section_builder.function(&func_builder);
305 }
306 module_info.replace_section(SectionId::Code.into(), &code_section_builder)?;
307 }
308
309 if module_info.exports_count > 0 {
310 if let GasMeter::External { .. } = gas_meter {
311 let mut export_sec_builder = ExportSection::new();
312
313 for export in module_info.export_section()?.expect("no export section") {
314 let mut export_index = export.index;
315 if let ExternalKind::Func = export.kind {
316 if export_index >= gas_func_idx {
317 export_index += 1;
318 }
319 }
320 export_sec_builder.export(
321 export.name,
322 DefaultTranslator.translate_export_kind(export.kind)?,
323 export_index,
324 );
325 }
326 module_info.replace_section(SectionId::Export.into(), &export_sec_builder)?;
327 }
328 }
329
330 if module_info.elements_count > 0 {
331 if let GasMeter::External { .. } = gas_meter {
334 let mut ele_sec_builder = ElementSection::new();
335
336 for elem in module_info.element_section()?.expect("no element_section section") {
338 let mut functions = vec![];
339 if let ElementItems::Functions(func_indexes) = elem.items {
340 for func_idx in func_indexes {
341 let mut func_idx = func_idx?;
342 if func_idx >= gas_func_idx {
343 func_idx += 1
344 }
345 functions.push(func_idx);
346 }
347 }
348
349 let offset;
350 let mode = match elem.kind {
351 ElementKind::Active { table_index, offset_expr } => {
352 offset = DefaultTranslator.translate_const_expr(
353 &offset_expr,
354 &ValType::I32,
355 ConstExprKind::ElementOffset,
356 )?;
357
358 ElementMode::Active { table: table_index, offset: &offset }
359 },
360 ElementKind::Passive => ElementMode::Passive,
361 ElementKind::Declared => ElementMode::Declared,
362 };
363
364 let element_type = DefaultTranslator.translate_ref_ty(&elem.ty)?;
365 let elements = Elements::Functions(&functions);
366
367 ele_sec_builder.segment(ElementSegment {
368 mode,
369 element_type,
371 elements,
373 });
374 }
375 module_info.replace_section(SectionId::Element.into(), &ele_sec_builder)?;
376 }
377 }
378
379 if module_info.raw_sections.get_mut(&SectionId::Start.into()).is_some() {
380 if let GasMeter::External { .. } = gas_meter {
381 if let Some(func_idx) = module_info.start_function {
382 if func_idx >= gas_func_idx {
383 let start_section = StartSection { function_index: func_idx + 1 };
384 module_info.replace_section(SectionId::Start.into(), &start_section)?;
385 }
386 }
387 }
388 }
389
390 #[cfg(not(feature = "ignore_custom_section"))]
391 update_custom_section_function_indices(module_info, gas_func_idx)?;
392
393 if error {
394 return Err(anyhow!("inject fail"));
395 }
396
397 if need_grow_counter {
398 if let Some((func, grow_counter_func)) = generate_grow_counter(rules, gas_func_idx) {
399 module_info.add_functions(&[(func, grow_counter_func)])?;
400 }
401 }
402 Ok(module_info.bytes())
403}
404
405#[derive(Debug)]
423struct ControlBlock {
424 lowest_forward_br_target: usize,
433
434 active_metered_block: MeteredBlock,
436
437 is_loop: bool,
440}
441
442#[derive(Debug)]
446struct MeteredBlock {
447 start_pos: usize,
449 cost: u64,
451}
452
453struct Counter {
456 stack: Vec<ControlBlock>,
462
463 finalized_blocks: Vec<MeteredBlock>,
465}
466
467impl Counter {
468 fn new() -> Counter {
469 Counter { stack: Vec::new(), finalized_blocks: Vec::new() }
470 }
471
472 fn begin_control_block(&mut self, cursor: usize, is_loop: bool) {
474 let index = self.stack.len();
475 self.stack.push(ControlBlock {
476 lowest_forward_br_target: index,
477 active_metered_block: MeteredBlock { start_pos: cursor, cost: 0 },
478 is_loop,
479 })
480 }
481
482 fn finalize_control_block(&mut self, cursor: usize) -> Result<()> {
485 self.finalize_metered_block(cursor)?;
488
489 let closing_control_block = self.stack.pop().ok_or_else(|| anyhow!("stack not found"))?;
491 let closing_control_index = self.stack.len();
492
493 if self.stack.is_empty() {
494 return Ok(());
495 }
496
497 {
499 let control_block = self.stack.last_mut().ok_or_else(|| anyhow!("stack not found"))?;
500 control_block.lowest_forward_br_target = min(
501 control_block.lowest_forward_br_target,
502 closing_control_block.lowest_forward_br_target,
503 );
504 }
505
506 let may_br_out = closing_control_block.lowest_forward_br_target < closing_control_index;
509 if may_br_out {
510 self.finalize_metered_block(cursor)?;
511 }
512
513 Ok(())
514 }
515
516 fn finalize_metered_block(&mut self, cursor: usize) -> Result<()> {
520 let closing_metered_block = {
521 let control_block = self.stack.last_mut().ok_or_else(|| anyhow!("stack not found"))?;
522 mem::replace(
523 &mut control_block.active_metered_block,
524 MeteredBlock { start_pos: cursor + 1, cost: 0 },
525 )
526 };
527
528 let last_index = self.stack.len() - 1;
534 if last_index > 0 {
535 let prev_control_block = self
536 .stack
537 .get_mut(last_index - 1)
538 .expect("last_index is greater than 0; last_index is stack size - 1; qed");
539 let prev_metered_block = &mut prev_control_block.active_metered_block;
540 if closing_metered_block.start_pos == prev_metered_block.start_pos {
541 prev_metered_block.cost = prev_metered_block
542 .cost
543 .checked_add(closing_metered_block.cost)
544 .ok_or_else(|| anyhow!("overflow occured"))?;
545 return Ok(());
546 }
547 }
548
549 if closing_metered_block.cost > 0 {
550 self.finalized_blocks.push(closing_metered_block);
551 }
552 Ok(())
553 }
554
555 fn branch(&mut self, cursor: usize, indices: &[usize]) -> Result<()> {
560 self.finalize_metered_block(cursor)?;
561
562 for &index in indices {
564 let target_is_loop = {
565 let target_block =
566 self.stack.get(index).ok_or_else(|| anyhow!("unable to find stack index"))?;
567 target_block.is_loop
568 };
569 if target_is_loop {
570 continue;
571 }
572
573 let control_block =
574 self.stack.last_mut().ok_or_else(|| anyhow!("stack does not exist"))?;
575 control_block.lowest_forward_br_target =
576 min(control_block.lowest_forward_br_target, index);
577 }
578
579 Ok(())
580 }
581
582 fn active_control_block_index(&self) -> Option<usize> {
584 self.stack.len().checked_sub(1)
585 }
586
587 fn active_metered_block(&mut self) -> Result<&mut MeteredBlock> {
589 let top_block = self.stack.last_mut().ok_or_else(|| anyhow!("stack does not exist"))?;
590 Ok(&mut top_block.active_metered_block)
591 }
592
593 fn increment(&mut self, val: u32) -> Result<()> {
595 let top_block = self.active_metered_block()?;
596 top_block.cost = top_block
597 .cost
598 .checked_add(val.into())
599 .ok_or_else(|| anyhow!("add cost overflow"))?;
600 Ok(())
601 }
602}
603
604fn inject_grow_counter(
605 func_body: &FunctionBody,
606 grow_counter_func: u32,
607) -> Result<(Function, usize)> {
608 let mut counter = 0;
609 let mut new_func = Function::new(copy_locals(func_body)?);
610 let mut operator_reader = func_body.get_operators_reader()?;
611 while !operator_reader.eof() {
612 let op = operator_reader.read()?;
613 match op {
614 Operator::MemoryGrow { .. } => {
616 new_func.instruction(&wasm_encoder::Instruction::Call(grow_counter_func));
617 counter += 1;
618 },
619 op => {
620 new_func.instruction(&DefaultTranslator.translate_op(&op)?);
621 },
622 }
623 }
624 Ok((new_func, counter))
625}
626
627fn generate_grow_counter<R: Rules>(rules: &R, gas_func: u32) -> Option<(Type, Function)> {
628 let cost = match rules.memory_grow_cost() {
629 MemoryGrowCost::Free => return None,
630 MemoryGrowCost::Linear(val) => val.get(),
631 };
632
633 let mut func = wasm_encoder::Function::new(None);
634 func.instruction(&wasm_encoder::Instruction::LocalGet(0));
635 func.instruction(&wasm_encoder::Instruction::LocalGet(0));
636 func.instruction(&wasm_encoder::Instruction::I64ExtendI32U);
637 func.instruction(&wasm_encoder::Instruction::I64Const(cost as i64));
638 func.instruction(&wasm_encoder::Instruction::I64Mul);
639 func.instruction(&wasm_encoder::Instruction::Call(gas_func));
640 func.instruction(&wasm_encoder::Instruction::MemoryGrow(0));
641 func.instruction(&wasm_encoder::Instruction::End);
642 Some((Type::Func(FuncType::new(vec![ValType::I32], vec![ValType::I32])), func))
643}
644
645fn determine_metered_blocks<R: Rules>(
646 func_body: &wasmparser::FunctionBody,
647 rules: &R,
648 locals_count: u32,
649) -> Result<Vec<MeteredBlock>> {
650 use wasmparser::Operator::*;
651
652 let mut counter = Counter::new();
653
654 counter.begin_control_block(0, false);
656 let locals_init_cost = rules
658 .call_per_local_cost()
659 .checked_mul(locals_count)
660 .ok_or_else(|| anyhow!("overflow occured"))?;
661 counter.increment(locals_init_cost)?;
662
663 let operators = func_body
664 .get_operators_reader()?
665 .into_iter()
666 .collect::<wasmparser::Result<Vec<Operator>>>()?;
667
668 for (cursor, instruction) in operators.iter().enumerate() {
669 let instruction_cost = rules
670 .instruction_cost(instruction)
671 .ok_or_else(|| anyhow!("check gas rule fail"))?;
672 match instruction {
673 Block { blockty: _ } => {
674 counter.increment(instruction_cost)?;
675
676 let top_block_start_pos = counter.active_metered_block()?.start_pos;
681 counter.begin_control_block(top_block_start_pos, false);
682 },
683 If { blockty: _ } => {
684 counter.increment(instruction_cost)?;
685 counter.begin_control_block(cursor + 1, false);
686 },
687 Loop { blockty: _ } => {
688 counter.increment(instruction_cost)?;
689 counter.begin_control_block(cursor + 1, true);
690 },
691 End => {
692 counter.finalize_control_block(cursor)?;
693 },
694 Else => {
695 counter.finalize_metered_block(cursor)?;
696 },
697 Br { relative_depth } | BrIf { relative_depth } => {
698 counter.increment(instruction_cost)?;
699
700 let active_index = counter
702 .active_control_block_index()
703 .ok_or_else(|| anyhow!("active control block not exit"))?;
704
705 let target_index = active_index
706 .checked_sub(*relative_depth as usize)
707 .ok_or_else(|| anyhow!("index not found"))?;
708
709 counter.branch(cursor, &[target_index])?;
710 },
711 BrTable { targets: br_table_data } => {
712 counter.increment(instruction_cost)?;
713
714 let active_index = counter
715 .active_control_block_index()
716 .ok_or_else(|| anyhow!("index not found"))?;
717 let r = br_table_data.targets().collect::<wasmparser::Result<Vec<u32>>>()?;
718 let target_indices = [br_table_data.default()]
719 .iter()
720 .chain(r.iter())
721 .map(|label| active_index.checked_sub(*label as usize))
722 .collect::<Option<Vec<_>>>()
723 .ok_or_else(|| anyhow!("to do check this error"))?;
724 counter.branch(cursor, &target_indices)?;
725 },
726 Return => {
727 counter.increment(instruction_cost)?;
728 counter.branch(cursor, &[0])?;
729 },
730 _ => {
731 counter.increment(instruction_cost)?;
733 },
734 }
735 }
736
737 counter.finalized_blocks.sort_unstable_by_key(|block| block.start_pos);
738 Ok(counter.finalized_blocks)
739}
740
741fn inject_counter<R: Rules>(
742 func_body: &FunctionBody,
743 gas_function_cost: u64,
744 locals_count: u32,
745 rules: &R,
746 gas_func: u32,
747) -> Result<wasm_encoder::Function> {
748 let blocks = determine_metered_blocks(func_body, rules, locals_count)?;
749 insert_metering_calls(func_body, gas_function_cost, blocks, gas_func)
750}
751
752fn insert_metering_calls(
754 func_body: &FunctionBody,
755 gas_function_cost: u64,
756 blocks: Vec<MeteredBlock>,
757 gas_func: u32,
758) -> Result<wasm_encoder::Function> {
759 let mut new_func = wasm_encoder::Function::new(copy_locals(func_body)?);
760
761 let mut block_iter = blocks.into_iter().peekable();
764 let operators = func_body
765 .get_operators_reader()?
766 .into_iter()
767 .collect::<wasmparser::Result<Vec<Operator>>>()?;
768 for (original_pos, instr) in operators.iter().enumerate() {
769 let used_block = if let Some(block) = block_iter.peek() {
771 if block.start_pos == original_pos {
772 let cost = block
773 .cost
774 .checked_add(gas_function_cost)
775 .ok_or_else(|| anyhow!("block cost add overflow"))? as i64;
776 new_func.instruction(&wasm_encoder::Instruction::I64Const(cost));
777 new_func.instruction(&wasm_encoder::Instruction::Call(gas_func));
778 true
779 } else {
780 false
781 }
782 } else {
783 false
784 };
785
786 if used_block {
787 block_iter.next();
788 }
789
790 new_func.instruction(&DefaultTranslator.translate_op(instr)?);
792 }
793
794 if block_iter.next().is_some() {
795 return Err(anyhow!("block should be consume all"));
796 }
797 Ok(new_func)
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803 use wasm_encoder::{BlockType, Encode, Instruction::*};
804
805 fn check_expect_function_body(
806 raw_wasm: &[u8],
807 index: usize,
808 ops2: &[wasm_encoder::Instruction],
809 ) -> bool {
810 let mut body_raw = vec![];
811 ops2.iter().for_each(|v| v.encode(&mut body_raw));
812 get_function_body(raw_wasm, index).eq(&body_raw)
813 }
814
815 fn get_function_body(raw_wasm: &[u8], index: usize) -> Vec<u8> {
816 let module = ModuleInfo::new(raw_wasm).unwrap();
817 let func_sec = module.raw_sections.get(&SectionId::Code.into()).unwrap();
818 let func_bodies = module.code_section().unwrap().expect("no code section");
819
820 let func_body = func_bodies
821 .get(index)
822 .unwrap_or_else(|| panic!("module doesn't have function {} body", index));
823
824 let start = func_body.get_operators_reader().unwrap().original_position();
825 func_sec.data[start..func_body.range().end].to_vec()
826 }
827
828 fn get_function_operators(raw_wasm: &[u8], index: usize) -> Vec<Instruction> {
829 let module = ModuleInfo::new(raw_wasm).unwrap();
830 let func_bodies = module.code_section().unwrap().expect("no code section");
831
832 let func_body = func_bodies
833 .get(index)
834 .unwrap_or_else(|| panic!("module doesn't have function {} body", index));
835
836 let operators = func_body
837 .get_operators_reader()
838 .unwrap()
839 .into_iter()
840 .map(|op| DefaultTranslator.translate_op(&op.unwrap()).unwrap())
841 .collect::<Vec<Instruction>>();
842
843 operators
844 }
845
846 fn parse_wat(source: &str) -> ModuleInfo {
847 let module_bytes = wat::parse_str(source).unwrap();
848 ModuleInfo::new(&module_bytes).unwrap()
849 }
850
851 #[test]
852 fn simple_grow_host_fn() {
853 let mut module = parse_wat(
854 r#"(module
855 (func (result i32)
856 global.get 0
857 memory.grow)
858 (global i32 (i32.const 42))
859 (memory 0 1)
860 )"#,
861 );
862
863 let backend = host_function::Injector::new("env", "gas");
864 let injected_raw_wasm =
865 super::inject(&mut module, backend, &ConstantCostRules::new(1, 10_000, 1)).unwrap();
866
867 assert!(check_expect_function_body(
869 &injected_raw_wasm,
870 0,
871 &[I64Const(2), Call(0), GlobalGet(0), Call(2), End,]
872 ));
873 assert!(check_expect_function_body(
875 &injected_raw_wasm,
876 1,
877 &[
878 LocalGet(0),
879 LocalGet(0),
880 I64ExtendI32U,
881 I64Const(10000),
882 I64Mul,
883 Call(0),
884 MemoryGrow(0),
885 End,
886 ]
887 ));
888
889 wasmparser::validate(&injected_raw_wasm).unwrap();
890 }
891
892 #[test]
893 fn simple_grow_mut_global() {
894 let mut module = parse_wat(
895 r#"(module
896 (func (result i32)
897 global.get 0
898 memory.grow)
899 (global i32 (i32.const 42))
900 (memory 0 1)
901 )"#,
902 );
903
904 let backend = mutable_global::Injector::new("env", "gas_left");
905 let injected_raw_wasm =
906 super::inject(&mut module, backend, &ConstantCostRules::new(1, 10_000, 1)).unwrap();
907
908 assert!(check_expect_function_body(
910 &injected_raw_wasm,
911 1,
912 &[
913 GlobalGet(1),
914 LocalGet(0),
915 I64GeU,
916 If(BlockType::Empty),
917 GlobalGet(1),
918 LocalGet(0),
919 I64Sub,
920 GlobalSet(1),
921 Else,
922 I64Const(-1i64),
923 GlobalSet(1),
924 Unreachable,
925 End,
926 End
927 ]
928 ));
929
930 assert!(check_expect_function_body(
932 &injected_raw_wasm,
933 2,
934 &[
935 LocalGet(0),
936 LocalGet(0),
937 I64ExtendI32U,
938 I64Const(10000i64),
939 I64Mul,
940 Call(1),
941 MemoryGrow(0),
942 End
943 ]
944 ));
945
946 wasmparser::validate(&injected_raw_wasm).unwrap();
947 }
948
949 #[test]
950 fn grow_no_gas_no_track_host_fn() {
951 let mut module = parse_wat(
952 r"(module
953 (func (result i32)
954 global.get 0
955 memory.grow)
956 (global i32 (i32.const 42))
957 (memory 0 1)
958 )",
959 );
960
961 let backend = host_function::Injector::new("env", "gas");
962 let injected_raw_wasm =
963 super::inject(&mut module, backend, &ConstantCostRules::default()).unwrap();
964
965 assert!(check_expect_function_body(
967 &injected_raw_wasm,
968 0,
969 &[I64Const(2), Call(0), GlobalGet(0), MemoryGrow(0), End,]
970 ));
971
972 assert_eq!(module.num_functions(), 2);
974
975 wasmparser::validate(&injected_raw_wasm).unwrap();
976 }
977 #[test]
978 fn grow_no_gas_no_track_mut_global() {
979 let mut module = parse_wat(
980 r"(module
981 (func (result i32)
982 global.get 0
983 memory.grow)
984 (global i32 (i32.const 42))
985 (memory 0 1)
986 )",
987 );
988
989 let backend = host_function::Injector::new("env", "gas");
990 let injected_raw_wasm =
991 super::inject(&mut module, backend, &ConstantCostRules::default()).unwrap();
992
993 assert!(check_expect_function_body(
995 &injected_raw_wasm,
996 0,
997 &[I64Const(2), Call(0), GlobalGet(0), MemoryGrow(0), End,]
998 ));
999
1000 assert_eq!(module.num_functions(), 2);
1002
1003 wasmparser::validate(&injected_raw_wasm).unwrap();
1004 }
1005 #[test]
1006 fn call_index_host_fn() {
1007 let mut module = parse_wat(
1008 r"(module
1009 (type (;0;) (func (result i32)))
1010 (func (;0;) (type 0) (result i32))
1011 (func (;1;) (type 0) (result i32)
1012 call 0
1013 if ;; label = @1
1014 call 0
1015 call 0
1016 call 0
1017 else
1018 call 0
1019 call 0
1020 end
1021 call 0
1022 )
1023 (global (;0;) i32 (i32.const 0))
1024 )",
1025 );
1026
1027 let backend = host_function::Injector::new("env", "gas");
1028 let injected_raw_wasm =
1029 super::inject(&mut module, backend, &ConstantCostRules::default()).unwrap();
1030
1031 assert!(check_expect_function_body(
1033 &injected_raw_wasm,
1034 1,
1035 &vec![
1036 I64Const(3),
1037 Call(0),
1038 Call(1),
1039 If(BlockType::Empty),
1040 I64Const(3),
1041 Call(0),
1042 Call(1),
1043 Call(1),
1044 Call(1),
1045 Else,
1046 I64Const(2),
1047 Call(0),
1048 Call(1),
1049 Call(1),
1050 End,
1051 Call(1),
1052 End
1053 ]
1054 ));
1055 }
1056
1057 #[test]
1058 fn call_index_mut_global() {
1059 let mut module = parse_wat(
1060 r"(module
1061 (type (;0;) (func (result i32)))
1062 (func (;0;) (type 0) (result i32))
1063 (func (;1;) (type 0) (result i32)
1064 call 0
1065 if ;; label = @1
1066 call 0
1067 call 0
1068 call 0
1069 else
1070 call 0
1071 call 0
1072 end
1073 call 0
1074 )
1075 (global (;0;) i32 (i32.const 0))
1076 )",
1077 );
1078
1079 let backend = mutable_global::Injector::new("env", "gas_left");
1080 let injected_raw_wasm =
1081 super::inject(&mut module, backend, &ConstantCostRules::default()).unwrap();
1082
1083 assert!(check_expect_function_body(
1085 &injected_raw_wasm,
1086 1,
1087 &vec![
1088 I64Const(14),
1089 Call(2),
1090 Call(0),
1091 If(BlockType::Empty),
1092 I64Const(14),
1093 Call(2),
1094 Call(0),
1095 Call(0),
1096 Call(0),
1097 Else,
1098 I64Const(13),
1099 Call(2),
1100 Call(0),
1101 Call(0),
1102 End,
1103 Call(0),
1104 End
1105 ]
1106 ));
1107 }
1108
1109 macro_rules! test_gas_counter_injection {
1110 (names = ($name1:ident, $name2:ident); input = $input:expr; expected = $expected:expr) => {
1111 #[test]
1112 fn $name1() {
1113 let mut module = parse_wat($input);
1114 let expected_module = parse_wat($expected);
1115 let injected_wasm = super::inject(
1116 &mut module,
1117 host_function::Injector::new("env", "gas"),
1118 &ConstantCostRules::default(),
1119 )
1120 .expect("inject_gas_counter call failed");
1121
1122 let actual_func_body = get_function_body(&injected_wasm, 0);
1123 let expected_func_body = get_function_body(&expected_module.bytes(), 0);
1124
1125 assert_eq!(actual_func_body, expected_func_body);
1126 }
1127
1128 #[test]
1129 fn $name2() {
1130 let mut module = parse_wat($input);
1131 let draft_module = parse_wat($expected);
1132 let gas_fun_cost = match mutable_global::Injector::new("env", "gas_left")
1133 .gas_meter(&mut module, &ConstantCostRules::default())
1134 {
1135 GasMeter::Internal { cost, .. } => cost as i64,
1136 _ => 0i64,
1137 };
1138
1139 let injected_wasm = super::inject(
1140 &mut module,
1141 mutable_global::Injector::new("env", "gas_left"),
1142 &ConstantCostRules::default(),
1143 )
1144 .expect("inject_gas_counter call failed");
1145
1146 let actual_func_body = get_function_body(&injected_wasm, 0);
1147
1148 let expected_module_bytes = draft_module.bytes();
1149 let mut expected_func_operators = get_function_operators(&expected_module_bytes, 0);
1150
1151 let mut iter = expected_func_operators.iter_mut();
1153 while let Some(ins) = iter.next() {
1154 if let I64Const(cost) = ins {
1155 if let Some(ins_next) = iter.next() {
1156 if let Call(0) = ins_next {
1157 *cost += gas_fun_cost;
1158 *ins_next = Call(1);
1159 }
1160 }
1161 }
1162 }
1163 let mut expected_func_body = vec![];
1164 expected_func_operators.iter().for_each(|v| v.encode(&mut expected_func_body));
1165
1166 assert_eq!(actual_func_body, expected_func_body);
1167 }
1168 };
1169 }
1170
1171 test_gas_counter_injection! {
1172 names = (simple_host_fn, simple_mut_global);
1173 input = r#"
1174 (module
1175 (func (result i32)
1176 (get_global 0)))
1177 "#;
1178 expected = r#"
1179 (module
1180 (func (result i32)
1181 (call 0 (i64.const 1))
1182 (get_global 0)))
1183 "#
1184 }
1185
1186 test_gas_counter_injection! {
1187 names = (nested_host_fn, nested_mut_global);
1188 input = r#"
1189 (module
1190 (func (result i32)
1191 (get_global 0)
1192 (block
1193 (get_global 0)
1194 (get_global 0)
1195 (get_global 0))
1196 (get_global 0)))
1197 "#;
1198 expected = r#"
1199 (module
1200 (func (result i32)
1201 (call 0 (i64.const 6))
1202 (get_global 0)
1203 (block
1204 (get_global 0)
1205 (get_global 0)
1206 (get_global 0))
1207 (get_global 0)))
1208 "#
1209 }
1210
1211 test_gas_counter_injection! {
1212 names = (ifelse_host_fn, ifelse_mut_global);
1213 input = r#"
1214 (module
1215 (func (result i32)
1216 (get_global 0)
1217 (if
1218 (then
1219 (get_global 0)
1220 (get_global 0)
1221 (get_global 0))
1222 (else
1223 (get_global 0)
1224 (get_global 0)))
1225 (get_global 0)))
1226 "#;
1227 expected = r#"
1228 (module
1229 (func (result i32)
1230 (call 0 (i64.const 3))
1231 (get_global 0)
1232 (if
1233 (then
1234 (call 0 (i64.const 3))
1235 (get_global 0)
1236 (get_global 0)
1237 (get_global 0))
1238 (else
1239 (call 0 (i64.const 2))
1240 (get_global 0)
1241 (get_global 0)))
1242 (get_global 0)))
1243 "#
1244 }
1245
1246 test_gas_counter_injection! {
1247 names = (branch_innermost_host_fn, branch_innermost_mut_global);
1248 input = r#"
1249 (module
1250 (func (result i32)
1251 (get_global 0)
1252 (block
1253 (get_global 0)
1254 (drop)
1255 (br 0)
1256 (get_global 0)
1257 (drop))
1258 (get_global 0)))
1259 "#;
1260 expected = r#"
1261 (module
1262 (func (result i32)
1263 (call 0 (i64.const 6))
1264 (get_global 0)
1265 (block
1266 (get_global 0)
1267 (drop)
1268 (br 0)
1269 (call 0 (i64.const 2))
1270 (get_global 0)
1271 (drop))
1272 (get_global 0)))
1273 "#
1274 }
1275
1276 test_gas_counter_injection! {
1277 names = (branch_outer_block_host_fn, branch_outer_block_mut_global);
1278 input = r#"
1279 (module
1280 (func (result i32)
1281 (get_global 0)
1282 (block
1283 (get_global 0)
1284 (if
1285 (then
1286 (get_global 0)
1287 (get_global 0)
1288 (drop)
1289 (br_if 1)))
1290 (get_global 0)
1291 (drop))
1292 (get_global 0)))
1293 "#;
1294 expected = r#"
1295 (module
1296 (func (result i32)
1297 (call 0 (i64.const 5))
1298 (get_global 0)
1299 (block
1300 (get_global 0)
1301 (if
1302 (then
1303 (call 0 (i64.const 4))
1304 (get_global 0)
1305 (get_global 0)
1306 (drop)
1307 (br_if 1)))
1308 (call 0 (i64.const 2))
1309 (get_global 0)
1310 (drop))
1311 (get_global 0)))
1312 "#
1313 }
1314
1315 test_gas_counter_injection! {
1316 names = (branch_outer_loop_host_fn, branch_outer_loop_mut_global);
1317 input = r#"
1318 (module
1319 (func (result i32)
1320 (get_global 0)
1321 (loop
1322 (get_global 0)
1323 (if
1324 (then
1325 (get_global 0)
1326 (br_if 0))
1327 (else
1328 (get_global 0)
1329 (get_global 0)
1330 (drop)
1331 (br_if 1)))
1332 (get_global 0)
1333 (drop))
1334 (get_global 0)))
1335 "#;
1336 expected = r#"
1337 (module
1338 (func (result i32)
1339 (call 0 (i64.const 3))
1340 (get_global 0)
1341 (loop
1342 (call 0 (i64.const 4))
1343 (get_global 0)
1344 (if
1345 (then
1346 (call 0 (i64.const 2))
1347 (get_global 0)
1348 (br_if 0))
1349 (else
1350 (call 0 (i64.const 4))
1351 (get_global 0)
1352 (get_global 0)
1353 (drop)
1354 (br_if 1)))
1355 (get_global 0)
1356 (drop))
1357 (get_global 0)))
1358 "#
1359 }
1360
1361 test_gas_counter_injection! {
1362 names = (return_from_func_host_fn, return_from_func_mut_global);
1363 input = r#"
1364 (module
1365 (func (result i32)
1366 (get_global 0)
1367 (if
1368 (then
1369 (return)))
1370 (get_global 0)))
1371 "#;
1372 expected = r#"
1373 (module
1374 (func (result i32)
1375 (call 0 (i64.const 2))
1376 (get_global 0)
1377 (if
1378 (then
1379 (call 0 (i64.const 1))
1380 (return)))
1381 (call 0 (i64.const 1))
1382 (get_global 0)))
1383 "#
1384 }
1385
1386 test_gas_counter_injection! {
1387 names = (branch_from_if_not_else_host_fn, branch_from_if_not_else_mut_global);
1388 input = r#"
1389 (module
1390 (func (result i32)
1391 (get_global 0)
1392 (block
1393 (get_global 0)
1394 (if
1395 (then (br 1))
1396 (else (br 0)))
1397 (get_global 0)
1398 (drop))
1399 (get_global 0)))
1400 "#;
1401 expected = r#"
1402 (module
1403 (func (result i32)
1404 (call 0 (i64.const 5))
1405 (get_global 0)
1406 (block
1407 (get_global 0)
1408 (if
1409 (then
1410 (call 0 (i64.const 1))
1411 (br 1))
1412 (else
1413 (call 0 (i64.const 1))
1414 (br 0)))
1415 (call 0 (i64.const 2))
1416 (get_global 0)
1417 (drop))
1418 (get_global 0)))
1419 "#
1420 }
1421
1422 test_gas_counter_injection! {
1423 names = (empty_loop_host_fn, empty_loop_mut_global);
1424 input = r#"
1425 (module
1426 (func
1427 (loop
1428 (br 0)
1429 )
1430 unreachable
1431 )
1432 )
1433 "#;
1434 expected = r#"
1435 (module
1436 (func
1437 (call 0 (i64.const 2))
1438 (loop
1439 (call 0 (i64.const 1))
1440 (br 0)
1441 )
1442 unreachable
1443 )
1444 )
1445 "#
1446 }
1447}