gwasm_instrument/stack_limiter/
mod.rs

1//! Contains the code for the stack height limiter instrumentation.
2
3use alloc::{vec, vec::Vec};
4use core::mem;
5use max_height::{MaxStackHeightCounter, MaxStackHeightCounterContext};
6use parity_wasm::{
7	builder,
8	elements::{self, FunctionType, Instruction, Instructions, Type},
9};
10
11mod max_height;
12mod thunk;
13
14pub(crate) struct Context {
15	stack_height_global_idx: u32,
16	func_stack_costs: Vec<u32>,
17	stack_limit: u32,
18}
19
20impl Context {
21	/// Returns index in a global index space of a stack_height global variable.
22	fn stack_height_global_idx(&self) -> u32 {
23		self.stack_height_global_idx
24	}
25
26	/// Returns `stack_cost` for `func_idx`.
27	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
28		self.func_stack_costs.get(func_idx as usize).cloned()
29	}
30
31	/// Returns stack limit specified by the rules.
32	fn stack_limit(&self) -> u32 {
33		self.stack_limit
34	}
35}
36
37/// Inject the instumentation that makes stack overflows deterministic, by introducing
38/// an upper bound of the stack size.
39///
40/// This pass introduces a global mutable variable to track stack height,
41/// and instruments all calls with preamble and postamble.
42///
43/// Stack height is increased prior the call. Otherwise, the check would
44/// be made after the stack frame is allocated.
45///
46/// The preamble is inserted before the call. It increments
47/// the global stack height variable with statically determined "stack cost"
48/// of the callee. If after the increment the stack height exceeds
49/// the limit (specified by the `rules`) then execution traps.
50/// Otherwise, the call is executed.
51///
52/// The postamble is inserted after the call. The purpose of the postamble is to decrease
53/// the stack height by the "stack cost" of the callee function.
54///
55/// Note, that we can't instrument all possible ways to return from the function. The simplest
56/// example would be a trap issued by the host function.
57/// That means stack height global won't be equal to zero upon the next execution after such trap.
58///
59/// # Thunks
60///
61/// Because stack height is increased prior the call few problems arises:
62///
63/// - Stack height isn't increased upon an entry to the first function, i.e. exported function.
64/// - Start function is executed externally (similar to exported functions).
65/// - It is statically unknown what function will be invoked in an indirect call.
66///
67/// The solution for this problems is to generate a intermediate functions, called 'thunks', which
68/// will increase before and decrease the stack height after the call to original function, and
69/// then make exported function and table entries, start section to point to a corresponding thunks.
70///
71/// # Stack cost
72///
73/// Stack cost of the function is calculated as a sum of it's locals
74/// and the maximal height of the value stack.
75///
76/// All values are treated equally, as they have the same size.
77///
78/// The rationale is that this makes it possible to use the following very naive wasm executor:
79///
80/// - values are implemented by a union, so each value takes a size equal to the size of the largest
81///   possible value type this union can hold. (In MVP it is 8 bytes)
82/// - each value from the value stack is placed on the native stack.
83/// - each local variable and function argument is placed on the native stack.
84/// - arguments pushed by the caller are copied into callee stack rather than shared between the
85///   frames.
86/// - upon entry into the function entire stack frame is allocated.
87pub fn inject(
88	module: elements::Module,
89	stack_limit: u32,
90) -> Result<elements::Module, &'static str> {
91	inject_with_config(
92		module,
93		InjectionConfig {
94			stack_limit,
95			injection_fn: |_| [Instruction::Unreachable],
96			stack_height_export_name: None,
97		},
98	)
99}
100
101/// Represents the injection configuration. See [`inject_with_config`] for more details.
102pub struct InjectionConfig<'a, I, F>
103where
104	I: IntoIterator<Item = Instruction>,
105	I::IntoIter: ExactSizeIterator + Clone,
106	F: Fn(&FunctionType) -> I,
107{
108	pub stack_limit: u32,
109	pub injection_fn: F,
110	pub stack_height_export_name: Option<&'a str>,
111}
112
113/// Same as the [`inject`] function, but allows to configure exit instructions when the stack limit
114/// is reached and the export name of the stack height global.
115pub fn inject_with_config<I: IntoIterator<Item = Instruction>>(
116	mut module: elements::Module,
117	injection_config: InjectionConfig<'_, I, impl Fn(&FunctionType) -> I>,
118) -> Result<elements::Module, &'static str>
119where
120	I::IntoIter: ExactSizeIterator + Clone,
121{
122	let InjectionConfig { stack_limit, injection_fn, stack_height_export_name } = injection_config;
123	let mut ctx = Context {
124		stack_height_global_idx: generate_stack_height_global(
125			&mut module,
126			stack_height_export_name,
127		),
128		func_stack_costs: compute_stack_costs(&module, &injection_fn)?,
129		stack_limit,
130	};
131
132	instrument_functions(&mut ctx, &mut module, &injection_fn)?;
133	let module = thunk::generate_thunks(&mut ctx, module, &injection_fn)?;
134
135	Ok(module)
136}
137
138/// Generate a new global that will be used for tracking current stack height.
139fn generate_stack_height_global(
140	module: &mut elements::Module,
141	stack_height_export_name: Option<&str>,
142) -> u32 {
143	let global_entry = builder::global()
144		.value_type()
145		.i32()
146		.mutable()
147		.init_expr(Instruction::I32Const(0))
148		.build();
149
150	let stack_height_global_idx = match module.global_section_mut() {
151		Some(global_section) => {
152			global_section.entries_mut().push(global_entry);
153			(global_section.entries().len() as u32) - 1
154		},
155		None => {
156			module.sections_mut().push(elements::Section::Global(
157				elements::GlobalSection::with_entries(vec![global_entry]),
158			));
159			0
160		},
161	};
162
163	if let Some(stack_height_export_name) = stack_height_export_name {
164		let export_entry = elements::ExportEntry::new(
165			stack_height_export_name.into(),
166			elements::Internal::Global(stack_height_global_idx),
167		);
168
169		match module.export_section_mut() {
170			Some(export_section) => {
171				export_section.entries_mut().push(export_entry);
172			},
173			None => {
174				module.sections_mut().push(elements::Section::Export(
175					elements::ExportSection::with_entries(vec![export_entry]),
176				));
177			},
178		}
179	}
180
181	stack_height_global_idx
182}
183
184/// Calculate stack costs for all functions.
185///
186/// Returns a vector with a stack cost for each function, including imports.
187fn compute_stack_costs<I: IntoIterator<Item = Instruction>>(
188	module: &elements::Module,
189	injection_fn: impl Fn(&FunctionType) -> I,
190) -> Result<Vec<u32>, &'static str>
191where
192	I::IntoIter: ExactSizeIterator + Clone,
193{
194	let functions_space = module
195		.functions_space()
196		.try_into()
197		.map_err(|_| "Can't convert functions space to u32")?;
198
199	// Don't create context when there are no functions (this will fail).
200	if functions_space == 0 {
201		return Ok(Vec::new())
202	}
203
204	// This context already contains the module, number of imports and section references.
205	// So we can use it to optimize access to these objects.
206	let context: MaxStackHeightCounterContext = module.try_into()?;
207
208	(0..functions_space)
209		.map(|func_idx| {
210			if func_idx < context.func_imports {
211				// We can't calculate stack_cost of the import functions.
212				Ok(0)
213			} else {
214				compute_stack_cost(func_idx, context, &injection_fn)
215			}
216		})
217		.collect()
218}
219
220/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
221/// number of arguments plus number of local variables) and the maximal stack
222/// height.
223fn compute_stack_cost<I: IntoIterator<Item = Instruction>>(
224	func_idx: u32,
225	context: MaxStackHeightCounterContext,
226	injection_fn: impl Fn(&FunctionType) -> I,
227) -> Result<u32, &'static str>
228where
229	I::IntoIter: ExactSizeIterator + Clone,
230{
231	// To calculate the cost of a function we need to convert index from
232	// function index space to defined function spaces.
233	let defined_func_idx = func_idx
234		.checked_sub(context.func_imports)
235		.ok_or("This should be a index of a defined function")?;
236
237	let body = context
238		.code_section
239		.bodies()
240		.get(defined_func_idx as usize)
241		.ok_or("Function body is out of bounds")?;
242
243	let mut locals_count: u32 = 0;
244	for local_group in body.locals() {
245		locals_count =
246			locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
247	}
248
249	let max_stack_height = MaxStackHeightCounter::new_with_context(context, injection_fn)
250		.count_instrumented_calls(true)
251		.compute_for_defined_func(defined_func_idx)?;
252
253	locals_count
254		.checked_add(max_stack_height)
255		.ok_or("Overflow in adding locals_count and max_stack_height")
256}
257
258fn instrument_functions<I: IntoIterator<Item = Instruction>>(
259	ctx: &mut Context,
260	module: &mut elements::Module,
261	injection_fn: impl Fn(&FunctionType) -> I,
262) -> Result<(), &'static str>
263where
264	I::IntoIter: ExactSizeIterator + Clone,
265{
266	if ctx.func_stack_costs.is_empty() {
267		return Ok(())
268	}
269
270	// Func stack costs collection is not empty, so stack height counter has counted costs
271	// for module with non empty function and type sections.
272	let types = module.type_section().map(|ts| ts.types()).expect("checked earlier").to_vec();
273	let functions = module
274		.function_section()
275		.map(|fs| fs.entries())
276		.expect("checked earlier")
277		.to_vec();
278
279	if let Some(code_section) = module.code_section_mut() {
280		for (func_idx, func_body) in code_section.bodies_mut().iter_mut().enumerate() {
281			let opcodes = func_body.code_mut();
282
283			let signature_index = &functions[func_idx];
284			let signature = &types[signature_index.type_ref() as usize];
285			let Type::Function(signature) = signature;
286
287			instrument_function(ctx, opcodes, signature, &injection_fn)?;
288		}
289	}
290
291	Ok(())
292}
293
294/// This function searches `call` instructions and wrap each call
295/// with preamble and postamble.
296///
297/// Before:
298///
299/// ```text
300/// local.get 0
301/// local.get 1
302/// call 228
303/// drop
304/// ```
305///
306/// After:
307///
308/// ```text
309/// local.get 0
310/// local.get 1
311///
312/// < ... preamble ... >
313///
314/// call 228
315///
316/// < .. postamble ... >
317///
318/// drop
319/// ```
320fn instrument_function<I: IntoIterator<Item = Instruction>>(
321	ctx: &mut Context,
322	func: &mut Instructions,
323	signature: &FunctionType,
324	injection_fn: impl Fn(&FunctionType) -> I,
325) -> Result<(), &'static str>
326where
327	I::IntoIter: ExactSizeIterator + Clone,
328{
329	use Instruction::*;
330
331	struct InstrumentCall {
332		offset: usize,
333		callee: u32,
334		cost: u32,
335	}
336
337	let calls: Vec<_> = func
338		.elements()
339		.iter()
340		.enumerate()
341		.filter_map(|(offset, instruction)| {
342			if let Call(callee) = instruction {
343				ctx.stack_cost(*callee).and_then(|cost| {
344					if cost > 0 {
345						Some(InstrumentCall { callee: *callee, offset, cost })
346					} else {
347						None
348					}
349				})
350			} else {
351				None
352			}
353		})
354		.collect();
355
356	// To pre-allocate memory, we need to count `8 + N + 6 - 1`, i.e. `13 + N`.
357	// We need to subtract one because it is assumed that we already have the original call
358	// instruction in `func.elements()`. See `instrument_call` function for details.
359	let body_of_condition = injection_fn(signature).into_iter();
360	let len = func.elements().len() + calls.len() * (13 + body_of_condition.len());
361	let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
362	let new_instrs = func.elements_mut();
363
364	let mut calls = calls.into_iter().peekable();
365	for (original_pos, instr) in original_instrs.into_iter().enumerate() {
366		// whether there is some call instruction at this position that needs to be instrumented
367		let did_instrument = if let Some(call) = calls.peek() {
368			if call.offset == original_pos {
369				instrument_call(
370					new_instrs,
371					call.callee,
372					call.cost as i32,
373					ctx.stack_height_global_idx(),
374					ctx.stack_limit(),
375					body_of_condition.clone(),
376					[],
377				);
378				true
379			} else {
380				false
381			}
382		} else {
383			false
384		};
385
386		if did_instrument {
387			calls.next();
388		} else {
389			new_instrs.push(instr);
390		}
391	}
392
393	if calls.next().is_some() {
394		return Err("Not all calls were used")
395	}
396
397	Ok(())
398}
399
400/// This function generates preamble and postamble.
401fn instrument_call(
402	instructions: &mut Vec<Instruction>,
403	callee_idx: u32,
404	callee_stack_cost: i32,
405	stack_height_global_idx: u32,
406	stack_limit: u32,
407	body_of_condition: impl IntoIterator<Item = Instruction>,
408	arguments: impl IntoIterator<Item = Instruction>,
409) {
410	use Instruction::*;
411
412	// 8 + body_of_condition.len() + 1 instructions
413	generate_preamble(
414		instructions,
415		callee_stack_cost,
416		stack_height_global_idx,
417		stack_limit,
418		body_of_condition,
419	);
420
421	// arguments.len() instructions
422	instructions.extend(arguments);
423
424	// Original call, 1 instruction
425	instructions.push(Call(callee_idx));
426
427	// 4 instructions
428	generate_postamble(instructions, callee_stack_cost, stack_height_global_idx);
429}
430
431/// This function generates preamble.
432fn generate_preamble(
433	instructions: &mut Vec<Instruction>,
434	callee_stack_cost: i32,
435	stack_height_global_idx: u32,
436	stack_limit: u32,
437	body_of_condition: impl IntoIterator<Item = Instruction>,
438) {
439	use Instruction::*;
440
441	// 8 instructions
442	instructions.extend_from_slice(&[
443		// stack_height += stack_cost(F)
444		GetGlobal(stack_height_global_idx),
445		I32Const(callee_stack_cost),
446		I32Add,
447		SetGlobal(stack_height_global_idx),
448		// if stack_counter > LIMIT: unreachable or custom instructions
449		GetGlobal(stack_height_global_idx),
450		I32Const(stack_limit as i32),
451		I32GtU,
452		If(elements::BlockType::NoResult),
453	]);
454
455	// body_of_condition.len() instructions
456	instructions.extend(body_of_condition);
457
458	// 1 instruction
459	instructions.push(End);
460}
461
462/// This function generates postamble.
463#[inline]
464fn generate_postamble(
465	instructions: &mut Vec<Instruction>,
466	callee_stack_cost: i32,
467	stack_height_global_idx: u32,
468) {
469	use Instruction::*;
470
471	// 4 instructions
472	instructions.extend_from_slice(&[
473		// stack_height -= stack_cost(F)
474		GetGlobal(stack_height_global_idx),
475		I32Const(callee_stack_cost),
476		I32Sub,
477		SetGlobal(stack_height_global_idx),
478	]);
479}
480
481fn resolve_func_type(
482	func_idx: u32,
483	module: &elements::Module,
484) -> Result<&FunctionType, &'static str> {
485	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
486	let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
487
488	let func_imports = module.import_count(elements::ImportCountType::Function);
489	let sig_idx = if func_idx < func_imports as u32 {
490		module
491			.import_section()
492			.expect("function import count is not zero; import section must exists; qed")
493			.entries()
494			.iter()
495			.filter_map(|entry| match entry.external() {
496				elements::External::Function(idx) => Some(*idx),
497				_ => None,
498			})
499			.nth(func_idx as usize)
500			.expect(
501				"func_idx is less than function imports count;
502				nth function import must be `Some`;
503				qed",
504			)
505	} else {
506		functions
507			.get(func_idx as usize - func_imports)
508			.ok_or("Function at the specified index is not defined")?
509			.type_ref()
510	};
511	let Type::Function(ty) = types
512		.get(sig_idx as usize)
513		.ok_or("The signature as specified by a function isn't defined")?;
514	Ok(ty)
515}
516
517#[cfg(test)]
518mod tests {
519	use super::*;
520	use parity_wasm::elements;
521
522	fn parse_wat(source: &str) -> elements::Module {
523		elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
524			.expect("Failed to deserialize the module")
525	}
526
527	fn validate_module(module: elements::Module) {
528		let binary = elements::serialize(module).expect("Failed to serialize");
529		wasmparser::validate(&binary).expect("Invalid module");
530	}
531
532	#[test]
533	fn test_with_params_and_result() {
534		let module = parse_wat(
535			r#"
536(module
537	(func (export "i32.add") (param i32 i32) (result i32)
538		local.get 0
539	local.get 1
540	i32.add
541	)
542)
543"#,
544		);
545
546		let module = inject(module, 1024).expect("Failed to inject stack counter");
547		validate_module(module);
548	}
549}