owasm_utils/stack_height/
mod.rs1use std::string::String;
52use std::vec::Vec;
53
54use parity_wasm::elements::{self, Type};
55use parity_wasm::builder;
56
57macro_rules! instrument_call {
59 ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
60 use $crate::parity_wasm::elements::Instruction::*;
61 [
62 GetGlobal($stack_height_global_idx),
64 I32Const($callee_stack_cost),
65 I32Add,
66 SetGlobal($stack_height_global_idx),
67 GetGlobal($stack_height_global_idx),
69 I32Const($stack_limit as i32),
70 I32GtU,
71 If(elements::BlockType::NoResult),
72 Unreachable,
73 End,
74 Call($callee_idx),
76 GetGlobal($stack_height_global_idx),
78 I32Const($callee_stack_cost),
79 I32Sub,
80 SetGlobal($stack_height_global_idx),
81 ]
82 }};
83}
84
85mod max_height;
86mod thunk;
87
88#[derive(Debug)]
92pub struct Error(String);
93
94pub(crate) struct Context {
95 stack_height_global_idx: Option<u32>,
96 func_stack_costs: Option<Vec<u32>>,
97 stack_limit: u32,
98}
99
100impl Context {
101 fn stack_height_global_idx(&self) -> u32 {
105 self.stack_height_global_idx.expect(
106 "stack_height_global_idx isn't yet generated;
107 Did you call `inject_stack_counter_global`",
108 )
109 }
110
111 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
116 self.func_stack_costs
117 .as_ref()
118 .expect(
119 "func_stack_costs isn't yet computed;
120 Did you call `compute_stack_costs`?",
121 )
122 .get(func_idx as usize)
123 .cloned()
124 }
125
126 fn stack_limit(&self) -> u32 {
128 self.stack_limit
129 }
130}
131
132pub fn inject_limiter(
140 mut module: elements::Module,
141 stack_limit: u32,
142) -> Result<elements::Module, Error> {
143 let mut ctx = Context {
144 stack_height_global_idx: None,
145 func_stack_costs: None,
146 stack_limit,
147 };
148
149 generate_stack_height_global(&mut ctx, &mut module);
150 compute_stack_costs(&mut ctx, &module)?;
151 instrument_functions(&mut ctx, &mut module)?;
152 let module = thunk::generate_thunks(&mut ctx, module)?;
153
154 Ok(module)
155}
156
157fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) {
159 let global_entry = builder::global()
160 .value_type()
161 .i32()
162 .mutable()
163 .init_expr(elements::Instruction::I32Const(0))
164 .build();
165
166 for section in module.sections_mut() {
168 if let elements::Section::Global(ref mut gs) = *section {
169 gs.entries_mut().push(global_entry);
170
171 let stack_height_global_idx = (gs.entries().len() as u32) - 1;
172 ctx.stack_height_global_idx = Some(stack_height_global_idx);
173 return;
174 }
175 }
176
177 module.sections_mut().push(elements::Section::Global(
179 elements::GlobalSection::with_entries(vec![global_entry]),
180 ));
181 ctx.stack_height_global_idx = Some(0);
182}
183
184fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> {
188 let func_imports = module.import_count(elements::ImportCountType::Function);
189 let mut func_stack_costs = vec![0; module.functions_space()];
190 for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() {
192 if func_idx >= func_imports {
194 *func_stack_cost = compute_stack_cost(func_idx as u32, &module)?;
195 }
196 }
197
198 ctx.func_stack_costs = Some(func_stack_costs);
199 Ok(())
200}
201
202fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
206 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
209 let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
210 Error("This should be a index of a defined function".into())
211 })?;
212
213 let code_section = module.code_section().ok_or_else(|| {
214 Error("Due to validation code section should exists".into())
215 })?;
216 let body = &code_section
217 .bodies()
218 .get(defined_func_idx as usize)
219 .ok_or_else(|| Error("Function body is out of bounds".into()))?;
220 let locals_count = body.locals().len() as u32;
221
222 let max_stack_height =
223 max_height::compute(
224 defined_func_idx,
225 module
226 )?;
227
228 Ok(locals_count + max_stack_height)
229}
230
231fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
232 for section in module.sections_mut() {
233 if let elements::Section::Code(ref mut code_section) = *section {
234 for func_body in code_section.bodies_mut() {
235 let mut opcodes = func_body.code_mut();
236 instrument_function(ctx, opcodes)?;
237 }
238 }
239 }
240 Ok(())
241}
242
243fn instrument_function(
270 ctx: &mut Context,
271 instructions: &mut elements::Instructions,
272) -> Result<(), Error> {
273 use parity_wasm::elements::Instruction::*;
274
275 let mut cursor = 0;
276 loop {
277 if cursor >= instructions.elements().len() {
278 break;
279 }
280
281 enum Action {
282 InstrumentCall {
283 callee_idx: u32,
284 callee_stack_cost: u32,
285 },
286 Nop,
287 }
288
289 let action: Action = {
290 let instruction = &instructions.elements()[cursor];
291 match *instruction {
292 Call(ref callee_idx) => {
293 let callee_stack_cost = ctx
294 .stack_cost(*callee_idx)
295 .ok_or_else(||
296 Error(
297 format!("Call to function that out-of-bounds: {}", callee_idx)
298 )
299 )?;
300
301 if callee_stack_cost > 0 {
304 Action::InstrumentCall {
305 callee_idx: *callee_idx,
306 callee_stack_cost,
307 }
308 } else {
309 Action::Nop
310 }
311 },
312 _ => Action::Nop,
313 }
314 };
315
316 match action {
317 Action::InstrumentCall { callee_idx, callee_stack_cost } => {
321 let new_seq = instrument_call!(
322 callee_idx,
323 callee_stack_cost as i32,
324 ctx.stack_height_global_idx(),
325 ctx.stack_limit()
326 );
327
328 let _ = instructions
334 .elements_mut()
335 .splice(cursor..(cursor + 1), new_seq.iter().cloned())
336 .count();
337
338 cursor += new_seq.len();
340 }
341 _ => {
343 cursor += 1;
344 }
345 }
346 }
347
348 Ok(())
349}
350
351fn resolve_func_type(
352 func_idx: u32,
353 module: &elements::Module,
354) -> Result<&elements::FunctionType, Error> {
355 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
356 let functions = module
357 .function_section()
358 .map(|fs| fs.entries())
359 .unwrap_or(&[]);
360
361 let func_imports = module.import_count(elements::ImportCountType::Function);
362 let sig_idx = if func_idx < func_imports as u32 {
363 module
364 .import_section()
365 .expect("function import count is not zero; import section must exists; qed")
366 .entries()
367 .iter()
368 .filter_map(|entry| match *entry.external() {
369 elements::External::Function(ref idx) => Some(*idx),
370 _ => None,
371 })
372 .nth(func_idx as usize)
373 .expect(
374 "func_idx is less than function imports count;
375 nth function import must be `Some`;
376 qed",
377 )
378 } else {
379 functions
380 .get(func_idx as usize - func_imports)
381 .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
382 .type_ref()
383 };
384 let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| {
385 Error(format!(
386 "Signature {} (specified by func {}) isn't defined",
387 sig_idx, func_idx
388 ))
389 })?;
390 Ok(ty)
391}
392
393#[cfg(test)]
394mod tests {
395 extern crate wabt;
396 use parity_wasm::elements;
397 use super::*;
398
399 fn parse_wat(source: &str) -> elements::Module {
400 elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
401 .expect("Failed to deserialize the module")
402 }
403
404 fn validate_module(module: elements::Module) {
405 let binary = elements::serialize(module).expect("Failed to serialize");
406 wabt::Module::read_binary(&binary, &Default::default())
407 .expect("Wabt failed to read final binary")
408 .validate()
409 .expect("Invalid module");
410 }
411
412 #[test]
413 fn test_with_params_and_result() {
414 let module = parse_wat(
415 r#"
416(module
417 (func (export "i32.add") (param i32 i32) (result i32)
418 get_local 0
419 get_local 1
420 i32.add
421 )
422)
423"#,
424 );
425
426 let module = inject_limiter(module, 1024)
427 .expect("Failed to inject stack counter");
428 validate_module(module);
429 }
430}