owasm_utils/stack_height/
mod.rs

1//! The pass that tries to make stack overflows deterministic, by introducing
2//! an upper bound of the stack size.
3//!
4//! This pass introduces a global mutable variable to track stack height,
5//! and instruments all calls with preamble and postamble.
6//!
7//! Stack height is increased prior the call. Otherwise, the check would
8//! be made after the stack frame is allocated.
9//!
10//! The preamble is inserted before the call. It increments
11//! the global stack height variable with statically determined "stack cost"
12//! of the callee. If after the increment the stack height exceeds
13//! the limit (specified by the `rules`) then execution traps.
14//! Otherwise, the call is executed.
15//!
16//! The postamble is inserted after the call. The purpose of the postamble is to decrease
17//! the stack height by the "stack cost" of the callee function.
18//!
19//! Note, that we can't instrument all possible ways to return from the function. The simplest
20//! example would be a trap issued by the host function.
21//! That means stack height global won't be equal to zero upon the next execution after such trap.
22//!
23//! # Thunks
24//!
25//! Because stack height is increased prior the call few problems arises:
26//!
27//! - Stack height isn't increased upon an entry to the first function, i.e. exported function.
28//! - It is statically unknown what function will be invoked in an indirect call.
29//!
30//! The solution for this problems is to generate a intermediate functions, called 'thunks', which
31//! will increase before and decrease the stack height after the call to original function, and
32//! then make exported function and table entries to point to a corresponding thunks.
33//!
34//! # Stack cost
35//!
36//! Stack cost of the function is calculated as a sum of it's locals
37//! and the maximal height of the value stack.
38//!
39//! All values are treated equally, as they have the same size.
40//!
41//! The rationale for this it makes it possible to use this very naive wasm executor, that is:
42//!
43//! - values are implemented by a union, so each value takes a size equal to
44//!   the size of the largest possible value type this union can hold. (In MVP it is 8 bytes)
45//! - each value from the value stack is placed on the native stack.
46//! - each local variable and function argument is placed on the native stack.
47//! - arguments pushed by the caller are copied into callee stack rather than shared
48//!   between the frames.
49//! - upon entry into the function entire stack frame is allocated.
50
51use std::string::String;
52use std::vec::Vec;
53
54use parity_wasm::elements::{self, Type};
55use parity_wasm::builder;
56
57/// Macro to generate preamble and postamble.
58macro_rules! instrument_call {
59	($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
60		use $crate::parity_wasm::elements::Instruction::*;
61		[
62			// stack_height += stack_cost(F)
63			GetGlobal($stack_height_global_idx),
64			I32Const($callee_stack_cost),
65			I32Add,
66			SetGlobal($stack_height_global_idx),
67			// if stack_counter > LIMIT: unreachable
68			GetGlobal($stack_height_global_idx),
69			I32Const($stack_limit as i32),
70			I32GtU,
71			If(elements::BlockType::NoResult),
72			Unreachable,
73			End,
74			// Original call
75			Call($callee_idx),
76			// stack_height -= stack_cost(F)
77			GetGlobal($stack_height_global_idx),
78			I32Const($callee_stack_cost),
79			I32Sub,
80			SetGlobal($stack_height_global_idx),
81		]
82	}};
83}
84
85mod max_height;
86mod thunk;
87
88/// Error that occured during processing the module.
89///
90/// This means that the module is invalid.
91#[derive(Debug)]
92pub struct Error(String);
93
94pub(crate) struct Context {
95	stack_height_global_idx: Option<u32>,
96	func_stack_costs: Option<Vec<u32>>,
97	stack_limit: u32,
98}
99
100impl Context {
101	/// Returns index in a global index space of a stack_height global variable.
102	///
103	/// Panics if it haven't generated yet.
104	fn stack_height_global_idx(&self) -> u32 {
105		self.stack_height_global_idx.expect(
106			"stack_height_global_idx isn't yet generated;
107			Did you call `inject_stack_counter_global`",
108		)
109	}
110
111	/// Returns `stack_cost` for `func_idx`.
112	///
113	/// Panics if stack costs haven't computed yet or `func_idx` is greater
114	/// than the last function index.
115	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
116		self.func_stack_costs
117			.as_ref()
118			.expect(
119				"func_stack_costs isn't yet computed;
120				Did you call `compute_stack_costs`?",
121			)
122			.get(func_idx as usize)
123			.cloned()
124	}
125
126	/// Returns stack limit specified by the rules.
127	fn stack_limit(&self) -> u32 {
128		self.stack_limit
129	}
130}
131
132/// Instrument a module with stack height limiter.
133///
134/// See module-level documentation for more details.
135///
136/// # Errors
137///
138/// Returns `Err` if module is invalid and can't be
139pub fn inject_limiter(
140	mut module: elements::Module,
141	stack_limit: u32,
142) -> Result<elements::Module, Error> {
143	let mut ctx = Context {
144		stack_height_global_idx: None,
145		func_stack_costs: None,
146		stack_limit,
147	};
148
149	generate_stack_height_global(&mut ctx, &mut module);
150	compute_stack_costs(&mut ctx, &module)?;
151	instrument_functions(&mut ctx, &mut module)?;
152	let module = thunk::generate_thunks(&mut ctx, module)?;
153
154	Ok(module)
155}
156
157/// Generate a new global that will be used for tracking current stack height.
158fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) {
159	let global_entry = builder::global()
160		.value_type()
161		.i32()
162		.mutable()
163		.init_expr(elements::Instruction::I32Const(0))
164		.build();
165
166	// Try to find an existing global section.
167	for section in module.sections_mut() {
168		if let elements::Section::Global(ref mut gs) = *section {
169			gs.entries_mut().push(global_entry);
170
171			let stack_height_global_idx = (gs.entries().len() as u32) - 1;
172			ctx.stack_height_global_idx = Some(stack_height_global_idx);
173			return;
174		}
175	}
176
177	// Existing section not found, create one!
178	module.sections_mut().push(elements::Section::Global(
179		elements::GlobalSection::with_entries(vec![global_entry]),
180	));
181	ctx.stack_height_global_idx = Some(0);
182}
183
184/// Calculate stack costs for all functions.
185///
186/// Returns a vector with a stack cost for each function, including imports.
187fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> {
188	let func_imports = module.import_count(elements::ImportCountType::Function);
189	let mut func_stack_costs = vec![0; module.functions_space()];
190	// TODO: optimize!
191	for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() {
192		// We can't calculate stack_cost of the import functions.
193		if func_idx >= func_imports {
194			*func_stack_cost = compute_stack_cost(func_idx as u32, &module)?;
195		}
196	}
197
198	ctx.func_stack_costs = Some(func_stack_costs);
199	Ok(())
200}
201
202/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
203/// number of arguments plus number of local variables) and the maximal stack
204/// height.
205fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
206	// To calculate the cost of a function we need to convert index from
207	// function index space to defined function spaces.
208	let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
209	let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
210		Error("This should be a index of a defined function".into())
211	})?;
212
213	let code_section = module.code_section().ok_or_else(|| {
214		Error("Due to validation code section should exists".into())
215	})?;
216	let body = &code_section
217		.bodies()
218		.get(defined_func_idx as usize)
219		.ok_or_else(|| Error("Function body is out of bounds".into()))?;
220	let locals_count = body.locals().len() as u32;
221
222	let max_stack_height =
223		max_height::compute(
224			defined_func_idx,
225			module
226		)?;
227
228	Ok(locals_count + max_stack_height)
229}
230
231fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
232	for section in module.sections_mut() {
233		if let elements::Section::Code(ref mut code_section) = *section {
234			for func_body in code_section.bodies_mut() {
235				let mut opcodes = func_body.code_mut();
236				instrument_function(ctx, opcodes)?;
237			}
238		}
239	}
240	Ok(())
241}
242
243/// This function searches `call` instructions and wrap each call
244/// with preamble and postamble.
245///
246/// Before:
247///
248/// ```text
249/// get_local 0
250/// get_local 1
251/// call 228
252/// drop
253/// ```
254///
255/// After:
256///
257/// ```text
258/// get_local 0
259/// get_local 1
260///
261/// < ... preamble ... >
262///
263/// call 228
264///
265/// < .. postamble ... >
266///
267/// drop
268/// ```
269fn instrument_function(
270	ctx: &mut Context,
271	instructions: &mut elements::Instructions,
272) -> Result<(), Error> {
273	use parity_wasm::elements::Instruction::*;
274
275	let mut cursor = 0;
276	loop {
277		if cursor >= instructions.elements().len() {
278			break;
279		}
280
281		enum Action {
282			InstrumentCall {
283				callee_idx: u32,
284				callee_stack_cost: u32,
285			},
286			Nop,
287		}
288
289		let action: Action = {
290			let instruction = &instructions.elements()[cursor];
291			match *instruction {
292				Call(ref callee_idx) => {
293					let callee_stack_cost = ctx
294						.stack_cost(*callee_idx)
295						.ok_or_else(||
296							Error(
297								format!("Call to function that out-of-bounds: {}", callee_idx)
298							)
299						)?;
300
301					// Instrument only calls to a functions which stack_cost is
302					// non-zero.
303					if callee_stack_cost > 0 {
304						Action::InstrumentCall {
305							callee_idx: *callee_idx,
306							callee_stack_cost,
307						}
308					} else {
309						Action::Nop
310					}
311				},
312				_ => Action::Nop,
313			}
314		};
315
316		match action {
317			// We need to wrap a `call idx` instruction
318			// with a code that adjusts stack height counter
319			// and then restores it.
320			Action::InstrumentCall { callee_idx, callee_stack_cost } => {
321				let new_seq = instrument_call!(
322					callee_idx,
323					callee_stack_cost as i32,
324					ctx.stack_height_global_idx(),
325					ctx.stack_limit()
326				);
327
328				// Replace the original `call idx` instruction with
329				// a wrapped call sequence.
330				//
331				// To splice actually take a place, we need to consume iterator
332				// splice returns. So we just `count()` it.
333				let _ = instructions
334					.elements_mut()
335					.splice(cursor..(cursor + 1), new_seq.iter().cloned())
336					.count();
337
338				// Advance cursor to be after the inserted sequence.
339				cursor += new_seq.len();
340			}
341			// Do nothing for other instructions.
342			_ => {
343				cursor += 1;
344			}
345		}
346	}
347
348	Ok(())
349}
350
351fn resolve_func_type(
352	func_idx: u32,
353	module: &elements::Module,
354) -> Result<&elements::FunctionType, Error> {
355	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
356	let functions = module
357		.function_section()
358		.map(|fs| fs.entries())
359		.unwrap_or(&[]);
360
361	let func_imports = module.import_count(elements::ImportCountType::Function);
362	let sig_idx = if func_idx < func_imports as u32 {
363		module
364			.import_section()
365			.expect("function import count is not zero; import section must exists; qed")
366			.entries()
367			.iter()
368			.filter_map(|entry| match *entry.external() {
369				elements::External::Function(ref idx) => Some(*idx),
370				_ => None,
371			})
372			.nth(func_idx as usize)
373			.expect(
374				"func_idx is less than function imports count;
375				nth function import must be `Some`;
376				qed",
377			)
378	} else {
379		functions
380			.get(func_idx as usize - func_imports)
381			.ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
382			.type_ref()
383	};
384	let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| {
385		Error(format!(
386			"Signature {} (specified by func {}) isn't defined",
387			sig_idx, func_idx
388		))
389	})?;
390	Ok(ty)
391}
392
393#[cfg(test)]
394mod tests {
395	extern crate wabt;
396	use parity_wasm::elements;
397	use super::*;
398
399	fn parse_wat(source: &str) -> elements::Module {
400		elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
401			.expect("Failed to deserialize the module")
402	}
403
404	fn validate_module(module: elements::Module) {
405		let binary = elements::serialize(module).expect("Failed to serialize");
406		wabt::Module::read_binary(&binary, &Default::default())
407			.expect("Wabt failed to read final binary")
408			.validate()
409			.expect("Invalid module");
410	}
411
412	#[test]
413	fn test_with_params_and_result() {
414		let module = parse_wat(
415			r#"
416(module
417  (func (export "i32.add") (param i32 i32) (result i32)
418    get_local 0
419	get_local 1
420	i32.add
421  )
422)
423"#,
424		);
425
426		let module = inject_limiter(module, 1024)
427			.expect("Failed to inject stack counter");
428		validate_module(module);
429	}
430}