gear_pwasm_utils/stack_height/
mod.rs

1//! The pass that tries to make stack overflows deterministic, by introducing
2//! an upper bound of the stack size.
3//!
4//! This pass introduces a global mutable variable to track stack height,
5//! and instruments all calls with preamble and postamble.
6//!
7//! Stack height is increased prior the call. Otherwise, the check would
8//! be made after the stack frame is allocated.
9//!
10//! The preamble is inserted before the call. It increments
11//! the global stack height variable with statically determined "stack cost"
12//! of the callee. If after the increment the stack height exceeds
13//! the limit (specified by the `rules`) then execution traps.
14//! Otherwise, the call is executed.
15//!
16//! The postamble is inserted after the call. The purpose of the postamble is to decrease
17//! the stack height by the "stack cost" of the callee function.
18//!
19//! Note, that we can't instrument all possible ways to return from the function. The simplest
20//! example would be a trap issued by the host function.
21//! That means stack height global won't be equal to zero upon the next execution after such trap.
22//!
23//! # Thunks
24//!
25//! Because stack height is increased prior the call few problems arises:
26//!
27//! - Stack height isn't increased upon an entry to the first function, i.e. exported function.
28//! - Start function is executed externally (similar to exported functions).
29//! - It is statically unknown what function will be invoked in an indirect call.
30//!
31//! The solution for this problems is to generate a intermediate functions, called 'thunks', which
32//! will increase before and decrease the stack height after the call to original function, and
33//! then make exported function and table entries, start section to point to a corresponding thunks.
34//!
35//! # Stack cost
36//!
37//! Stack cost of the function is calculated as a sum of it's locals
38//! and the maximal height of the value stack.
39//!
40//! All values are treated equally, as they have the same size.
41//!
42//! The rationale is that this makes it possible to use the following very naive wasm executor:
43//!
44//! - values are implemented by a union, so each value takes a size equal to
45//!   the size of the largest possible value type this union can hold. (In MVP it is 8 bytes)
46//! - each value from the value stack is placed on the native stack.
47//! - each local variable and function argument is placed on the native stack.
48//! - arguments pushed by the caller are copied into callee stack rather than shared
49//!   between the frames.
50//! - upon entry into the function entire stack frame is allocated.
51
52use crate::std::{mem, string::String, vec::Vec};
53
54use parity_wasm::{
55	builder,
56	elements::{self, Instruction, Instructions, Type},
57};
58
59/// Macro to generate preamble and postamble.
60macro_rules! instrument_call {
61	($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
62		use $crate::parity_wasm::elements::Instruction::*;
63		[
64			// stack_height += stack_cost(F)
65			GetGlobal($stack_height_global_idx),
66			I32Const($callee_stack_cost),
67			I32Add,
68			SetGlobal($stack_height_global_idx),
69			// if stack_counter > LIMIT: unreachable
70			GetGlobal($stack_height_global_idx),
71			I32Const($stack_limit as i32),
72			I32GtU,
73			If(elements::BlockType::NoResult),
74			Unreachable,
75			End,
76			// Original call
77			Call($callee_idx),
78			// stack_height -= stack_cost(F)
79			GetGlobal($stack_height_global_idx),
80			I32Const($callee_stack_cost),
81			I32Sub,
82			SetGlobal($stack_height_global_idx),
83		]
84	}};
85}
86
87mod max_height;
88mod thunk;
89
90/// Error that occured during processing the module.
91///
92/// This means that the module is invalid.
93#[derive(Debug)]
94pub struct Error(String);
95
96pub(crate) struct Context {
97	stack_height_global_idx: u32,
98	func_stack_costs: Vec<u32>,
99	stack_limit: u32,
100}
101
102impl Context {
103	/// Returns index in a global index space of a stack_height global variable.
104	fn stack_height_global_idx(&self) -> u32 {
105		self.stack_height_global_idx
106	}
107
108	/// Returns `stack_cost` for `func_idx`.
109	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
110		self.func_stack_costs.get(func_idx as usize).cloned()
111	}
112
113	/// Returns stack limit specified by the rules.
114	fn stack_limit(&self) -> u32 {
115		self.stack_limit
116	}
117}
118
119/// Instrument a module with stack height limiter.
120///
121/// See module-level documentation for more details.
122///
123/// # Errors
124///
125/// Returns `Err` if module is invalid and can't be
126pub fn inject_limiter(
127	mut module: elements::Module,
128	stack_limit: u32,
129) -> Result<elements::Module, Error> {
130	let mut ctx = Context {
131		stack_height_global_idx: generate_stack_height_global(&mut module),
132		func_stack_costs: compute_stack_costs(&module)?,
133		stack_limit,
134	};
135
136	instrument_functions(&mut ctx, &mut module)?;
137	let module = thunk::generate_thunks(&mut ctx, module)?;
138
139	Ok(module)
140}
141
142/// Generate a new global that will be used for tracking current stack height.
143fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
144	let global_entry = builder::global()
145		.value_type()
146		.i32()
147		.mutable()
148		.init_expr(Instruction::I32Const(0))
149		.build();
150
151	// Try to find an existing global section.
152	for section in module.sections_mut() {
153		if let elements::Section::Global(gs) = section {
154			gs.entries_mut().push(global_entry);
155			return (gs.entries().len() as u32) - 1
156		}
157	}
158
159	// Existing section not found, create one!
160	module
161		.sections_mut()
162		.push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry])));
163	0
164}
165
166/// Calculate stack costs for all functions.
167///
168/// Returns a vector with a stack cost for each function, including imports.
169fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
170	let func_imports = module.import_count(elements::ImportCountType::Function);
171
172	// TODO: optimize!
173	(0..module.functions_space())
174		.map(|func_idx| {
175			if func_idx < func_imports {
176				// We can't calculate stack_cost of the import functions.
177				Ok(0)
178			} else {
179				compute_stack_cost(func_idx as u32, module)
180			}
181		})
182		.collect()
183}
184
185/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
186/// number of arguments plus number of local variables) and the maximal stack
187/// height.
188fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
189	// To calculate the cost of a function we need to convert index from
190	// function index space to defined function spaces.
191	let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
192	let defined_func_idx = func_idx
193		.checked_sub(func_imports)
194		.ok_or_else(|| Error("This should be a index of a defined function".into()))?;
195
196	let code_section = module
197		.code_section()
198		.ok_or_else(|| Error("Due to validation code section should exists".into()))?;
199	let body = &code_section
200		.bodies()
201		.get(defined_func_idx as usize)
202		.ok_or_else(|| Error("Function body is out of bounds".into()))?;
203
204	let mut locals_count: u32 = 0;
205	for local_group in body.locals() {
206		locals_count = locals_count
207			.checked_add(local_group.count())
208			.ok_or_else(|| Error("Overflow in local count".into()))?;
209	}
210
211	let max_stack_height = max_height::compute(defined_func_idx, module)?;
212
213	locals_count
214		.checked_add(max_stack_height)
215		.ok_or_else(|| Error("Overflow in adding locals_count and max_stack_height".into()))
216}
217
218fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
219	for section in module.sections_mut() {
220		if let elements::Section::Code(code_section) = section {
221			for func_body in code_section.bodies_mut() {
222				let opcodes = func_body.code_mut();
223				instrument_function(ctx, opcodes)?;
224			}
225		}
226	}
227	Ok(())
228}
229
230/// This function searches `call` instructions and wrap each call
231/// with preamble and postamble.
232///
233/// Before:
234///
235/// ```text
236/// get_local 0
237/// get_local 1
238/// call 228
239/// drop
240/// ```
241///
242/// After:
243///
244/// ```text
245/// get_local 0
246/// get_local 1
247///
248/// < ... preamble ... >
249///
250/// call 228
251///
252/// < .. postamble ... >
253///
254/// drop
255/// ```
256fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
257	use Instruction::*;
258
259	struct InstrumentCall {
260		offset: usize,
261		callee: u32,
262		cost: u32,
263	}
264
265	let calls: Vec<_> = func
266		.elements()
267		.iter()
268		.enumerate()
269		.filter_map(|(offset, instruction)| {
270			if let Call(callee) = instruction {
271				ctx.stack_cost(*callee).and_then(|cost| {
272					if cost > 0 {
273						Some(InstrumentCall { callee: *callee, offset, cost })
274					} else {
275						None
276					}
277				})
278			} else {
279				None
280			}
281		})
282		.collect();
283
284	// The `instrumented_call!` contains the call itself. This is why we need to subtract one.
285	let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
286	let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
287	let new_instrs = func.elements_mut();
288
289	let mut calls = calls.into_iter().peekable();
290	for (original_pos, instr) in original_instrs.into_iter().enumerate() {
291		// whether there is some call instruction at this position that needs to be instrumented
292		let did_instrument = if let Some(call) = calls.peek() {
293			if call.offset == original_pos {
294				let new_seq = instrument_call!(
295					call.callee,
296					call.cost as i32,
297					ctx.stack_height_global_idx(),
298					ctx.stack_limit()
299				);
300				new_instrs.extend(new_seq);
301				true
302			} else {
303				false
304			}
305		} else {
306			false
307		};
308
309		if did_instrument {
310			calls.next();
311		} else {
312			new_instrs.push(instr);
313		}
314	}
315
316	if calls.next().is_some() {
317		return Err(Error("Not all calls were used".into()))
318	}
319
320	Ok(())
321}
322
323fn resolve_func_type(
324	func_idx: u32,
325	module: &elements::Module,
326) -> Result<&elements::FunctionType, Error> {
327	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
328	let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
329
330	let func_imports = module.import_count(elements::ImportCountType::Function);
331	let sig_idx = if func_idx < func_imports as u32 {
332		module
333			.import_section()
334			.expect("function import count is not zero; import section must exists; qed")
335			.entries()
336			.iter()
337			.filter_map(|entry| match entry.external() {
338				elements::External::Function(idx) => Some(*idx),
339				_ => None,
340			})
341			.nth(func_idx as usize)
342			.expect(
343				"func_idx is less than function imports count;
344				nth function import must be `Some`;
345				qed",
346			)
347	} else {
348		functions
349			.get(func_idx as usize - func_imports)
350			.ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
351			.type_ref()
352	};
353	let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
354		Error(format!("Signature {} (specified by func {}) isn't defined", sig_idx, func_idx))
355	})?;
356	Ok(ty)
357}
358
359#[cfg(test)]
360mod tests {
361	use super::*;
362	use parity_wasm::elements;
363
364	fn parse_wat(source: &str) -> elements::Module {
365		elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
366			.expect("Failed to deserialize the module")
367	}
368
369	fn validate_module(module: elements::Module) {
370		let binary = elements::serialize(module).expect("Failed to serialize");
371		wabt::Module::read_binary(&binary, &Default::default())
372			.expect("Wabt failed to read final binary")
373			.validate()
374			.expect("Invalid module");
375	}
376
377	#[test]
378	fn test_with_params_and_result() {
379		let module = parse_wat(
380			r#"
381(module
382	(func (export "i32.add") (param i32 i32) (result i32)
383		get_local 0
384	get_local 1
385	i32.add
386	)
387)
388"#,
389		);
390
391		let module = inject_limiter(module, 1024).expect("Failed to inject stack counter");
392		validate_module(module);
393	}
394}