casper_wasm_utils/
stack_height.rs1use crate::std::{mem, string::String, vec::Vec};
53
54use casper_wasm::{
55 builder,
56 elements::{self, Instruction, Instructions, Type},
57};
58
59const fn instrument_call(
61 callee_idx: u32,
62 callee_stack_cost: u32,
63 stack_height_global_idx: u32,
64 stack_limit: u32,
65) -> [Instruction; INSTRUMENT_CALL_LENGTH] {
66 use casper_wasm::elements::Instruction::*;
67 [
68 GetGlobal(stack_height_global_idx),
70 I32Const(callee_stack_cost as i32),
71 I32Add,
72 SetGlobal(stack_height_global_idx),
73 GetGlobal(stack_height_global_idx),
75 I32Const(stack_limit as i32),
76 I32GtU,
77 If(elements::BlockType::NoResult),
78 Unreachable,
79 End,
80 Call(callee_idx),
82 GetGlobal(stack_height_global_idx),
84 I32Const(callee_stack_cost as i32),
85 I32Sub,
86 SetGlobal(stack_height_global_idx),
87 ]
88}
89
90const INSTRUMENT_CALL_LENGTH: usize = 15;
91
92mod max_height;
93mod thunk;
94
95#[derive(Debug)]
99#[allow(unused)]
100pub struct Error(String);
101
102pub(crate) struct Context {
103 stack_height_global_idx: u32,
104 func_stack_costs: Vec<u32>,
105 stack_limit: u32,
106}
107
108impl Context {
109 fn stack_height_global_idx(&self) -> u32 {
111 self.stack_height_global_idx
112 }
113
114 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
116 self.func_stack_costs.get(func_idx as usize).cloned()
117 }
118
119 fn stack_limit(&self) -> u32 {
121 self.stack_limit
122 }
123}
124
125pub fn inject_limiter(
133 mut module: elements::Module,
134 stack_limit: u32,
135) -> Result<elements::Module, Error> {
136 let mut ctx = Context {
137 stack_height_global_idx: generate_stack_height_global(&mut module),
138 func_stack_costs: compute_stack_costs(&module)?,
139 stack_limit,
140 };
141
142 instrument_functions(&mut ctx, &mut module)?;
143 let module = thunk::generate_thunks(&mut ctx, module)?;
144
145 Ok(module)
146}
147
148fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
150 let global_entry = builder::global()
151 .value_type()
152 .i32()
153 .mutable()
154 .init_expr(Instruction::I32Const(0))
155 .build();
156
157 for section in module.sections_mut() {
159 if let elements::Section::Global(gs) = section {
160 gs.entries_mut().push(global_entry);
161 return (gs.entries().len() as u32) - 1;
162 }
163 }
164
165 module.sections_mut().push(elements::Section::Global(
167 elements::GlobalSection::with_entries(vec![global_entry]),
168 ));
169 0
170}
171
172fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
176 let func_imports = module.import_count(elements::ImportCountType::Function);
177
178 (0..module.functions_space())
180 .map(|func_idx| {
181 if func_idx < func_imports {
182 Ok(0)
184 } else {
185 compute_stack_cost(func_idx as u32, module)
186 }
187 })
188 .collect()
189}
190
191fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
195 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
198 let defined_func_idx = func_idx
199 .checked_sub(func_imports)
200 .ok_or_else(|| Error("This should be a index of a defined function".into()))?;
201
202 let code_section = module
203 .code_section()
204 .ok_or_else(|| Error("Due to validation code section should exists".into()))?;
205 let body = &code_section
206 .bodies()
207 .get(defined_func_idx as usize)
208 .ok_or_else(|| Error("Function body is out of bounds".into()))?;
209
210 let mut locals_count: u32 = 0;
211 for local_group in body.locals() {
212 locals_count = locals_count
213 .checked_add(local_group.count())
214 .ok_or_else(|| Error("Overflow in local count".into()))?;
215 }
216
217 let max_stack_height = max_height::compute(defined_func_idx, module)?;
218
219 locals_count
220 .checked_add(max_stack_height)
221 .ok_or_else(|| Error("Overflow in adding locals_count and max_stack_height".into()))
222}
223
224fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
225 for section in module.sections_mut() {
226 if let elements::Section::Code(code_section) = section {
227 for func_body in code_section.bodies_mut() {
228 let opcodes = func_body.code_mut();
229 instrument_function(ctx, opcodes)?;
230 }
231 }
232 }
233 Ok(())
234}
235
236fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
263 use Instruction::*;
264
265 struct InstrumentCall {
266 offset: usize,
267 callee: u32,
268 cost: u32,
269 }
270
271 let calls: Vec<_> = func
272 .elements()
273 .iter()
274 .enumerate()
275 .filter_map(|(offset, instruction)| {
276 if let Call(callee) = *instruction {
277 ctx.stack_cost(callee)
278 .filter(|&cost| cost > 0)
279 .map(|cost| InstrumentCall {
280 callee,
281 offset,
282 cost,
283 })
284 } else {
285 None
286 }
287 })
288 .collect();
289
290 let len = func.elements().len() + calls.len() * (INSTRUMENT_CALL_LENGTH - 1);
292 let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
293 let new_instrs = func.elements_mut();
294
295 let mut calls = calls.into_iter().peekable();
296 for (original_pos, instr) in original_instrs.into_iter().enumerate() {
297 let did_instrument = if let Some(call) = calls.peek() {
299 if call.offset == original_pos {
300 let new_seq = instrument_call(
301 call.callee,
302 call.cost,
303 ctx.stack_height_global_idx(),
304 ctx.stack_limit(),
305 );
306 new_instrs.extend(new_seq);
307 true
308 } else {
309 false
310 }
311 } else {
312 false
313 };
314
315 if did_instrument {
316 calls.next();
317 } else {
318 new_instrs.push(instr);
319 }
320 }
321
322 if calls.next().is_some() {
323 return Err(Error("Not all calls were used".into()));
324 }
325
326 Ok(())
327}
328
329fn resolve_func_type(
330 func_idx: u32,
331 module: &elements::Module,
332) -> Result<&elements::FunctionType, Error> {
333 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
334 let functions = module
335 .function_section()
336 .map(|fs| fs.entries())
337 .unwrap_or(&[]);
338
339 let func_imports = module.import_count(elements::ImportCountType::Function);
340 let sig_idx = if func_idx < func_imports as u32 {
341 module
342 .import_section()
343 .expect("function import count is not zero; import section must exists; qed")
344 .entries()
345 .iter()
346 .filter_map(|entry| match entry.external() {
347 elements::External::Function(idx) => Some(*idx),
348 _ => None,
349 })
350 .nth(func_idx as usize)
351 .expect(
352 "func_idx is less than function imports count;
353 nth function import must be `Some`;
354 qed",
355 )
356 } else {
357 functions
358 .get(func_idx as usize - func_imports)
359 .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
360 .type_ref()
361 };
362 let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
363 Error(format!(
364 "Signature {} (specified by func {}) isn't defined",
365 sig_idx, func_idx
366 ))
367 })?;
368 Ok(ty)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use casper_wasm::elements;
375
376 fn parse_wat(source: &str) -> elements::Module {
377 elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
378 .expect("Failed to deserialize the module")
379 }
380
381 fn validate_module(module: elements::Module) {
382 let binary = elements::serialize(module).expect("Failed to serialize");
383 wabt::Module::read_binary(binary, &Default::default())
384 .expect("Wabt failed to read final binary")
385 .validate()
386 .expect("Invalid module");
387 }
388
389 #[test]
390 fn test_with_params_and_result() {
391 let module = parse_wat(
392 r#"
393(module
394 (func (export "i32.add") (param i32 i32) (result i32)
395 get_local 0
396 get_local 1
397 i32.add
398 )
399)
400"#,
401 );
402
403 let module = inject_limiter(module, 1024).expect("Failed to inject stack counter");
404 validate_module(module);
405 }
406}