gear_wasm_instrument/gas_metering/
mod.rs

1// This file is part of Gear.
2//
3// Copyright (C) 2017-2024 Parity Technologies.
4// Copyright (C) 2025 Gear Technologies Inc.
5// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
6//
7// This program is free software: you can redistribute it and/or modify
8// it under the terms of the GNU General Public License as published by
9// the Free Software Foundation, either version 3 of the License, or
10// (at your option) any later version.
11//
12// This program is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU General Public License for more details.
16//
17// You should have received a copy of the GNU General Public License
18// along with this program. If not, see <https://www.gnu.org/licenses/>.
19
20//! This module is used to instrument a Wasm module with gas metering code.
21//!
22//! The primary public interface is the [`inject`] function which transforms a given
23//! module into one that charges gas for code to be executed. See function documentation for usage
24//! and details.
25
26#[cfg(test)]
27mod validation;
28
29use crate::{
30    Module,
31    module::{Function, Import, Instruction, ModuleBuilder},
32};
33use alloc::vec::Vec;
34use core::{cmp::min, mem, num::NonZeroU32};
35use wasmparser::{FuncType, TypeRef, ValType};
36
37#[derive(Debug, derive_more::From)]
38pub enum GasMeteringError {
39    Counter(CounterError),
40    LocalsInitCost,
41    NoActiveControlBlock,
42    MissingInstructionRule(Instruction),
43    ActiveIndexRelativeDepthUnderflow,
44    UnusedBlock,
45    LocalsCountOverflow,
46    ActiveIndexLabelUnderflow,
47}
48
49/// An interface that describes instruction costs.
50pub trait Rules {
51    /// Returns the cost for the passed `instruction`.
52    ///
53    /// Returning `None` makes the gas instrumentation end with an error. This is meant
54    /// as a way to have a partial rule set where any instruction that is not specified
55    /// is considered as forbidden.
56    fn instruction_cost(&self, instruction: &Instruction) -> Option<u32>;
57
58    /// Returns the costs for growing the memory using the `memory.grow` instruction.
59    ///
60    /// Please note that these costs are in addition to the costs specified by `instruction_cost`
61    /// for the `memory.grow` instruction. Those are meant as dynamic costs which take the
62    /// amount of pages that the memory is grown by into consideration. This is not possible
63    /// using `instruction_cost` because those costs depend on the stack and must be injected as
64    /// code into the function calling `memory.grow`. Therefore returning anything but
65    /// [`MemoryGrowCost::Free`] introduces some overhead to the `memory.grow` instruction.
66    fn memory_grow_cost(&self) -> MemoryGrowCost;
67
68    /// A surcharge cost to calling a function that is added per local variable of the function.
69    fn call_per_local_cost(&self) -> u32;
70}
71
72/// Dynamic costs for memory growth.
73#[derive(Debug, PartialEq, Eq, Copy, Clone)]
74pub enum MemoryGrowCost {
75    /// Skip per page charge.
76    ///
77    /// # Note
78    ///
79    /// This makes sense when the amount of pages that a module is allowed to use is limited
80    /// to a rather small number by static validation. In that case it is viable to
81    /// benchmark the costs of `memory.grow` as the worst case (growing to to the maximum
82    /// number of pages).
83    Free,
84    /// Charge the specified amount for each page that the memory is grown by.
85    Linear(NonZeroU32),
86}
87
88impl MemoryGrowCost {
89    /// True iff memory growths code needs to be injected.
90    fn enabled(&self) -> bool {
91        match self {
92            Self::Free => false,
93            Self::Linear(_) => true,
94        }
95    }
96}
97
98/// A type that implements [`Rules`] so that every instruction costs the same.
99///
100/// This is a simplification that is mostly useful for development and testing.
101///
102/// # Note
103///
104/// In a production environment it usually makes no sense to assign every instruction
105/// the same cost. A proper implementation of [`Rules`] should be prived that is probably
106/// created by benchmarking.
107pub struct ConstantCostRules {
108    instruction_cost: u32,
109    memory_grow_cost: u32,
110    call_per_local_cost: u32,
111}
112
113impl ConstantCostRules {
114    /// Create a new [`ConstantCostRules`].
115    ///
116    /// Uses `instruction_cost` for every instruction and `memory_grow_cost` to dynamically
117    /// meter the memory growth instruction.
118    pub fn new(instruction_cost: u32, memory_grow_cost: u32, call_per_local_cost: u32) -> Self {
119        Self {
120            instruction_cost,
121            memory_grow_cost,
122            call_per_local_cost,
123        }
124    }
125}
126
127impl Default for ConstantCostRules {
128    /// Uses instruction cost of `1` and disables memory growth instrumentation.
129    fn default() -> Self {
130        Self {
131            instruction_cost: 1,
132            memory_grow_cost: 0,
133            call_per_local_cost: 1,
134        }
135    }
136}
137
138impl Rules for ConstantCostRules {
139    fn instruction_cost(&self, _instruction: &Instruction) -> Option<u32> {
140        Some(self.instruction_cost)
141    }
142
143    fn memory_grow_cost(&self) -> MemoryGrowCost {
144        NonZeroU32::new(self.memory_grow_cost).map_or(MemoryGrowCost::Free, MemoryGrowCost::Linear)
145    }
146
147    fn call_per_local_cost(&self) -> u32 {
148        self.call_per_local_cost
149    }
150}
151
152/// Transforms a given module into one that charges gas for code to be executed by proxy of an
153/// imported gas metering function.
154///
155/// The output module imports a function "gas" from the specified module with type signature
156/// [i32] -> []. The argument is the amount of gas required to continue execution. The external
157/// function is meant to keep track of the total amount of gas used and trap or otherwise halt
158/// execution of the runtime if the gas usage exceeds some allowed limit.
159///
160/// The body of each function is divided into metered blocks, and the calls to charge gas are
161/// inserted at the beginning of every such block of code. A metered block is defined so that,
162/// unless there is a trap, either all of the instructions are executed or none are. These are
163/// similar to basic blocks in a control flow graph, except that in some cases multiple basic
164/// blocks can be merged into a single metered block. This is the case if any path through the
165/// control flow graph containing one basic block also contains another.
166///
167/// Charging gas is at the beginning of each metered block ensures that 1) all instructions
168/// executed are already paid for, 2) instructions that will not be executed are not charged for
169/// unless execution traps, and 3) the number of calls to "gas" is minimized. The corollary is that
170/// modules instrumented with this metering code may charge gas for instructions not executed in
171/// the event of a trap.
172///
173/// Additionally, each `memory.grow` instruction found in the module is instrumented to first make
174/// a call to charge gas for the additional pages requested. This cannot be done as part of the
175/// block level gas charges as the gas cost is not static and depends on the stack argument to
176/// `memory.grow`.
177///
178/// The above transformations are performed for every function body defined in the module. This
179/// function also rewrites all function indices references by code, table elements, etc., since
180/// the addition of an imported functions changes the indices of module-defined functions. If the
181/// the module has a NameSection, added by calling `parse_names`, the indices will also be updated.
182///
183/// This routine runs in time linear in the size of the input module.
184///
185/// The function fails if the module contains any operation forbidden by gas rule set, returning
186/// the original module as an Err.
187pub fn inject<R: Rules>(
188    module: Module,
189    rules: &R,
190    gas_module_name: &'static str,
191) -> Result<Module, GasMeteringError> {
192    // Injecting gas counting external
193    let gas_func = module.import_count(|ty| matches!(ty, TypeRef::Func(_)));
194
195    let mut mbuilder = ModuleBuilder::from_module(module);
196
197    let import_sig = mbuilder.push_type(FuncType::new([ValType::I32], []));
198    mbuilder.push_import(Import::func(gas_module_name, "gas", import_sig));
199
200    let module = mbuilder
201        .shift_func_index(gas_func as u32)
202        .shift_all()
203        .build();
204
205    post_injection_handler(module, rules, gas_func)
206}
207
208/// Helper procedure that makes adjustments after gas metering function injected.
209///
210/// See documentation for [`inject`] for more details.
211pub fn post_injection_handler<R: Rules>(
212    mut module: Module,
213    rules: &R,
214    gas_charge_index: usize,
215) -> Result<Module, GasMeteringError> {
216    // calculate actual function index of the imported definition
217    //    (subtract all imports that are NOT functions)
218
219    let import_count = module.import_count(|ty| matches!(ty, TypeRef::Func(_)));
220    let total_func = module.functions_space() as u32;
221    let mut need_grow_counter = false;
222
223    if let Some(code_section) = &mut module.code_section {
224        for (i, func_body) in code_section.iter_mut().enumerate() {
225            if i + import_count == gas_charge_index {
226                continue;
227            }
228
229            let locals_count = func_body
230                .locals
231                .iter()
232                .try_fold(0u32, |count, val_type| count.checked_add(val_type.0))
233                .ok_or(GasMeteringError::LocalsCountOverflow)?;
234            inject_counter(
235                &mut func_body.instructions,
236                rules,
237                locals_count,
238                gas_charge_index as u32,
239            )?;
240
241            if rules.memory_grow_cost().enabled()
242                && inject_grow_counter(&mut func_body.instructions, total_func) > 0
243            {
244                need_grow_counter = true;
245            }
246        }
247    }
248
249    match need_grow_counter {
250        true => Ok(add_grow_counter(module, rules, gas_charge_index as u32)),
251        false => Ok(module),
252    }
253}
254
255/// A control flow block is opened with the `block`, `loop`, and `if` instructions and is closed
256/// with `end`. Each block implicitly defines a new label. The control blocks form a stack during
257/// program execution.
258///
259/// An example of block:
260///
261/// ```wasm
262/// loop
263///   i32.const 1
264///   local.get 0
265///   i32.sub
266///   local.tee 0
267///   br_if 0
268/// end
269/// ```
270///
271/// The start of the block is `i32.const 1`.
272#[derive(Debug)]
273struct ControlBlock {
274    /// The lowest control stack index corresponding to a forward jump targeted by a br, br_if, or
275    /// br_table instruction within this control block. The index must refer to a control block
276    /// that is not a loop, meaning it is a forward jump. Given the way Wasm control flow is
277    /// structured, the lowest index on the stack represents the furthest forward branch target.
278    ///
279    /// This value will always be at most the index of the block itself, even if there is no
280    /// explicit br instruction targeting this control block. This does not affect how the value is
281    /// used in the metering algorithm.
282    lowest_forward_br_target: usize,
283
284    /// The active metering block that new instructions contribute a gas cost towards.
285    active_metered_block: MeteredBlock,
286
287    /// Whether the control block is a loop. Loops have the distinguishing feature that branches to
288    /// them jump to the beginning of the block, not the end as with the other control blocks.
289    is_loop: bool,
290}
291
292/// A block of code that metering instructions will be inserted at the beginning of. Metered blocks
293/// are constructed with the property that, in the absence of any traps, either all instructions in
294/// the block are executed or none are.
295#[derive(Debug)]
296struct MeteredBlock {
297    /// Index of the first instruction (aka `Opcode`) in the block.
298    start_pos: usize,
299    /// Sum of costs of all instructions until end of the block.
300    cost: BlockCostCounter,
301}
302
303/// Metering block cost counter, which handles arithmetic overflows.
304#[derive(Debug, PartialEq, PartialOrd)]
305#[cfg_attr(test, derive(Copy, Clone, Default))]
306struct BlockCostCounter {
307    /// Arithmetical overflows can occur while summarizing costs of some
308    /// instruction set. To handle this, we count amount of such overflows
309    /// with a separate counter and continue counting cost of metering block.
310    ///
311    /// The overflow counter can overflow itself. However, this is not the
312    /// problem for the following reason. The returning after module instrumentation
313    /// set of instructions is a `Vec` which can't allocate more than `isize::MAX`
314    /// amount of memory, If, for instance, we are running the counter on the host
315    /// machine with 32 pointer size, reaching a huge amount of overflows can fail
316    /// instrumentation even if `overflows` is not overflowed, because we will
317    /// have a resulting set of instructions so big, that it will be impossible to
318    /// allocate a vector for it. So regardless of overflow of `overflows` field,
319    /// the field having huge value can fail instrumentation. This memory allocation
320    /// problem allows us to exhale and not think about the overflow of the
321    /// `overflows` field. What's more, the memory allocation problem (size of
322    /// instrumenting WASM) is a caller side concern.
323    overflows: usize,
324    /// Block's cost accumulator.
325    accumulator: u32,
326}
327
328impl BlockCostCounter {
329    /// Maximum value of the `gas` call argument.
330    ///
331    /// This constant bounds maximum value of argument
332    /// in `gas` operation in order to prevent arithmetic
333    /// overflow. For more information see type docs.
334    const MAX_GAS_ARG: u32 = u32::MAX;
335
336    fn zero() -> Self {
337        Self::initialize(0)
338    }
339
340    fn initialize(initial_cost: u32) -> Self {
341        Self {
342            overflows: 0,
343            accumulator: initial_cost,
344        }
345    }
346
347    fn add(&mut self, counter: BlockCostCounter) {
348        // Overflow of `self.overflows` is not a big deal. See `overflows` field docs.
349        self.overflows = self.overflows.saturating_add(counter.overflows);
350        self.increment(counter.accumulator)
351    }
352
353    fn increment(&mut self, val: u32) {
354        if let Some(res) = self.accumulator.checked_add(val) {
355            self.accumulator = res;
356        } else {
357            // Case when self.accumulator + val > Self::MAX_GAS_ARG
358            self.accumulator = val - (u32::MAX - self.accumulator);
359            // Overflow of `self.overflows` is not a big deal. See `overflows` field docs.
360            self.overflows = self.overflows.saturating_add(1);
361        }
362    }
363
364    /// Returns the tuple of costs, where the first element is an amount of overflows
365    /// emerged when summating block's cost, and the second element is the current
366    /// (not overflowed remainder) block's cost.
367    fn block_costs(&self) -> (usize, u32) {
368        (self.overflows, self.accumulator)
369    }
370
371    /// Returns amount of costs for each of which the gas charging
372    /// procedure will be called.
373    fn costs_num(&self) -> usize {
374        if self.accumulator != 0 {
375            self.overflows + 1
376        } else {
377            self.overflows
378        }
379    }
380}
381
382#[derive(Debug)]
383pub enum CounterError {
384    StackLast,
385    StackPop,
386    StackGet,
387}
388
389/// Counter is used to manage state during the gas metering algorithm implemented by
390/// `inject_counter`.
391struct Counter {
392    /// A stack of control blocks. This stack grows when new control blocks are opened with
393    /// `block`, `loop`, and `if` and shrinks when control blocks are closed with `end`. The first
394    /// block on the stack corresponds to the function body, not to any labelled block. Therefore
395    /// the actual Wasm label index associated with each control block is 1 less than its position
396    /// in this stack.
397    stack: Vec<ControlBlock>,
398
399    /// A list of metered blocks that have been finalized, meaning they will no longer change.
400    finalized_blocks: Vec<MeteredBlock>,
401}
402
403impl Counter {
404    fn new() -> Counter {
405        Counter {
406            stack: Vec::new(),
407            finalized_blocks: Vec::new(),
408        }
409    }
410
411    /// Open a new control block. The cursor is the position of the first instruction in the block.
412    fn begin_control_block(&mut self, cursor: usize, is_loop: bool) {
413        let index = self.stack.len();
414        self.stack.push(ControlBlock {
415            lowest_forward_br_target: index,
416            active_metered_block: MeteredBlock {
417                start_pos: cursor,
418                cost: BlockCostCounter::zero(),
419            },
420            is_loop,
421        })
422    }
423
424    /// Close the last control block. The cursor is the position of the final (pseudo-)instruction
425    /// in the block.
426    fn finalize_control_block(&mut self, cursor: usize) -> Result<(), CounterError> {
427        // This either finalizes the active metered block or merges its cost into the active
428        // metered block in the previous control block on the stack.
429        self.finalize_metered_block(cursor)?;
430
431        // Pop the control block stack.
432        let closing_control_block = self.stack.pop().ok_or(CounterError::StackPop)?;
433        let closing_control_index = self.stack.len();
434
435        if self.stack.is_empty() {
436            return Ok(());
437        }
438
439        // Update the lowest_forward_br_target for the control block now on top of the stack.
440        {
441            let control_block = self.stack.last_mut().ok_or(CounterError::StackLast)?;
442            control_block.lowest_forward_br_target = min(
443                control_block.lowest_forward_br_target,
444                closing_control_block.lowest_forward_br_target,
445            );
446        }
447
448        // If there may have been a branch to a lower index, then also finalize the active metered
449        // block for the previous control block. Otherwise, finalize it and begin a new one.
450        let may_br_out = closing_control_block.lowest_forward_br_target < closing_control_index;
451        if may_br_out {
452            self.finalize_metered_block(cursor)?;
453        }
454
455        Ok(())
456    }
457
458    /// Finalize the current active metered block.
459    ///
460    /// Finalized blocks have final cost which will not change later.
461    fn finalize_metered_block(&mut self, cursor: usize) -> Result<(), CounterError> {
462        let closing_metered_block = {
463            let control_block = self.stack.last_mut().ok_or(CounterError::StackLast)?;
464            mem::replace(
465                &mut control_block.active_metered_block,
466                MeteredBlock {
467                    start_pos: cursor + 1,
468                    cost: BlockCostCounter::zero(),
469                },
470            )
471        };
472
473        // If the block was opened with a `block`, then its start position will be set to that of
474        // the active metered block in the control block one higher on the stack. This is because
475        // any instructions between a `block` and the first branch are part of the same basic block
476        // as the preceding instruction. In this case, instead of finalizing the block, merge its
477        // cost into the other active metered block to avoid injecting unnecessary instructions.
478        let last_index = self.stack.len() - 1;
479        if last_index > 0 {
480            let prev_control_block = self
481                .stack
482                .get_mut(last_index - 1)
483                .expect("last_index is greater than 0; last_index is stack size - 1; qed");
484            let prev_metered_block = &mut prev_control_block.active_metered_block;
485            if closing_metered_block.start_pos == prev_metered_block.start_pos {
486                prev_metered_block.cost.add(closing_metered_block.cost);
487                return Ok(());
488            }
489        }
490
491        if closing_metered_block.cost > BlockCostCounter::zero() {
492            self.finalized_blocks.push(closing_metered_block);
493        }
494        Ok(())
495    }
496
497    /// Handle a branch instruction in the program. The cursor is the index of the branch
498    /// instruction in the program. The indices are the stack positions of the target control
499    /// blocks. Recall that the index is 0 for a `return` and relatively indexed from the top of
500    /// the stack by the label of `br`, `br_if`, and `br_table` instructions.
501    fn branch(&mut self, cursor: usize, indices: &[usize]) -> Result<(), CounterError> {
502        self.finalize_metered_block(cursor)?;
503
504        // Update the lowest_forward_br_target of the current control block.
505        for &index in indices {
506            let target_is_loop = {
507                let target_block = self.stack.get(index).ok_or(CounterError::StackGet)?;
508                target_block.is_loop
509            };
510            if target_is_loop {
511                continue;
512            }
513
514            let control_block = self.stack.last_mut().ok_or(CounterError::StackLast)?;
515            control_block.lowest_forward_br_target =
516                min(control_block.lowest_forward_br_target, index);
517        }
518
519        Ok(())
520    }
521
522    /// Returns the stack index of the active control block. Returns None if stack is empty.
523    fn active_control_block_index(&self) -> Option<usize> {
524        self.stack.len().checked_sub(1)
525    }
526
527    /// Get a reference to the currently active metered block.
528    fn active_metered_block(&mut self) -> Result<&mut MeteredBlock, CounterError> {
529        let top_block = self.stack.last_mut().ok_or(CounterError::StackLast)?;
530        Ok(&mut top_block.active_metered_block)
531    }
532
533    /// Increment the cost of the current block by the specified value.
534    fn increment(&mut self, val: u32) -> Result<(), CounterError> {
535        let top_block = self.active_metered_block()?;
536        top_block.cost.increment(val);
537        Ok(())
538    }
539}
540
541fn inject_grow_counter(instructions: &mut Vec<Instruction>, grow_counter_func: u32) -> usize {
542    use Instruction::*;
543    let mut counter = 0;
544    for instruction in instructions {
545        if let MemoryGrow(_) = *instruction {
546            *instruction = Call(grow_counter_func);
547            counter += 1;
548        }
549    }
550    counter
551}
552
553fn add_grow_counter<R: Rules>(module: Module, rules: &R, gas_func: u32) -> Module {
554    use Instruction::*;
555
556    let cost = match rules.memory_grow_cost() {
557        MemoryGrowCost::Free => return module,
558        MemoryGrowCost::Linear(val) => val.get(),
559    };
560
561    let mut b = ModuleBuilder::from_module(module);
562    b.add_func(
563        FuncType::new([ValType::I32], [ValType::I32]),
564        Function::from_instructions([
565            LocalGet(0),
566            LocalGet(0),
567            I32Const(cost as i32),
568            I32Mul,
569            // todo: there should be strong guarantee that it does not return anything on
570            // stack?
571            Call(gas_func),
572            MemoryGrow(0),
573            End,
574        ]),
575    );
576
577    b.build()
578}
579
580fn determine_metered_blocks<R: Rules>(
581    instructions: &[Instruction],
582    rules: &R,
583    locals_count: u32,
584) -> Result<Vec<MeteredBlock>, GasMeteringError> {
585    use Instruction::*;
586
587    let mut counter = Counter::new();
588
589    // Begin an implicit function (i.e. `func...end`) block.
590    counter.begin_control_block(0, false);
591
592    // Add locals initialization cost to the function block.
593    let locals_init_cost = rules
594        .call_per_local_cost()
595        .checked_mul(locals_count)
596        .ok_or(GasMeteringError::LocalsInitCost)?;
597    counter.increment(locals_init_cost)?;
598
599    for (cursor, instruction) in instructions.iter().enumerate() {
600        let instruction_cost = rules
601            .instruction_cost(instruction)
602            .ok_or_else(|| GasMeteringError::MissingInstructionRule(instruction.clone()))?;
603        match instruction {
604            Block { .. } => {
605                counter.increment(instruction_cost)?;
606
607                // Begin new block. The cost of the following opcodes until `end` or `else` will
608                // be included into this block. The start position is set to that of the previous
609                // active metered block to signal that they should be merged in order to reduce
610                // unnecessary metering instructions.
611                let top_block_start_pos = counter.active_metered_block()?.start_pos;
612                counter.begin_control_block(top_block_start_pos, false);
613            }
614            If { .. } => {
615                counter.increment(instruction_cost)?;
616                counter.begin_control_block(cursor + 1, false);
617            }
618            Loop { .. } => {
619                counter.increment(instruction_cost)?;
620                counter.begin_control_block(cursor + 1, true);
621            }
622            End => {
623                counter.finalize_control_block(cursor)?;
624            }
625            Else => {
626                counter.finalize_metered_block(cursor)?;
627            }
628            Br(relative_depth) | BrIf(relative_depth) => {
629                counter.increment(instruction_cost)?;
630
631                // Label is a relative index into the control stack.
632                let active_index = counter
633                    .active_control_block_index()
634                    .ok_or(GasMeteringError::NoActiveControlBlock)?;
635                let target_index = active_index
636                    .checked_sub(*relative_depth as usize)
637                    .ok_or(GasMeteringError::ActiveIndexRelativeDepthUnderflow)?;
638                counter.branch(cursor, &[target_index])?;
639            }
640            BrTable(targets) => {
641                counter.increment(instruction_cost)?;
642
643                let active_index = counter
644                    .active_control_block_index()
645                    .ok_or(GasMeteringError::NoActiveControlBlock)?;
646                let target_indices = [targets.default]
647                    .into_iter()
648                    .chain(targets.targets.clone())
649                    .map(|label| active_index.checked_sub(label as usize))
650                    .collect::<Option<Vec<_>>>()
651                    .ok_or(GasMeteringError::ActiveIndexLabelUnderflow)?;
652                counter.branch(cursor, target_indices.as_slice())?;
653            }
654            Return => {
655                counter.increment(instruction_cost)?;
656                counter.branch(cursor, &[0])?;
657            }
658            _ => {
659                // An ordinal non control flow instruction increments the cost of the current block.
660                counter.increment(instruction_cost)?;
661            }
662        }
663    }
664
665    counter
666        .finalized_blocks
667        .sort_unstable_by_key(|block| block.start_pos);
668    Ok(counter.finalized_blocks)
669}
670
671fn inject_counter<R: Rules>(
672    instructions: &mut Vec<Instruction>,
673    rules: &R,
674    locals_count: u32,
675    gas_func: u32,
676) -> Result<(), GasMeteringError> {
677    let blocks = determine_metered_blocks(instructions, rules, locals_count)?;
678    insert_metering_calls(instructions, blocks, gas_func)
679}
680
681// Then insert metering calls into a sequence of instructions given the block locations and costs.
682fn insert_metering_calls(
683    instructions: &mut Vec<Instruction>,
684    blocks: Vec<MeteredBlock>,
685    gas_func: u32,
686) -> Result<(), GasMeteringError> {
687    let block_cost_instrs = calculate_blocks_costs_num(&blocks);
688    // To do this in linear time, construct a new vector of instructions, copying over old
689    // instructions one by one and injecting new ones as required.
690    let new_instrs_len = instructions.len() + 2 * block_cost_instrs;
691    let original_instrs = mem::replace(instructions, Vec::with_capacity(new_instrs_len));
692    let new_instrs = instructions;
693
694    let mut block_iter = blocks.into_iter().peekable();
695    for (original_pos, instr) in original_instrs.into_iter().enumerate() {
696        // If there the next block starts at this position, inject metering instructions.
697        let used_block = if let Some(block) = block_iter.peek() {
698            if block.start_pos == original_pos {
699                insert_gas_call(new_instrs, block, gas_func);
700                true
701            } else {
702                false
703            }
704        } else {
705            false
706        };
707
708        if used_block {
709            block_iter.next();
710        }
711
712        // Copy over the original instruction.
713        new_instrs.push(instr);
714    }
715
716    if block_iter.next().is_some() {
717        return Err(GasMeteringError::UnusedBlock);
718    }
719
720    Ok(())
721}
722
723// Calculates total amount of costs (potential gas charging calls) in blocks
724fn calculate_blocks_costs_num(blocks: &[MeteredBlock]) -> usize {
725    blocks.iter().map(|block| block.cost.costs_num()).sum()
726}
727
728fn insert_gas_call(new_instrs: &mut Vec<Instruction>, current_block: &MeteredBlock, gas_func: u32) {
729    use Instruction::*;
730
731    let (mut overflows_num, current_cost) = current_block.cost.block_costs();
732    // First insert gas charging call with maximum argument due to overflows.
733    while overflows_num != 0 {
734        new_instrs.push(I32Const(BlockCostCounter::MAX_GAS_ARG as i32));
735        new_instrs.push(Call(gas_func));
736        overflows_num -= 1;
737    }
738    // Second insert remaining block's cost, if necessary.
739    if current_cost != 0 {
740        new_instrs.push(I32Const(current_cost as i32));
741        new_instrs.push(Call(gas_func));
742    }
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use crate::{
749        module::{ConstExpr, Global, Instruction::*},
750        test_gas_counter_injection,
751        tests::parse_wat,
752    };
753    use wasmparser::{BlockType, GlobalType};
754
755    fn get_function_body(module: &Module, index: usize) -> Option<&[Instruction]> {
756        module
757            .code_section
758            .as_ref()
759            .and_then(|code_section| code_section.get(index))
760            .map(|func_body| func_body.instructions.as_slice())
761    }
762
763    fn prebuilt_simple_module() -> Module {
764        let mut mbuilder = ModuleBuilder::default();
765
766        mbuilder.push_global(Global {
767            ty: GlobalType {
768                content_type: ValType::I32,
769                mutable: false,
770                shared: false,
771            },
772            init_expr: ConstExpr::default(),
773        });
774
775        mbuilder.add_func(FuncType::new([ValType::I32], []), Function::default());
776
777        mbuilder.add_func(
778            FuncType::new([ValType::I32], []),
779            Function::from_instructions([
780                Call(0),
781                If(BlockType::Empty),
782                Call(0),
783                Call(0),
784                Call(0),
785                Else,
786                Call(0),
787                Call(0),
788                End,
789                Call(0),
790                End,
791            ]),
792        );
793
794        mbuilder.build()
795    }
796
797    #[test]
798    fn simple_grow() {
799        let module = parse_wat(
800            r#"(module
801			(func (result i32)
802			  global.get 0
803			  memory.grow)
804			(global i32 (i32.const 42))
805			(memory 0 1)
806			)"#,
807        );
808
809        let injected_module = inject(module, &ConstantCostRules::new(1, 10_000, 1), "env").unwrap();
810
811        assert_eq!(
812            get_function_body(&injected_module, 0).unwrap(),
813            [I32Const(2), Call(0), GlobalGet(0), Call(2), End]
814        );
815        assert_eq!(
816            get_function_body(&injected_module, 1).unwrap(),
817            [
818                LocalGet(0),
819                LocalGet(0),
820                I32Const(10000),
821                I32Mul,
822                Call(0),
823                MemoryGrow(0),
824                End,
825            ]
826        );
827
828        let binary = injected_module.serialize().expect("serialization failed");
829        wasmparser::validate(&binary).unwrap();
830    }
831
832    #[test]
833    fn grow_no_gas_no_track() {
834        let module = parse_wat(
835            r"(module
836			(func (result i32)
837			  global.get 0
838			  memory.grow)
839			(global i32 (i32.const 42))
840			(memory 0 1)
841			)",
842        );
843
844        let injected_module = inject(module, &ConstantCostRules::default(), "env").unwrap();
845
846        assert_eq!(
847            get_function_body(&injected_module, 0).unwrap(),
848            [I32Const(2), Call(0), GlobalGet(0), MemoryGrow(0), End]
849        );
850
851        assert_eq!(injected_module.functions_space(), 2);
852
853        let binary = injected_module.serialize().expect("serialization failed");
854        wasmparser::validate(&binary).unwrap();
855    }
856
857    #[test]
858    fn call_index() {
859        let injected_module = inject(
860            prebuilt_simple_module(),
861            &ConstantCostRules::default(),
862            "env",
863        )
864        .unwrap();
865
866        assert_eq!(
867            get_function_body(&injected_module, 1).unwrap(),
868            &vec![
869                I32Const(3),
870                Call(0),
871                Call(1),
872                If(BlockType::Empty),
873                I32Const(3),
874                Call(0),
875                Call(1),
876                Call(1),
877                Call(1),
878                Else,
879                I32Const(2),
880                Call(0),
881                Call(1),
882                Call(1),
883                End,
884                Call(1),
885                End
886            ][..]
887        );
888    }
889
890    #[test]
891    fn cost_overflow() {
892        let instruction_cost = u32::MAX / 2;
893        let injected_module = inject(
894            prebuilt_simple_module(),
895            &ConstantCostRules::new(instruction_cost, 0, instruction_cost),
896            "env",
897        )
898        .unwrap();
899
900        assert_eq!(
901            get_function_body(&injected_module, 1).unwrap(),
902            &vec![
903                // (instruction_cost * 3) as i32 => ((2147483647 * 2) + 2147483647) as i32 =>
904                // ((2147483647 + 2147483647 + 1) + 2147483646) as i32 =>
905                // (u32::MAX as i32) + 2147483646 as i32
906                I32Const(-1),
907                Call(0),
908                I32Const((instruction_cost - 1) as i32),
909                Call(0),
910                Call(1),
911                If(BlockType::Empty),
912                // Same as upper
913                I32Const(-1),
914                Call(0),
915                I32Const((instruction_cost - 1) as i32),
916                Call(0),
917                Call(1),
918                Call(1),
919                Call(1),
920                Else,
921                // (instruction_cost * 2) as i32
922                I32Const(-2),
923                Call(0),
924                Call(1),
925                Call(1),
926                End,
927                Call(1),
928                End
929            ][..]
930        );
931    }
932
933    macro_rules! test_gas_counter_injection {
934        (name = $name:ident; input = $input:expr; expected = $expected:expr) => {
935            #[test]
936            fn $name() {
937                let input_module = parse_wat($input);
938                let expected_module = parse_wat($expected);
939
940                let injected_module = inject(input_module, &ConstantCostRules::default(), "env")
941                    .expect("inject_gas_counter call failed");
942
943                let actual_func_body = get_function_body(&injected_module, 0)
944                    .expect("injected module must have a function body");
945                let expected_func_body = get_function_body(&expected_module, 0)
946                    .expect("post-module must have a function body");
947
948                assert_eq!(actual_func_body, expected_func_body);
949            }
950        };
951    }
952
953    test_gas_counter_injection! {
954        name = simple;
955        input = r#"
956		(module
957			(func (result i32)
958				(global.get 0)))
959		"#;
960        expected = r#"
961		(module
962			(func (result i32)
963				(call 0 (i32.const 1))
964				(global.get 0)))
965		"#
966    }
967
968    test_gas_counter_injection! {
969        name = nested;
970        input = r#"
971		(module
972			(func (result i32)
973				(global.get 0)
974				(block
975					(global.get 0)
976					(global.get 0)
977					(global.get 0))
978				(global.get 0)))
979		"#;
980        expected = r#"
981		(module
982			(func (result i32)
983				(call 0 (i32.const 6))
984				(global.get 0)
985				(block
986					(global.get 0)
987					(global.get 0)
988					(global.get 0))
989				(global.get 0)))
990		"#
991    }
992
993    test_gas_counter_injection! {
994        name = ifelse;
995        input = r#"
996		(module
997			(func (result i32)
998				(global.get 0)
999				(if
1000					(then
1001						(global.get 0)
1002						(global.get 0)
1003						(global.get 0))
1004					(else
1005						(global.get 0)
1006						(global.get 0)))
1007				(global.get 0)))
1008		"#;
1009        expected = r#"
1010		(module
1011			(func (result i32)
1012				(call 0 (i32.const 3))
1013				(global.get 0)
1014				(if
1015					(then
1016						(call 0 (i32.const 3))
1017						(global.get 0)
1018						(global.get 0)
1019						(global.get 0))
1020					(else
1021						(call 0 (i32.const 2))
1022						(global.get 0)
1023						(global.get 0)))
1024				(global.get 0)))
1025		"#
1026    }
1027
1028    test_gas_counter_injection! {
1029        name = branch_innermost;
1030        input = r#"
1031		(module
1032			(func (result i32)
1033				(global.get 0)
1034				(block
1035					(global.get 0)
1036					(drop)
1037					(br 0)
1038					(global.get 0)
1039					(drop))
1040				(global.get 0)))
1041		"#;
1042        expected = r#"
1043		(module
1044			(func (result i32)
1045				(call 0 (i32.const 6))
1046				(global.get 0)
1047				(block
1048					(global.get 0)
1049					(drop)
1050					(br 0)
1051					(call 0 (i32.const 2))
1052					(global.get 0)
1053					(drop))
1054				(global.get 0)))
1055		"#
1056    }
1057
1058    test_gas_counter_injection! {
1059        name = branch_outer_block;
1060        input = r#"
1061		(module
1062			(func (result i32)
1063				(global.get 0)
1064				(block
1065					(global.get 0)
1066					(if
1067						(then
1068							(global.get 0)
1069							(global.get 0)
1070							(drop)
1071							(br_if 1)))
1072					(global.get 0)
1073					(drop))
1074				(global.get 0)))
1075		"#;
1076        expected = r#"
1077		(module
1078			(func (result i32)
1079				(call 0 (i32.const 5))
1080				(global.get 0)
1081				(block
1082					(global.get 0)
1083					(if
1084						(then
1085							(call 0 (i32.const 4))
1086							(global.get 0)
1087							(global.get 0)
1088							(drop)
1089							(br_if 1)))
1090					(call 0 (i32.const 2))
1091					(global.get 0)
1092					(drop))
1093				(global.get 0)))
1094		"#
1095    }
1096
1097    test_gas_counter_injection! {
1098        name = branch_outer_loop;
1099        input = r#"
1100		(module
1101			(func (result i32)
1102				(global.get 0)
1103				(loop
1104					(global.get 0)
1105					(if
1106						(then
1107							(global.get 0)
1108							(br_if 0))
1109						(else
1110							(global.get 0)
1111							(global.get 0)
1112							(drop)
1113							(br_if 1)))
1114					(global.get 0)
1115					(drop))
1116				(global.get 0)))
1117		"#;
1118        expected = r#"
1119		(module
1120			(func (result i32)
1121				(call 0 (i32.const 3))
1122				(global.get 0)
1123				(loop
1124					(call 0 (i32.const 4))
1125					(global.get 0)
1126					(if
1127						(then
1128							(call 0 (i32.const 2))
1129							(global.get 0)
1130							(br_if 0))
1131						(else
1132							(call 0 (i32.const 4))
1133							(global.get 0)
1134							(global.get 0)
1135							(drop)
1136							(br_if 1)))
1137					(global.get 0)
1138					(drop))
1139				(global.get 0)))
1140		"#
1141    }
1142
1143    test_gas_counter_injection! {
1144        name = return_from_func;
1145        input = r#"
1146		(module
1147			(func (result i32)
1148				(global.get 0)
1149				(if
1150					(then
1151						(return)))
1152				(global.get 0)))
1153		"#;
1154        expected = r#"
1155		(module
1156			(func (result i32)
1157				(call 0 (i32.const 2))
1158				(global.get 0)
1159				(if
1160					(then
1161						(call 0 (i32.const 1))
1162						(return)))
1163				(call 0 (i32.const 1))
1164				(global.get 0)))
1165		"#
1166    }
1167
1168    test_gas_counter_injection! {
1169        name = branch_from_if_not_else;
1170        input = r#"
1171		(module
1172			(func (result i32)
1173				(global.get 0)
1174				(block
1175					(global.get 0)
1176					(if
1177						(then (br 1))
1178						(else (br 0)))
1179					(global.get 0)
1180					(drop))
1181				(global.get 0)))
1182		"#;
1183        expected = r#"
1184		(module
1185			(func (result i32)
1186				(call 0 (i32.const 5))
1187				(global.get 0)
1188				(block
1189					(global.get 0)
1190					(if
1191						(then
1192							(call 0 (i32.const 1))
1193							(br 1))
1194						(else
1195							(call 0 (i32.const 1))
1196							(br 0)))
1197					(call 0 (i32.const 2))
1198					(global.get 0)
1199					(drop))
1200				(global.get 0)))
1201		"#
1202    }
1203
1204    test_gas_counter_injection! {
1205        name = empty_loop;
1206        input = r#"
1207		(module
1208			(func
1209				(loop
1210					(br 0)
1211				)
1212				unreachable
1213			)
1214		)
1215		"#;
1216        expected = r#"
1217		(module
1218			(func
1219				(call 0 (i32.const 2))
1220				(loop
1221					(call 0 (i32.const 1))
1222					(br 0)
1223				)
1224				unreachable
1225			)
1226		)
1227		"#
1228    }
1229}