gear_pwasm_utils/stack_height/
mod.rs1use crate::std::{mem, string::String, vec::Vec};
53
54use parity_wasm::{
55 builder,
56 elements::{self, Instruction, Instructions, Type},
57};
58
59macro_rules! instrument_call {
61 ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
62 use $crate::parity_wasm::elements::Instruction::*;
63 [
64 GetGlobal($stack_height_global_idx),
66 I32Const($callee_stack_cost),
67 I32Add,
68 SetGlobal($stack_height_global_idx),
69 GetGlobal($stack_height_global_idx),
71 I32Const($stack_limit as i32),
72 I32GtU,
73 If(elements::BlockType::NoResult),
74 Unreachable,
75 End,
76 Call($callee_idx),
78 GetGlobal($stack_height_global_idx),
80 I32Const($callee_stack_cost),
81 I32Sub,
82 SetGlobal($stack_height_global_idx),
83 ]
84 }};
85}
86
87mod max_height;
88mod thunk;
89
90#[derive(Debug)]
94pub struct Error(String);
95
96pub(crate) struct Context {
97 stack_height_global_idx: u32,
98 func_stack_costs: Vec<u32>,
99 stack_limit: u32,
100}
101
102impl Context {
103 fn stack_height_global_idx(&self) -> u32 {
105 self.stack_height_global_idx
106 }
107
108 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
110 self.func_stack_costs.get(func_idx as usize).cloned()
111 }
112
113 fn stack_limit(&self) -> u32 {
115 self.stack_limit
116 }
117}
118
119pub fn inject_limiter(
127 mut module: elements::Module,
128 stack_limit: u32,
129) -> Result<elements::Module, Error> {
130 let mut ctx = Context {
131 stack_height_global_idx: generate_stack_height_global(&mut module),
132 func_stack_costs: compute_stack_costs(&module)?,
133 stack_limit,
134 };
135
136 instrument_functions(&mut ctx, &mut module)?;
137 let module = thunk::generate_thunks(&mut ctx, module)?;
138
139 Ok(module)
140}
141
142fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
144 let global_entry = builder::global()
145 .value_type()
146 .i32()
147 .mutable()
148 .init_expr(Instruction::I32Const(0))
149 .build();
150
151 for section in module.sections_mut() {
153 if let elements::Section::Global(gs) = section {
154 gs.entries_mut().push(global_entry);
155 return (gs.entries().len() as u32) - 1
156 }
157 }
158
159 module
161 .sections_mut()
162 .push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry])));
163 0
164}
165
166fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
170 let func_imports = module.import_count(elements::ImportCountType::Function);
171
172 (0..module.functions_space())
174 .map(|func_idx| {
175 if func_idx < func_imports {
176 Ok(0)
178 } else {
179 compute_stack_cost(func_idx as u32, module)
180 }
181 })
182 .collect()
183}
184
185fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
189 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
192 let defined_func_idx = func_idx
193 .checked_sub(func_imports)
194 .ok_or_else(|| Error("This should be a index of a defined function".into()))?;
195
196 let code_section = module
197 .code_section()
198 .ok_or_else(|| Error("Due to validation code section should exists".into()))?;
199 let body = &code_section
200 .bodies()
201 .get(defined_func_idx as usize)
202 .ok_or_else(|| Error("Function body is out of bounds".into()))?;
203
204 let mut locals_count: u32 = 0;
205 for local_group in body.locals() {
206 locals_count = locals_count
207 .checked_add(local_group.count())
208 .ok_or_else(|| Error("Overflow in local count".into()))?;
209 }
210
211 let max_stack_height = max_height::compute(defined_func_idx, module)?;
212
213 locals_count
214 .checked_add(max_stack_height)
215 .ok_or_else(|| Error("Overflow in adding locals_count and max_stack_height".into()))
216}
217
218fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
219 for section in module.sections_mut() {
220 if let elements::Section::Code(code_section) = section {
221 for func_body in code_section.bodies_mut() {
222 let opcodes = func_body.code_mut();
223 instrument_function(ctx, opcodes)?;
224 }
225 }
226 }
227 Ok(())
228}
229
230fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
257 use Instruction::*;
258
259 struct InstrumentCall {
260 offset: usize,
261 callee: u32,
262 cost: u32,
263 }
264
265 let calls: Vec<_> = func
266 .elements()
267 .iter()
268 .enumerate()
269 .filter_map(|(offset, instruction)| {
270 if let Call(callee) = instruction {
271 ctx.stack_cost(*callee).and_then(|cost| {
272 if cost > 0 {
273 Some(InstrumentCall { callee: *callee, offset, cost })
274 } else {
275 None
276 }
277 })
278 } else {
279 None
280 }
281 })
282 .collect();
283
284 let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
286 let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
287 let new_instrs = func.elements_mut();
288
289 let mut calls = calls.into_iter().peekable();
290 for (original_pos, instr) in original_instrs.into_iter().enumerate() {
291 let did_instrument = if let Some(call) = calls.peek() {
293 if call.offset == original_pos {
294 let new_seq = instrument_call!(
295 call.callee,
296 call.cost as i32,
297 ctx.stack_height_global_idx(),
298 ctx.stack_limit()
299 );
300 new_instrs.extend(new_seq);
301 true
302 } else {
303 false
304 }
305 } else {
306 false
307 };
308
309 if did_instrument {
310 calls.next();
311 } else {
312 new_instrs.push(instr);
313 }
314 }
315
316 if calls.next().is_some() {
317 return Err(Error("Not all calls were used".into()))
318 }
319
320 Ok(())
321}
322
323fn resolve_func_type(
324 func_idx: u32,
325 module: &elements::Module,
326) -> Result<&elements::FunctionType, Error> {
327 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
328 let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
329
330 let func_imports = module.import_count(elements::ImportCountType::Function);
331 let sig_idx = if func_idx < func_imports as u32 {
332 module
333 .import_section()
334 .expect("function import count is not zero; import section must exists; qed")
335 .entries()
336 .iter()
337 .filter_map(|entry| match entry.external() {
338 elements::External::Function(idx) => Some(*idx),
339 _ => None,
340 })
341 .nth(func_idx as usize)
342 .expect(
343 "func_idx is less than function imports count;
344 nth function import must be `Some`;
345 qed",
346 )
347 } else {
348 functions
349 .get(func_idx as usize - func_imports)
350 .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
351 .type_ref()
352 };
353 let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
354 Error(format!("Signature {} (specified by func {}) isn't defined", sig_idx, func_idx))
355 })?;
356 Ok(ty)
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use parity_wasm::elements;
363
364 fn parse_wat(source: &str) -> elements::Module {
365 elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
366 .expect("Failed to deserialize the module")
367 }
368
369 fn validate_module(module: elements::Module) {
370 let binary = elements::serialize(module).expect("Failed to serialize");
371 wabt::Module::read_binary(&binary, &Default::default())
372 .expect("Wabt failed to read final binary")
373 .validate()
374 .expect("Invalid module");
375 }
376
377 #[test]
378 fn test_with_params_and_result() {
379 let module = parse_wat(
380 r#"
381(module
382 (func (export "i32.add") (param i32 i32) (result i32)
383 get_local 0
384 get_local 1
385 i32.add
386 )
387)
388"#,
389 );
390
391 let module = inject_limiter(module, 1024).expect("Failed to inject stack counter");
392 validate_module(module);
393 }
394}