1use crate::Error;
2
3use wasm_instrument::parity_wasm::{
4 builder,
5 elements::{deserialize_buffer, serialize, External, MemoryType, Module},
6};
7use wasmer::wasmparser;
8
9static MEMORY_LIMIT: u32 = 512; static MAX_STACK_HEIGHT: u32 = 16 * 1024; static REQUIRED_EXPORTS: &[&str] = &["prepare", "execute"];
15static SUPPORTED_IMPORTS: &[&str] = &[
16 "env.get_span_size",
17 "env.read_calldata",
18 "env.set_return_data",
19 "env.get_ask_count",
20 "env.get_min_count",
21 "env.get_prepare_time",
22 "env.get_execute_time",
23 "env.get_ans_count",
24 "env.ask_external_data",
25 "env.get_external_data_status",
26 "env.read_external_data",
27 "env.ecvrf_verify",
28];
29
30pub fn compile(code: &[u8]) -> Result<Vec<u8>, Error> {
31 wasmparser::validate(code).map_err(|_| Error::ValidationError)?;
33
34 let module = deserialize_buffer(code).map_err(|_| Error::DeserializationError)?;
36 check_wasm_exports(&module)?;
37 check_wasm_imports(&module)?;
38 let module = inject_memory(module)?;
39 let module = inject_stack_height(module)?;
40
41 serialize(module).map_err(|_| Error::SerializationError)
43}
44
45fn check_wasm_exports(module: &Module) -> Result<(), Error> {
46 let available_exports: Vec<&str> = module.export_section().map_or(vec![], |export_section| {
47 export_section.entries().iter().map(|entry| entry.field()).collect()
48 });
49
50 for required_export in REQUIRED_EXPORTS {
51 if !available_exports.contains(required_export) {
52 return Err(Error::InvalidExportsError);
53 }
54 }
55
56 Ok(())
57}
58
59fn check_wasm_imports(module: &Module) -> Result<(), Error> {
60 let required_imports =
61 module.import_section().map_or(vec![], |import_section| import_section.entries().to_vec());
62
63 for required_import in required_imports {
64 let full_name = format!("{}.{}", required_import.module(), required_import.field());
65 if !SUPPORTED_IMPORTS.contains(&full_name.as_str()) {
66 return Err(Error::InvalidImportsError);
67 }
68
69 match required_import.external() {
70 External::Function(_) => (), _ => return Err(Error::InvalidImportsError),
72 };
73 }
74
75 Ok(())
76}
77
78fn inject_memory(module: Module) -> Result<Module, Error> {
79 let mut m = module;
80 let section = match m.memory_section() {
81 Some(section) => section,
82 None => return Err(Error::BadMemorySectionError),
83 };
84
85 let memory = section.entries()[0];
88 let limits = memory.limits();
89
90 if limits.initial() > MEMORY_LIMIT {
91 return Err(Error::BadMemorySectionError);
92 }
93
94 if limits.maximum() != None {
95 return Err(Error::BadMemorySectionError);
96 }
97
98 let memory = MemoryType::new(limits.initial(), Some(MEMORY_LIMIT));
100
101 let entries = m.memory_section_mut().unwrap().entries_mut();
103 entries.pop();
104 entries.push(memory);
105
106 Ok(builder::from_module(m).build())
107}
108
109fn inject_stack_height(module: Module) -> Result<Module, Error> {
110 wasm_instrument::inject_stack_limiter(module, MAX_STACK_HEIGHT)
111 .map_err(|_| Error::StackHeightInjectionError)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 use assert_matches::assert_matches;
119 use std::io::{Read, Write};
120 use std::process::Command;
121 use tempfile::NamedTempFile;
122
123 fn wat2wasm(wat: impl AsRef<[u8]>) -> Vec<u8> {
124 let mut input_file = NamedTempFile::new().unwrap();
125 let mut output_file = NamedTempFile::new().unwrap();
126 input_file.write_all(wat.as_ref()).unwrap();
127 Command::new("wat2wasm")
128 .args(&[
129 input_file.path().to_str().unwrap(),
130 "-o",
131 output_file.path().to_str().unwrap(),
132 ])
133 .output()
134 .unwrap();
135 let mut wasm = Vec::new();
136 output_file.read_to_end(&mut wasm).unwrap();
137 wasm
138 }
139
140 fn get_module_from_wasm(code: &[u8]) -> Module {
141 match deserialize_buffer(code) {
142 Ok(deserialized) => deserialized,
143 Err(_) => panic!("Cannot deserialized"),
144 }
145 }
146
147 #[test]
148 fn test_inject_memory_ok() {
149 let wasm = wat2wasm(r#"(module (memory 1))"#);
150 let module = get_module_from_wasm(&wasm);
151 assert_matches!(inject_memory(module), Ok(_));
152 }
153
154 #[test]
155 fn test_inject_memory_no_memory() {
156 let wasm = wat2wasm("(module)");
157 let module = get_module_from_wasm(&wasm);
158 assert_eq!(inject_memory(module), Err(Error::BadMemorySectionError));
159 }
160
161 #[test]
162 fn test_inject_memory_two_memories() {
163 let wasm = hex::decode(concat!(
166 "0061736d", "01000000", "05", "05", "02", "0009", "0009", ))
174 .unwrap();
175 let r = compile(&wasm);
176 assert_eq!(r, Err(Error::ValidationError));
177 }
178
179 #[test]
180 fn test_inject_memory_initial_size() {
181 let wasm_ok = wat2wasm("(module (memory 512))");
182 let module = get_module_from_wasm(&wasm_ok);
183 assert_matches!(inject_memory(module), Ok(_));
184 let wasm_too_big = wat2wasm("(module (memory 513))");
185 let module = get_module_from_wasm(&wasm_too_big);
186 assert_eq!(inject_memory(module), Err(Error::BadMemorySectionError));
187 }
188
189 #[test]
190 fn test_inject_memory_maximum_size() {
191 let wasm = wat2wasm("(module (memory 1 5))");
192 let module = get_module_from_wasm(&wasm);
193 assert_eq!(inject_memory(module), Err(Error::BadMemorySectionError));
194 }
195
196 #[test]
197 fn test_inject_stack_height() {
198 let wasm = wat2wasm(
199 r#"(module
200 (func
201 (local $idx i32)
202 (local.set $idx (i32.const 0))
203 (block
204 (loop
205 (local.set $idx (local.get $idx) (i32.const 1) (i32.add) )
206 (br_if 0 (i32.lt_u (local.get $idx) (i32.const 1000000000)))
207 )
208 )
209 )
210 (func (;"execute": Resolves with result "beeb";)
211 )
212 (memory 17)
213 (data (i32.const 1048576) "beeb") (;str = "beeb";)
214 (export "prepare" (func 0))
215 (export "execute" (func 1)))
216 "#,
217 );
218 let module = inject_stack_height(get_module_from_wasm(&wasm)).unwrap();
219 let wasm = serialize(module).unwrap();
220 let expected = wat2wasm(
221 r#"(module
222 (type (;0;) (func))
223 (func (;0;) (type 0)
224 (local i32)
225 i32.const 0
226 local.set 0
227 block ;; label = @1
228 loop ;; label = @2
229 local.get 0
230 i32.const 1
231 i32.add
232 local.set 0
233 local.get 0
234 i32.const 1000000000
235 i32.lt_u
236 br_if 0 (;@2;)
237 end
238 end)
239 (func (;1;) (type 0))
240 (func (;2;) (type 0)
241 global.get 0
242 i32.const 5
243 i32.add
244 global.set 0
245 global.get 0
246 i32.const 16384
247 i32.gt_u
248 if ;; label = @1
249 unreachable
250 end
251 call 0
252 global.get 0
253 i32.const 5
254 i32.sub
255 global.set 0)
256 (func (;3;) (type 0)
257 global.get 0
258 i32.const 2
259 i32.add
260 global.set 0
261 global.get 0
262 i32.const 16384
263 i32.gt_u
264 if ;; label = @1
265 unreachable
266 end
267 call 1
268 global.get 0
269 i32.const 2
270 i32.sub
271 global.set 0)
272 (memory (;0;) 17)
273 (global (;0;) (mut i32) (i32.const 0))
274 (export "prepare" (func 2))
275 (export "execute" (func 3))
276 (data (;0;) (i32.const 1048576) "beeb"))"#,
277 );
278 assert_eq!(wasm, expected);
279 }
280
281 #[test]
282 fn test_check_wasm_imports() {
283 let wasm = wat2wasm(
284 r#"(module
285 (type (func (param i64 i64 i64 i64) (result i64)))
286 (import "env" "beeb" (func (type 0))))"#,
287 );
288 let module = get_module_from_wasm(&wasm);
289 assert_eq!(check_wasm_imports(&module), Err(Error::InvalidImportsError));
290 let wasm = wat2wasm(
291 r#"(module
292 (type (func (param i64 i64 i64 i64) (result i64)))
293 (import "env" "ask_external_data" (func (type 0))))"#,
294 );
295 let module = get_module_from_wasm(&wasm);
296 assert_eq!(check_wasm_imports(&module), Ok(()));
297 }
298
299 #[test]
300 fn test_check_wasm_exports() {
301 let wasm = wat2wasm(
302 r#"(module
303 (func $execute (export "execute")))"#,
304 );
305 let module = get_module_from_wasm(&wasm);
306 assert_eq!(check_wasm_exports(&module), Err(Error::InvalidExportsError));
307 let wasm = wat2wasm(
308 r#"(module
309 (func $prepare (export "prepare")))"#,
310 );
311 let module = get_module_from_wasm(&wasm);
312 assert_eq!(check_wasm_exports(&module), Err(Error::InvalidExportsError));
313 let wasm = wat2wasm(
314 r#"(module
315 (func $execute (export "execute"))
316 (func $prepare (export "prepare"))
317 )"#,
318 );
319 let module = get_module_from_wasm(&wasm);
320 assert_eq!(check_wasm_exports(&module), Ok(()));
321 }
322
323 #[test]
324 fn test_compile() {
325 let wasm = wat2wasm(
326 r#"(module
327 (type (func (param i64 i64 i64 i64) (result i64)))
328 (import "env" "ask_external_data" (func (type 0)))
329 (func
330 (local $idx i32)
331 (local.set $idx (i32.const 0))
332 (block
333 (loop
334 (local.set $idx (local.get $idx) (i32.const 1) (i32.add) )
335 (br_if 0 (i32.lt_u (local.get $idx) (i32.const 1000000000)))
336 )
337 )
338 )
339 (func (;"execute": Resolves with result "beeb";)
340 )
341 (memory 17)
342 (data (i32.const 1048576) "beeb") (;str = "beeb";)
343 (export "prepare" (func 0))
344 (export "execute" (func 1)))
345 "#,
346 );
347 let code = compile(&wasm).unwrap();
348 let expected = wat2wasm(
349 r#"(module
350 (type (;0;) (func (param i64 i64 i64 i64) (result i64)))
351 (type (;1;) (func))
352 (import "env" "ask_external_data" (func (;0;) (type 0)))
353 (func (;1;) (type 1)
354 (local i32)
355 i32.const 0
356 local.set 0
357 block ;; label = @1
358 loop ;; label = @2
359 local.get 0
360 i32.const 1
361 i32.add
362 local.set 0
363 local.get 0
364 i32.const 1000000000
365 i32.lt_u
366 br_if 0 (;@2;)
367 end
368 end)
369 (func (;2;) (type 1))
370 (func (;3;) (type 1)
371 global.get 0
372 i32.const 5
373 i32.add
374 global.set 0
375 global.get 0
376 i32.const 16384
377 i32.gt_u
378 if ;; label = @1
379 unreachable
380 end
381 call 1
382 global.get 0
383 i32.const 5
384 i32.sub
385 global.set 0)
386 (memory (;0;) 17 512)
387 (global (;0;) (mut i32) (i32.const 0))
388 (export "prepare" (func 0))
389 (export "execute" (func 3))
390 (data (;0;) (i32.const 1048576) "beeb"))"#,
391 );
392 assert_eq!(code, expected);
393 }
394}