1#[cfg(test)]
8mod validation;
9
10use crate::std::{cmp::min, mem, vec::Vec};
11
12use crate::rules::Rules;
13use parity_wasm::{builder, elements, elements::ValueType};
14
15pub fn update_call_index(instructions: &mut elements::Instructions, inserted_index: u32) {
16 use parity_wasm::elements::Instruction::*;
17 for instruction in instructions.elements_mut().iter_mut() {
18 if let Call(call_index) = instruction {
19 if *call_index >= inserted_index {
20 *call_index += 1
21 }
22 }
23 }
24}
25
26#[derive(Debug)]
45struct ControlBlock {
46 lowest_forward_br_target: usize,
55
56 active_metered_block: MeteredBlock,
58
59 is_loop: bool,
62}
63
64#[derive(Debug)]
68pub(crate) struct MeteredBlock {
69 start_pos: usize,
71 cost: u32,
73}
74
75struct Counter {
78 stack: Vec<ControlBlock>,
84
85 finalized_blocks: Vec<MeteredBlock>,
87}
88
89impl Counter {
90 fn new() -> Counter {
91 Counter { stack: Vec::new(), finalized_blocks: Vec::new() }
92 }
93
94 fn begin_control_block(&mut self, cursor: usize, is_loop: bool) {
96 let index = self.stack.len();
97 self.stack.push(ControlBlock {
98 lowest_forward_br_target: index,
99 active_metered_block: MeteredBlock { start_pos: cursor, cost: 0 },
100 is_loop,
101 })
102 }
103
104 fn finalize_control_block(&mut self, cursor: usize) -> Result<(), ()> {
107 self.finalize_metered_block(cursor)?;
110
111 let closing_control_block = self.stack.pop().ok_or(())?;
113 let closing_control_index = self.stack.len();
114
115 if self.stack.is_empty() {
116 return Ok(())
117 }
118
119 {
121 let control_block = self.stack.last_mut().ok_or(())?;
122 control_block.lowest_forward_br_target = min(
123 control_block.lowest_forward_br_target,
124 closing_control_block.lowest_forward_br_target,
125 );
126 }
127
128 let may_br_out = closing_control_block.lowest_forward_br_target < closing_control_index;
131 if may_br_out {
132 self.finalize_metered_block(cursor)?;
133 }
134
135 Ok(())
136 }
137
138 fn finalize_metered_block(&mut self, cursor: usize) -> Result<(), ()> {
142 let closing_metered_block = {
143 let control_block = self.stack.last_mut().ok_or(())?;
144 mem::replace(
145 &mut control_block.active_metered_block,
146 MeteredBlock { start_pos: cursor + 1, cost: 0 },
147 )
148 };
149
150 let last_index = self.stack.len() - 1;
156 if last_index > 0 {
157 let prev_control_block = self
158 .stack
159 .get_mut(last_index - 1)
160 .expect("last_index is greater than 0; last_index is stack size - 1; qed");
161 let prev_metered_block = &mut prev_control_block.active_metered_block;
162 if closing_metered_block.start_pos == prev_metered_block.start_pos {
163 prev_metered_block.cost += closing_metered_block.cost;
164 return Ok(())
165 }
166 }
167
168 if closing_metered_block.cost > 0 {
169 self.finalized_blocks.push(closing_metered_block);
170 }
171 Ok(())
172 }
173
174 fn branch(&mut self, cursor: usize, indices: &[usize]) -> Result<(), ()> {
179 self.finalize_metered_block(cursor)?;
180
181 for &index in indices {
183 let target_is_loop = {
184 let target_block = self.stack.get(index).ok_or(())?;
185 target_block.is_loop
186 };
187 if target_is_loop {
188 continue
189 }
190
191 let control_block = self.stack.last_mut().ok_or(())?;
192 control_block.lowest_forward_br_target =
193 min(control_block.lowest_forward_br_target, index);
194 }
195
196 Ok(())
197 }
198
199 fn active_control_block_index(&self) -> Option<usize> {
201 self.stack.len().checked_sub(1)
202 }
203
204 fn active_metered_block(&mut self) -> Result<&mut MeteredBlock, ()> {
206 let top_block = self.stack.last_mut().ok_or(())?;
207 Ok(&mut top_block.active_metered_block)
208 }
209
210 fn increment(&mut self, val: u32) -> Result<(), ()> {
212 let top_block = self.active_metered_block()?;
213 top_block.cost = top_block.cost.checked_add(val).ok_or(())?;
214 Ok(())
215 }
216}
217
218fn inject_grow_counter(instructions: &mut elements::Instructions, grow_counter_func: u32) -> usize {
219 use parity_wasm::elements::Instruction::*;
220 let mut counter = 0;
221 for instruction in instructions.elements_mut() {
222 if let GrowMemory(_) = *instruction {
223 *instruction = Call(grow_counter_func);
224 counter += 1;
225 }
226 }
227 counter
228}
229
230fn add_grow_counter<R: Rules>(
231 module: elements::Module,
232 rules: &R,
233 gas_func: u32,
234) -> elements::Module {
235 use crate::rules::MemoryGrowCost;
236 use parity_wasm::elements::Instruction::*;
237
238 let cost = match rules.memory_grow_cost() {
239 None => return module,
240 Some(MemoryGrowCost::Linear(val)) => val.get(),
241 };
242
243 let mut b = builder::from_module(module);
244 b.push_function(
245 builder::function()
246 .signature()
247 .with_param(ValueType::I32)
248 .with_result(ValueType::I32)
249 .build()
250 .body()
251 .with_instructions(elements::Instructions::new(vec![
252 GetLocal(0),
253 GetLocal(0),
254 I32Const(cost as i32),
255 I32Mul,
256 Call(gas_func),
258 GrowMemory(0),
259 End,
260 ]))
261 .build()
262 .build(),
263 );
264
265 b.build()
266}
267
268pub(crate) fn determine_metered_blocks<R: Rules>(
269 instructions: &elements::Instructions,
270 rules: &R,
271) -> Result<Vec<MeteredBlock>, ()> {
272 use parity_wasm::elements::Instruction::*;
273
274 let mut counter = Counter::new();
275
276 counter.begin_control_block(0, false);
278
279 for cursor in 0..instructions.elements().len() {
280 let instruction = &instructions.elements()[cursor];
281 let instruction_cost = rules.instruction_cost(instruction).ok_or(())?;
282 match instruction {
283 Block(_) => {
284 counter.increment(instruction_cost)?;
285
286 let top_block_start_pos = counter.active_metered_block()?.start_pos;
291 counter.begin_control_block(top_block_start_pos, false);
292 },
293 If(_) => {
294 counter.increment(instruction_cost)?;
295 counter.begin_control_block(cursor + 1, false);
296 },
297 Loop(_) => {
298 counter.increment(instruction_cost)?;
299 counter.begin_control_block(cursor + 1, true);
300 },
301 End => {
302 counter.finalize_control_block(cursor)?;
303 },
304 Else => {
305 counter.finalize_metered_block(cursor)?;
306 },
307 Br(label) | BrIf(label) => {
308 counter.increment(instruction_cost)?;
309
310 let active_index = counter.active_control_block_index().ok_or(())?;
312 let target_index = active_index.checked_sub(*label as usize).ok_or(())?;
313 counter.branch(cursor, &[target_index])?;
314 },
315 BrTable(br_table_data) => {
316 counter.increment(instruction_cost)?;
317
318 let active_index = counter.active_control_block_index().ok_or(())?;
319 let target_indices = [br_table_data.default]
320 .iter()
321 .chain(br_table_data.table.iter())
322 .map(|label| active_index.checked_sub(*label as usize))
323 .collect::<Option<Vec<_>>>()
324 .ok_or(())?;
325 counter.branch(cursor, &target_indices)?;
326 },
327 Return => {
328 counter.increment(instruction_cost)?;
329 counter.branch(cursor, &[0])?;
330 },
331 _ => {
332 counter.increment(instruction_cost)?;
334 },
335 }
336 }
337
338 counter.finalized_blocks.sort_unstable_by_key(|block| block.start_pos);
339 Ok(counter.finalized_blocks)
340}
341
342pub fn inject_counter<R: Rules>(
343 instructions: &mut elements::Instructions,
344 rules: &R,
345 gas_func: u32,
346) -> Result<(), ()> {
347 let blocks = determine_metered_blocks(instructions, rules)?;
348 insert_metering_calls(instructions, blocks, gas_func)
349}
350
351fn insert_metering_calls(
353 instructions: &mut elements::Instructions,
354 blocks: Vec<MeteredBlock>,
355 gas_func: u32,
356) -> Result<(), ()> {
357 use parity_wasm::elements::Instruction::*;
358
359 let new_instrs_len = instructions.elements().len() + 2 * blocks.len();
362 let original_instrs =
363 mem::replace(instructions.elements_mut(), Vec::with_capacity(new_instrs_len));
364 let new_instrs = instructions.elements_mut();
365
366 let mut block_iter = blocks.into_iter().peekable();
367 for (original_pos, instr) in original_instrs.into_iter().enumerate() {
368 let used_block = if let Some(block) = block_iter.peek() {
370 if block.start_pos == original_pos {
371 new_instrs.push(I32Const(block.cost as i32));
372 new_instrs.push(Call(gas_func));
373 true
374 } else {
375 false
376 }
377 } else {
378 false
379 };
380
381 if used_block {
382 block_iter.next();
383 }
384
385 new_instrs.push(instr);
387 }
388
389 if block_iter.next().is_some() {
390 return Err(())
391 }
392
393 Ok(())
394}
395
396pub fn inject_gas_counter<R: Rules>(
431 module: elements::Module,
432 rules: &R,
433 gas_module_name: &str,
434) -> Result<elements::Module, elements::Module> {
435 let mut mbuilder = builder::from_module(module);
437 let import_sig =
438 mbuilder.push_signature(builder::signature().with_param(ValueType::I32).build_sig());
439
440 mbuilder.push_import(
441 builder::import()
442 .module(gas_module_name)
443 .field("gas")
444 .external()
445 .func(import_sig)
446 .build(),
447 );
448
449 let mut module = mbuilder.build();
451
452 let gas_func = module.import_count(elements::ImportCountType::Function) as u32 - 1;
456 let total_func = module.functions_space() as u32;
457 let mut need_grow_counter = false;
458 let mut error = false;
459
460 for section in module.sections_mut() {
462 match section {
463 elements::Section::Code(code_section) =>
464 for func_body in code_section.bodies_mut() {
465 update_call_index(func_body.code_mut(), gas_func);
466 if inject_counter(func_body.code_mut(), rules, gas_func).is_err() {
467 error = true;
468 break
469 }
470 if rules.memory_grow_cost().is_some() &&
471 inject_grow_counter(func_body.code_mut(), total_func) > 0
472 {
473 need_grow_counter = true;
474 }
475 },
476 elements::Section::Export(export_section) => {
477 for export in export_section.entries_mut() {
478 if let elements::Internal::Function(func_index) = export.internal_mut() {
479 if *func_index >= gas_func {
480 *func_index += 1
481 }
482 }
483 }
484 },
485 elements::Section::Element(elements_section) => {
486 for segment in elements_section.entries_mut() {
489 for func_index in segment.members_mut() {
491 if *func_index >= gas_func {
492 *func_index += 1
493 }
494 }
495 }
496 },
497 elements::Section::Start(start_idx) =>
498 if *start_idx >= gas_func {
499 *start_idx += 1
500 },
501 _ => {},
502 }
503 }
504
505 if error {
506 return Err(module)
507 }
508
509 if need_grow_counter {
510 Ok(add_grow_counter(module, rules, gas_func))
511 } else {
512 Ok(module)
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::rules;
520 use parity_wasm::{builder, elements, elements::Instruction::*, serialize};
521
522 pub fn get_function_body(
523 module: &elements::Module,
524 index: usize,
525 ) -> Option<&[elements::Instruction]> {
526 module
527 .code_section()
528 .and_then(|code_section| code_section.bodies().get(index))
529 .map(|func_body| func_body.code().elements())
530 }
531
532 #[test]
533 fn simple_grow() {
534 let module = builder::module()
535 .global()
536 .value_type()
537 .i32()
538 .build()
539 .function()
540 .signature()
541 .param()
542 .i32()
543 .build()
544 .body()
545 .with_instructions(elements::Instructions::new(vec![GetGlobal(0), GrowMemory(0), End]))
546 .build()
547 .build()
548 .build();
549
550 let injected_module =
551 inject_gas_counter(module, &rules::Set::default().with_grow_cost(10000), "env")
552 .unwrap();
553
554 assert_eq!(
555 get_function_body(&injected_module, 0).unwrap(),
556 &vec![I32Const(2), Call(0), GetGlobal(0), Call(2), End][..]
557 );
558 assert_eq!(
559 get_function_body(&injected_module, 1).unwrap(),
560 &vec![GetLocal(0), GetLocal(0), I32Const(10000), I32Mul, Call(0), GrowMemory(0), End,]
561 [..]
562 );
563
564 let binary = serialize(injected_module).expect("serialization failed");
565 wabt::wasm2wat(&binary).unwrap();
566 }
567
568 #[test]
569 fn grow_no_gas_no_track() {
570 let module = builder::module()
571 .global()
572 .value_type()
573 .i32()
574 .build()
575 .function()
576 .signature()
577 .param()
578 .i32()
579 .build()
580 .body()
581 .with_instructions(elements::Instructions::new(vec![GetGlobal(0), GrowMemory(0), End]))
582 .build()
583 .build()
584 .build();
585
586 let injected_module = inject_gas_counter(module, &rules::Set::default(), "env").unwrap();
587
588 assert_eq!(
589 get_function_body(&injected_module, 0).unwrap(),
590 &vec![I32Const(2), Call(0), GetGlobal(0), GrowMemory(0), End][..]
591 );
592
593 assert_eq!(injected_module.functions_space(), 2);
594
595 let binary = serialize(injected_module).expect("serialization failed");
596 wabt::wasm2wat(&binary).unwrap();
597 }
598
599 #[test]
600 fn call_index() {
601 let module = builder::module()
602 .global()
603 .value_type()
604 .i32()
605 .build()
606 .function()
607 .signature()
608 .param()
609 .i32()
610 .build()
611 .body()
612 .build()
613 .build()
614 .function()
615 .signature()
616 .param()
617 .i32()
618 .build()
619 .body()
620 .with_instructions(elements::Instructions::new(vec![
621 Call(0),
622 If(elements::BlockType::NoResult),
623 Call(0),
624 Call(0),
625 Call(0),
626 Else,
627 Call(0),
628 Call(0),
629 End,
630 Call(0),
631 End,
632 ]))
633 .build()
634 .build()
635 .build();
636
637 let injected_module = inject_gas_counter(module, &rules::Set::default(), "env").unwrap();
638
639 assert_eq!(
640 get_function_body(&injected_module, 1).unwrap(),
641 &vec![
642 I32Const(3),
643 Call(0),
644 Call(1),
645 If(elements::BlockType::NoResult),
646 I32Const(3),
647 Call(0),
648 Call(1),
649 Call(1),
650 Call(1),
651 Else,
652 I32Const(2),
653 Call(0),
654 Call(1),
655 Call(1),
656 End,
657 Call(1),
658 End
659 ][..]
660 );
661 }
662
663 #[test]
664 fn forbidden() {
665 let module = builder::module()
666 .global()
667 .value_type()
668 .i32()
669 .build()
670 .function()
671 .signature()
672 .param()
673 .i32()
674 .build()
675 .body()
676 .with_instructions(elements::Instructions::new(vec![F32Const(555555), End]))
677 .build()
678 .build()
679 .build();
680
681 let rules = rules::Set::default().with_forbidden_floats();
682
683 if inject_gas_counter(module, &rules, "env").is_ok() {
684 panic!("Should be error because of the forbidden operation")
685 }
686 }
687
688 fn parse_wat(source: &str) -> elements::Module {
689 let module_bytes = wabt::Wat2Wasm::new()
690 .validate(false)
691 .convert(source)
692 .expect("failed to parse module");
693 elements::deserialize_buffer(module_bytes.as_ref()).expect("failed to parse module")
694 }
695
696 macro_rules! test_gas_counter_injection {
697 (name = $name:ident; input = $input:expr; expected = $expected:expr) => {
698 #[test]
699 fn $name() {
700 let input_module = parse_wat($input);
701 let expected_module = parse_wat($expected);
702
703 let injected_module =
704 inject_gas_counter(input_module, &rules::Set::default(), "env")
705 .expect("inject_gas_counter call failed");
706
707 let actual_func_body = get_function_body(&injected_module, 0)
708 .expect("injected module must have a function body");
709 let expected_func_body = get_function_body(&expected_module, 0)
710 .expect("post-module must have a function body");
711
712 assert_eq!(actual_func_body, expected_func_body);
713 }
714 };
715 }
716
717 test_gas_counter_injection! {
718 name = simple;
719 input = r#"
720 (module
721 (func (result i32)
722 (get_global 0)))
723 "#;
724 expected = r#"
725 (module
726 (func (result i32)
727 (call 0 (i32.const 1))
728 (get_global 0)))
729 "#
730 }
731
732 test_gas_counter_injection! {
733 name = nested;
734 input = r#"
735 (module
736 (func (result i32)
737 (get_global 0)
738 (block
739 (get_global 0)
740 (get_global 0)
741 (get_global 0))
742 (get_global 0)))
743 "#;
744 expected = r#"
745 (module
746 (func (result i32)
747 (call 0 (i32.const 6))
748 (get_global 0)
749 (block
750 (get_global 0)
751 (get_global 0)
752 (get_global 0))
753 (get_global 0)))
754 "#
755 }
756
757 test_gas_counter_injection! {
758 name = ifelse;
759 input = r#"
760 (module
761 (func (result i32)
762 (get_global 0)
763 (if
764 (then
765 (get_global 0)
766 (get_global 0)
767 (get_global 0))
768 (else
769 (get_global 0)
770 (get_global 0)))
771 (get_global 0)))
772 "#;
773 expected = r#"
774 (module
775 (func (result i32)
776 (call 0 (i32.const 3))
777 (get_global 0)
778 (if
779 (then
780 (call 0 (i32.const 3))
781 (get_global 0)
782 (get_global 0)
783 (get_global 0))
784 (else
785 (call 0 (i32.const 2))
786 (get_global 0)
787 (get_global 0)))
788 (get_global 0)))
789 "#
790 }
791
792 test_gas_counter_injection! {
793 name = branch_innermost;
794 input = r#"
795 (module
796 (func (result i32)
797 (get_global 0)
798 (block
799 (get_global 0)
800 (drop)
801 (br 0)
802 (get_global 0)
803 (drop))
804 (get_global 0)))
805 "#;
806 expected = r#"
807 (module
808 (func (result i32)
809 (call 0 (i32.const 6))
810 (get_global 0)
811 (block
812 (get_global 0)
813 (drop)
814 (br 0)
815 (call 0 (i32.const 2))
816 (get_global 0)
817 (drop))
818 (get_global 0)))
819 "#
820 }
821
822 test_gas_counter_injection! {
823 name = branch_outer_block;
824 input = r#"
825 (module
826 (func (result i32)
827 (get_global 0)
828 (block
829 (get_global 0)
830 (if
831 (then
832 (get_global 0)
833 (get_global 0)
834 (drop)
835 (br_if 1)))
836 (get_global 0)
837 (drop))
838 (get_global 0)))
839 "#;
840 expected = r#"
841 (module
842 (func (result i32)
843 (call 0 (i32.const 5))
844 (get_global 0)
845 (block
846 (get_global 0)
847 (if
848 (then
849 (call 0 (i32.const 4))
850 (get_global 0)
851 (get_global 0)
852 (drop)
853 (br_if 1)))
854 (call 0 (i32.const 2))
855 (get_global 0)
856 (drop))
857 (get_global 0)))
858 "#
859 }
860
861 test_gas_counter_injection! {
862 name = branch_outer_loop;
863 input = r#"
864 (module
865 (func (result i32)
866 (get_global 0)
867 (loop
868 (get_global 0)
869 (if
870 (then
871 (get_global 0)
872 (br_if 0))
873 (else
874 (get_global 0)
875 (get_global 0)
876 (drop)
877 (br_if 1)))
878 (get_global 0)
879 (drop))
880 (get_global 0)))
881 "#;
882 expected = r#"
883 (module
884 (func (result i32)
885 (call 0 (i32.const 3))
886 (get_global 0)
887 (loop
888 (call 0 (i32.const 4))
889 (get_global 0)
890 (if
891 (then
892 (call 0 (i32.const 2))
893 (get_global 0)
894 (br_if 0))
895 (else
896 (call 0 (i32.const 4))
897 (get_global 0)
898 (get_global 0)
899 (drop)
900 (br_if 1)))
901 (get_global 0)
902 (drop))
903 (get_global 0)))
904 "#
905 }
906
907 test_gas_counter_injection! {
908 name = return_from_func;
909 input = r#"
910 (module
911 (func (result i32)
912 (get_global 0)
913 (if
914 (then
915 (return)))
916 (get_global 0)))
917 "#;
918 expected = r#"
919 (module
920 (func (result i32)
921 (call 0 (i32.const 2))
922 (get_global 0)
923 (if
924 (then
925 (call 0 (i32.const 1))
926 (return)))
927 (call 0 (i32.const 1))
928 (get_global 0)))
929 "#
930 }
931
932 test_gas_counter_injection! {
933 name = branch_from_if_not_else;
934 input = r#"
935 (module
936 (func (result i32)
937 (get_global 0)
938 (block
939 (get_global 0)
940 (if
941 (then (br 1))
942 (else (br 0)))
943 (get_global 0)
944 (drop))
945 (get_global 0)))
946 "#;
947 expected = r#"
948 (module
949 (func (result i32)
950 (call 0 (i32.const 5))
951 (get_global 0)
952 (block
953 (get_global 0)
954 (if
955 (then
956 (call 0 (i32.const 1))
957 (br 1))
958 (else
959 (call 0 (i32.const 1))
960 (br 0)))
961 (call 0 (i32.const 2))
962 (get_global 0)
963 (drop))
964 (get_global 0)))
965 "#
966 }
967
968 test_gas_counter_injection! {
969 name = empty_loop;
970 input = r#"
971 (module
972 (func
973 (loop
974 (br 0)
975 )
976 unreachable
977 )
978 )
979 "#;
980 expected = r#"
981 (module
982 (func
983 (call 0 (i32.const 2))
984 (loop
985 (call 0 (i32.const 1))
986 (br 0)
987 )
988 unreachable
989 )
990 )
991 "#
992 }
993}