gwasm_instrument/stack_limiter/
mod.rs1use alloc::{vec, vec::Vec};
4use core::mem;
5use max_height::{MaxStackHeightCounter, MaxStackHeightCounterContext};
6use parity_wasm::{
7 builder,
8 elements::{self, FunctionType, Instruction, Instructions, Type},
9};
10
11mod max_height;
12mod thunk;
13
14pub(crate) struct Context {
15 stack_height_global_idx: u32,
16 func_stack_costs: Vec<u32>,
17 stack_limit: u32,
18}
19
20impl Context {
21 fn stack_height_global_idx(&self) -> u32 {
23 self.stack_height_global_idx
24 }
25
26 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
28 self.func_stack_costs.get(func_idx as usize).cloned()
29 }
30
31 fn stack_limit(&self) -> u32 {
33 self.stack_limit
34 }
35}
36
37pub fn inject(
88 module: elements::Module,
89 stack_limit: u32,
90) -> Result<elements::Module, &'static str> {
91 inject_with_config(
92 module,
93 InjectionConfig {
94 stack_limit,
95 injection_fn: |_| [Instruction::Unreachable],
96 stack_height_export_name: None,
97 },
98 )
99}
100
101pub struct InjectionConfig<'a, I, F>
103where
104 I: IntoIterator<Item = Instruction>,
105 I::IntoIter: ExactSizeIterator + Clone,
106 F: Fn(&FunctionType) -> I,
107{
108 pub stack_limit: u32,
109 pub injection_fn: F,
110 pub stack_height_export_name: Option<&'a str>,
111}
112
113pub fn inject_with_config<I: IntoIterator<Item = Instruction>>(
116 mut module: elements::Module,
117 injection_config: InjectionConfig<'_, I, impl Fn(&FunctionType) -> I>,
118) -> Result<elements::Module, &'static str>
119where
120 I::IntoIter: ExactSizeIterator + Clone,
121{
122 let InjectionConfig { stack_limit, injection_fn, stack_height_export_name } = injection_config;
123 let mut ctx = Context {
124 stack_height_global_idx: generate_stack_height_global(
125 &mut module,
126 stack_height_export_name,
127 ),
128 func_stack_costs: compute_stack_costs(&module, &injection_fn)?,
129 stack_limit,
130 };
131
132 instrument_functions(&mut ctx, &mut module, &injection_fn)?;
133 let module = thunk::generate_thunks(&mut ctx, module, &injection_fn)?;
134
135 Ok(module)
136}
137
138fn generate_stack_height_global(
140 module: &mut elements::Module,
141 stack_height_export_name: Option<&str>,
142) -> u32 {
143 let global_entry = builder::global()
144 .value_type()
145 .i32()
146 .mutable()
147 .init_expr(Instruction::I32Const(0))
148 .build();
149
150 let stack_height_global_idx = match module.global_section_mut() {
151 Some(global_section) => {
152 global_section.entries_mut().push(global_entry);
153 (global_section.entries().len() as u32) - 1
154 },
155 None => {
156 module.sections_mut().push(elements::Section::Global(
157 elements::GlobalSection::with_entries(vec![global_entry]),
158 ));
159 0
160 },
161 };
162
163 if let Some(stack_height_export_name) = stack_height_export_name {
164 let export_entry = elements::ExportEntry::new(
165 stack_height_export_name.into(),
166 elements::Internal::Global(stack_height_global_idx),
167 );
168
169 match module.export_section_mut() {
170 Some(export_section) => {
171 export_section.entries_mut().push(export_entry);
172 },
173 None => {
174 module.sections_mut().push(elements::Section::Export(
175 elements::ExportSection::with_entries(vec![export_entry]),
176 ));
177 },
178 }
179 }
180
181 stack_height_global_idx
182}
183
184fn compute_stack_costs<I: IntoIterator<Item = Instruction>>(
188 module: &elements::Module,
189 injection_fn: impl Fn(&FunctionType) -> I,
190) -> Result<Vec<u32>, &'static str>
191where
192 I::IntoIter: ExactSizeIterator + Clone,
193{
194 let functions_space = module
195 .functions_space()
196 .try_into()
197 .map_err(|_| "Can't convert functions space to u32")?;
198
199 if functions_space == 0 {
201 return Ok(Vec::new())
202 }
203
204 let context: MaxStackHeightCounterContext = module.try_into()?;
207
208 (0..functions_space)
209 .map(|func_idx| {
210 if func_idx < context.func_imports {
211 Ok(0)
213 } else {
214 compute_stack_cost(func_idx, context, &injection_fn)
215 }
216 })
217 .collect()
218}
219
220fn compute_stack_cost<I: IntoIterator<Item = Instruction>>(
224 func_idx: u32,
225 context: MaxStackHeightCounterContext,
226 injection_fn: impl Fn(&FunctionType) -> I,
227) -> Result<u32, &'static str>
228where
229 I::IntoIter: ExactSizeIterator + Clone,
230{
231 let defined_func_idx = func_idx
234 .checked_sub(context.func_imports)
235 .ok_or("This should be a index of a defined function")?;
236
237 let body = context
238 .code_section
239 .bodies()
240 .get(defined_func_idx as usize)
241 .ok_or("Function body is out of bounds")?;
242
243 let mut locals_count: u32 = 0;
244 for local_group in body.locals() {
245 locals_count =
246 locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
247 }
248
249 let max_stack_height = MaxStackHeightCounter::new_with_context(context, injection_fn)
250 .count_instrumented_calls(true)
251 .compute_for_defined_func(defined_func_idx)?;
252
253 locals_count
254 .checked_add(max_stack_height)
255 .ok_or("Overflow in adding locals_count and max_stack_height")
256}
257
258fn instrument_functions<I: IntoIterator<Item = Instruction>>(
259 ctx: &mut Context,
260 module: &mut elements::Module,
261 injection_fn: impl Fn(&FunctionType) -> I,
262) -> Result<(), &'static str>
263where
264 I::IntoIter: ExactSizeIterator + Clone,
265{
266 if ctx.func_stack_costs.is_empty() {
267 return Ok(())
268 }
269
270 let types = module.type_section().map(|ts| ts.types()).expect("checked earlier").to_vec();
273 let functions = module
274 .function_section()
275 .map(|fs| fs.entries())
276 .expect("checked earlier")
277 .to_vec();
278
279 if let Some(code_section) = module.code_section_mut() {
280 for (func_idx, func_body) in code_section.bodies_mut().iter_mut().enumerate() {
281 let opcodes = func_body.code_mut();
282
283 let signature_index = &functions[func_idx];
284 let signature = &types[signature_index.type_ref() as usize];
285 let Type::Function(signature) = signature;
286
287 instrument_function(ctx, opcodes, signature, &injection_fn)?;
288 }
289 }
290
291 Ok(())
292}
293
294fn instrument_function<I: IntoIterator<Item = Instruction>>(
321 ctx: &mut Context,
322 func: &mut Instructions,
323 signature: &FunctionType,
324 injection_fn: impl Fn(&FunctionType) -> I,
325) -> Result<(), &'static str>
326where
327 I::IntoIter: ExactSizeIterator + Clone,
328{
329 use Instruction::*;
330
331 struct InstrumentCall {
332 offset: usize,
333 callee: u32,
334 cost: u32,
335 }
336
337 let calls: Vec<_> = func
338 .elements()
339 .iter()
340 .enumerate()
341 .filter_map(|(offset, instruction)| {
342 if let Call(callee) = instruction {
343 ctx.stack_cost(*callee).and_then(|cost| {
344 if cost > 0 {
345 Some(InstrumentCall { callee: *callee, offset, cost })
346 } else {
347 None
348 }
349 })
350 } else {
351 None
352 }
353 })
354 .collect();
355
356 let body_of_condition = injection_fn(signature).into_iter();
360 let len = func.elements().len() + calls.len() * (13 + body_of_condition.len());
361 let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
362 let new_instrs = func.elements_mut();
363
364 let mut calls = calls.into_iter().peekable();
365 for (original_pos, instr) in original_instrs.into_iter().enumerate() {
366 let did_instrument = if let Some(call) = calls.peek() {
368 if call.offset == original_pos {
369 instrument_call(
370 new_instrs,
371 call.callee,
372 call.cost as i32,
373 ctx.stack_height_global_idx(),
374 ctx.stack_limit(),
375 body_of_condition.clone(),
376 [],
377 );
378 true
379 } else {
380 false
381 }
382 } else {
383 false
384 };
385
386 if did_instrument {
387 calls.next();
388 } else {
389 new_instrs.push(instr);
390 }
391 }
392
393 if calls.next().is_some() {
394 return Err("Not all calls were used")
395 }
396
397 Ok(())
398}
399
400fn instrument_call(
402 instructions: &mut Vec<Instruction>,
403 callee_idx: u32,
404 callee_stack_cost: i32,
405 stack_height_global_idx: u32,
406 stack_limit: u32,
407 body_of_condition: impl IntoIterator<Item = Instruction>,
408 arguments: impl IntoIterator<Item = Instruction>,
409) {
410 use Instruction::*;
411
412 generate_preamble(
414 instructions,
415 callee_stack_cost,
416 stack_height_global_idx,
417 stack_limit,
418 body_of_condition,
419 );
420
421 instructions.extend(arguments);
423
424 instructions.push(Call(callee_idx));
426
427 generate_postamble(instructions, callee_stack_cost, stack_height_global_idx);
429}
430
431fn generate_preamble(
433 instructions: &mut Vec<Instruction>,
434 callee_stack_cost: i32,
435 stack_height_global_idx: u32,
436 stack_limit: u32,
437 body_of_condition: impl IntoIterator<Item = Instruction>,
438) {
439 use Instruction::*;
440
441 instructions.extend_from_slice(&[
443 GetGlobal(stack_height_global_idx),
445 I32Const(callee_stack_cost),
446 I32Add,
447 SetGlobal(stack_height_global_idx),
448 GetGlobal(stack_height_global_idx),
450 I32Const(stack_limit as i32),
451 I32GtU,
452 If(elements::BlockType::NoResult),
453 ]);
454
455 instructions.extend(body_of_condition);
457
458 instructions.push(End);
460}
461
462#[inline]
464fn generate_postamble(
465 instructions: &mut Vec<Instruction>,
466 callee_stack_cost: i32,
467 stack_height_global_idx: u32,
468) {
469 use Instruction::*;
470
471 instructions.extend_from_slice(&[
473 GetGlobal(stack_height_global_idx),
475 I32Const(callee_stack_cost),
476 I32Sub,
477 SetGlobal(stack_height_global_idx),
478 ]);
479}
480
481fn resolve_func_type(
482 func_idx: u32,
483 module: &elements::Module,
484) -> Result<&FunctionType, &'static str> {
485 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
486 let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
487
488 let func_imports = module.import_count(elements::ImportCountType::Function);
489 let sig_idx = if func_idx < func_imports as u32 {
490 module
491 .import_section()
492 .expect("function import count is not zero; import section must exists; qed")
493 .entries()
494 .iter()
495 .filter_map(|entry| match entry.external() {
496 elements::External::Function(idx) => Some(*idx),
497 _ => None,
498 })
499 .nth(func_idx as usize)
500 .expect(
501 "func_idx is less than function imports count;
502 nth function import must be `Some`;
503 qed",
504 )
505 } else {
506 functions
507 .get(func_idx as usize - func_imports)
508 .ok_or("Function at the specified index is not defined")?
509 .type_ref()
510 };
511 let Type::Function(ty) = types
512 .get(sig_idx as usize)
513 .ok_or("The signature as specified by a function isn't defined")?;
514 Ok(ty)
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use parity_wasm::elements;
521
522 fn parse_wat(source: &str) -> elements::Module {
523 elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
524 .expect("Failed to deserialize the module")
525 }
526
527 fn validate_module(module: elements::Module) {
528 let binary = elements::serialize(module).expect("Failed to serialize");
529 wasmparser::validate(&binary).expect("Invalid module");
530 }
531
532 #[test]
533 fn test_with_params_and_result() {
534 let module = parse_wat(
535 r#"
536(module
537 (func (export "i32.add") (param i32 i32) (result i32)
538 local.get 0
539 local.get 1
540 i32.add
541 )
542)
543"#,
544 );
545
546 let module = inject(module, 1024).expect("Failed to inject stack counter");
547 validate_module(module);
548 }
549}