casper_wasm_utils/
stack_height.rs

1//! The pass that tries to make stack overflows deterministic, by introducing
2//! an upper bound of the stack size.
3//!
4//! This pass introduces a global mutable variable to track stack height,
5//! and instruments all calls with preamble and postamble.
6//!
7//! Stack height is increased prior the call. Otherwise, the check would
8//! be made after the stack frame is allocated.
9//!
10//! The preamble is inserted before the call. It increments
11//! the global stack height variable with statically determined "stack cost"
12//! of the callee. If after the increment the stack height exceeds
13//! the limit (specified by the `rules`) then execution traps.
14//! Otherwise, the call is executed.
15//!
16//! The postamble is inserted after the call. The purpose of the postamble is to decrease
17//! the stack height by the "stack cost" of the callee function.
18//!
19//! Note, that we can't instrument all possible ways to return from the function. The simplest
20//! example would be a trap issued by the host function.
21//! That means stack height global won't be equal to zero upon the next execution after such trap.
22//!
23//! # Thunks
24//!
25//! Because stack height is increased prior the call few problems arises:
26//!
27//! - Stack height isn't increased upon an entry to the first function, i.e. exported function.
28//! - Start function is executed externally (similar to exported functions).
29//! - It is statically unknown what function will be invoked in an indirect call.
30//!
31//! The solution for this problems is to generate a intermediate functions, called 'thunks', which
32//! will increase before and decrease the stack height after the call to original function, and
33//! then make exported function and table entries, start section to point to a corresponding thunks.
34//!
35//! # Stack cost
36//!
37//! Stack cost of the function is calculated as a sum of it's locals
38//! and the maximal height of the value stack.
39//!
40//! All values are treated equally, as they have the same size.
41//!
42//! The rationale is that this makes it possible to use the following very naive wasm executor:
43//!
44//! - values are implemented by a union, so each value takes a size equal to
45//!   the size of the largest possible value type this union can hold. (In MVP it is 8 bytes)
46//! - each value from the value stack is placed on the native stack.
47//! - each local variable and function argument is placed on the native stack.
48//! - arguments pushed by the caller are copied into callee stack rather than shared
49//!   between the frames.
50//! - upon entry into the function entire stack frame is allocated.
51
52use crate::std::{mem, string::String, vec::Vec};
53
54use casper_wasm::{
55    builder,
56    elements::{self, Instruction, Instructions, Type},
57};
58
59/// Const function to generate preamble and postamble.
60const fn instrument_call(
61    callee_idx: u32,
62    callee_stack_cost: u32,
63    stack_height_global_idx: u32,
64    stack_limit: u32,
65) -> [Instruction; INSTRUMENT_CALL_LENGTH] {
66    use casper_wasm::elements::Instruction::*;
67    [
68        // stack_height += stack_cost(F)
69        GetGlobal(stack_height_global_idx),
70        I32Const(callee_stack_cost as i32),
71        I32Add,
72        SetGlobal(stack_height_global_idx),
73        // if stack_counter > LIMIT: unreachable
74        GetGlobal(stack_height_global_idx),
75        I32Const(stack_limit as i32),
76        I32GtU,
77        If(elements::BlockType::NoResult),
78        Unreachable,
79        End,
80        // Original call
81        Call(callee_idx),
82        // stack_height -= stack_cost(F)
83        GetGlobal(stack_height_global_idx),
84        I32Const(callee_stack_cost as i32),
85        I32Sub,
86        SetGlobal(stack_height_global_idx),
87    ]
88}
89
90const INSTRUMENT_CALL_LENGTH: usize = 15;
91
92mod max_height;
93mod thunk;
94
95/// Error that occured during processing the module.
96///
97/// This means that the module is invalid.
98#[derive(Debug)]
99#[allow(unused)]
100pub struct Error(String);
101
102pub(crate) struct Context {
103    stack_height_global_idx: u32,
104    func_stack_costs: Vec<u32>,
105    stack_limit: u32,
106}
107
108impl Context {
109    /// Returns index in a global index space of a stack_height global variable.
110    fn stack_height_global_idx(&self) -> u32 {
111        self.stack_height_global_idx
112    }
113
114    /// Returns `stack_cost` for `func_idx`.
115    fn stack_cost(&self, func_idx: u32) -> Option<u32> {
116        self.func_stack_costs.get(func_idx as usize).cloned()
117    }
118
119    /// Returns stack limit specified by the rules.
120    fn stack_limit(&self) -> u32 {
121        self.stack_limit
122    }
123}
124
125/// Instrument a module with stack height limiter.
126///
127/// See module-level documentation for more details.
128///
129/// # Errors
130///
131/// Returns `Err` if module is invalid and can't be
132pub fn inject_limiter(
133    mut module: elements::Module,
134    stack_limit: u32,
135) -> Result<elements::Module, Error> {
136    let mut ctx = Context {
137        stack_height_global_idx: generate_stack_height_global(&mut module),
138        func_stack_costs: compute_stack_costs(&module)?,
139        stack_limit,
140    };
141
142    instrument_functions(&mut ctx, &mut module)?;
143    let module = thunk::generate_thunks(&mut ctx, module)?;
144
145    Ok(module)
146}
147
148/// Generate a new global that will be used for tracking current stack height.
149fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
150    let global_entry = builder::global()
151        .value_type()
152        .i32()
153        .mutable()
154        .init_expr(Instruction::I32Const(0))
155        .build();
156
157    // Try to find an existing global section.
158    for section in module.sections_mut() {
159        if let elements::Section::Global(gs) = section {
160            gs.entries_mut().push(global_entry);
161            return (gs.entries().len() as u32) - 1;
162        }
163    }
164
165    // Existing section not found, create one!
166    module.sections_mut().push(elements::Section::Global(
167        elements::GlobalSection::with_entries(vec![global_entry]),
168    ));
169    0
170}
171
172/// Calculate stack costs for all functions.
173///
174/// Returns a vector with a stack cost for each function, including imports.
175fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
176    let func_imports = module.import_count(elements::ImportCountType::Function);
177
178    // TODO: optimize!
179    (0..module.functions_space())
180        .map(|func_idx| {
181            if func_idx < func_imports {
182                // We can't calculate stack_cost of the import functions.
183                Ok(0)
184            } else {
185                compute_stack_cost(func_idx as u32, module)
186            }
187        })
188        .collect()
189}
190
191/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
192/// number of arguments plus number of local variables) and the maximal stack
193/// height.
194fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
195    // To calculate the cost of a function we need to convert index from
196    // function index space to defined function spaces.
197    let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
198    let defined_func_idx = func_idx
199        .checked_sub(func_imports)
200        .ok_or_else(|| Error("This should be a index of a defined function".into()))?;
201
202    let code_section = module
203        .code_section()
204        .ok_or_else(|| Error("Due to validation code section should exists".into()))?;
205    let body = &code_section
206        .bodies()
207        .get(defined_func_idx as usize)
208        .ok_or_else(|| Error("Function body is out of bounds".into()))?;
209
210    let mut locals_count: u32 = 0;
211    for local_group in body.locals() {
212        locals_count = locals_count
213            .checked_add(local_group.count())
214            .ok_or_else(|| Error("Overflow in local count".into()))?;
215    }
216
217    let max_stack_height = max_height::compute(defined_func_idx, module)?;
218
219    locals_count
220        .checked_add(max_stack_height)
221        .ok_or_else(|| Error("Overflow in adding locals_count and max_stack_height".into()))
222}
223
224fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
225    for section in module.sections_mut() {
226        if let elements::Section::Code(code_section) = section {
227            for func_body in code_section.bodies_mut() {
228                let opcodes = func_body.code_mut();
229                instrument_function(ctx, opcodes)?;
230            }
231        }
232    }
233    Ok(())
234}
235
236/// This function searches `call` instructions and wrap each call
237/// with preamble and postamble.
238///
239/// Before:
240///
241/// ```text
242/// get_local 0
243/// get_local 1
244/// call 228
245/// drop
246/// ```
247///
248/// After:
249///
250/// ```text
251/// get_local 0
252/// get_local 1
253///
254/// < ... preamble ... >
255///
256/// call 228
257///
258/// < .. postamble ... >
259///
260/// drop
261/// ```
262fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
263    use Instruction::*;
264
265    struct InstrumentCall {
266        offset: usize,
267        callee: u32,
268        cost: u32,
269    }
270
271    let calls: Vec<_> = func
272        .elements()
273        .iter()
274        .enumerate()
275        .filter_map(|(offset, instruction)| {
276            if let Call(callee) = *instruction {
277                ctx.stack_cost(callee)
278                    .filter(|&cost| cost > 0)
279                    .map(|cost| InstrumentCall {
280                        callee,
281                        offset,
282                        cost,
283                    })
284            } else {
285                None
286            }
287        })
288        .collect();
289
290    // The `instrumented_call!` contains the call itself. This is why we need to subtract one.
291    let len = func.elements().len() + calls.len() * (INSTRUMENT_CALL_LENGTH - 1);
292    let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
293    let new_instrs = func.elements_mut();
294
295    let mut calls = calls.into_iter().peekable();
296    for (original_pos, instr) in original_instrs.into_iter().enumerate() {
297        // whether there is some call instruction at this position that needs to be instrumented
298        let did_instrument = if let Some(call) = calls.peek() {
299            if call.offset == original_pos {
300                let new_seq = instrument_call(
301                    call.callee,
302                    call.cost,
303                    ctx.stack_height_global_idx(),
304                    ctx.stack_limit(),
305                );
306                new_instrs.extend(new_seq);
307                true
308            } else {
309                false
310            }
311        } else {
312            false
313        };
314
315        if did_instrument {
316            calls.next();
317        } else {
318            new_instrs.push(instr);
319        }
320    }
321
322    if calls.next().is_some() {
323        return Err(Error("Not all calls were used".into()));
324    }
325
326    Ok(())
327}
328
329fn resolve_func_type(
330    func_idx: u32,
331    module: &elements::Module,
332) -> Result<&elements::FunctionType, Error> {
333    let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
334    let functions = module
335        .function_section()
336        .map(|fs| fs.entries())
337        .unwrap_or(&[]);
338
339    let func_imports = module.import_count(elements::ImportCountType::Function);
340    let sig_idx = if func_idx < func_imports as u32 {
341        module
342            .import_section()
343            .expect("function import count is not zero; import section must exists; qed")
344            .entries()
345            .iter()
346            .filter_map(|entry| match entry.external() {
347                elements::External::Function(idx) => Some(*idx),
348                _ => None,
349            })
350            .nth(func_idx as usize)
351            .expect(
352                "func_idx is less than function imports count;
353				nth function import must be `Some`;
354				qed",
355            )
356    } else {
357        functions
358            .get(func_idx as usize - func_imports)
359            .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
360            .type_ref()
361    };
362    let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
363        Error(format!(
364            "Signature {} (specified by func {}) isn't defined",
365            sig_idx, func_idx
366        ))
367    })?;
368    Ok(ty)
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use casper_wasm::elements;
375
376    fn parse_wat(source: &str) -> elements::Module {
377        elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
378            .expect("Failed to deserialize the module")
379    }
380
381    fn validate_module(module: elements::Module) {
382        let binary = elements::serialize(module).expect("Failed to serialize");
383        wabt::Module::read_binary(binary, &Default::default())
384            .expect("Wabt failed to read final binary")
385            .validate()
386            .expect("Invalid module");
387    }
388
389    #[test]
390    fn test_with_params_and_result() {
391        let module = parse_wat(
392            r#"
393(module
394	(func (export "i32.add") (param i32 i32) (result i32)
395		get_local 0
396	get_local 1
397	i32.add
398	)
399)
400"#,
401        );
402
403        let module = inject_limiter(module, 1024).expect("Failed to inject stack counter");
404        validate_module(module);
405    }
406}