multiversx_sc_meta_lib/tools/wasm_extractor/
extractor.rs

1use colored::Colorize;
2use std::{
3    collections::{HashMap, HashSet},
4    fs,
5    path::{Path, PathBuf},
6};
7use wasmparser::{
8    BinaryReaderError, DataSectionReader, ExportSectionReader, FunctionBody, ImportSectionReader,
9    Operator, Parser, Payload,
10};
11
12use crate::{ei::EIVersion, tools::CodeReport};
13
14use super::{
15    endpoint_info::{EndpointInfo, FunctionInfo},
16    report::WasmReport,
17    whitelisted_opcodes::{is_whitelisted, ERROR_FAIL_ALLOCATOR, WRITE_OP},
18};
19
20type CallGraph = HashMap<usize, FunctionInfo>;
21
22#[derive(Default, Debug, Clone)]
23pub struct WasmInfo {
24    pub call_graph: CallGraph,
25    pub write_index_functions: HashSet<usize>,
26    pub endpoints: HashMap<String, EndpointInfo>,
27    pub report: WasmReport,
28    pub data: Vec<u8>,
29}
30
31impl WasmInfo {
32    pub fn extract_wasm_report(
33        output_wasm_path: &PathBuf,
34        extract_imports_enabled: bool,
35        check_ei: Option<&EIVersion>,
36        endpoints: &HashMap<&str, bool>,
37    ) -> WasmReport {
38        let wasm_data = fs::read(output_wasm_path)
39            .expect("error occurred while extracting information from .wasm: file not found");
40
41        let wasm_info = WasmInfo::default()
42            .add_endpoints(endpoints)
43            .add_path(output_wasm_path)
44            .add_wasm_data(&wasm_data)
45            .populate_wasm_info(extract_imports_enabled, check_ei)
46            .expect("error occurred while extracting information from .wasm file");
47
48        wasm_info.report
49    }
50
51    pub(crate) fn populate_wasm_info(
52        self,
53        import_extraction_enabled: bool,
54        check_ei: Option<&EIVersion>,
55    ) -> Result<WasmInfo, BinaryReaderError> {
56        let parser = Parser::new(0);
57        let mut wasm_info = self.clone();
58
59        for payload in parser.parse_all(&self.data) {
60            match payload? {
61                Payload::ImportSection(import_section) => {
62                    wasm_info.process_imports(import_section, import_extraction_enabled);
63                    wasm_info.report.ei_check |= is_ei_valid(&wasm_info.report.imports, check_ei);
64                }
65                Payload::DataSection(data_section) => {
66                    wasm_info.report.code.has_allocator |=
67                        is_fail_allocator_triggered(data_section.clone());
68                    wasm_info.report.code.has_panic.max_severity(data_section);
69                }
70                Payload::CodeSectionEntry(code_section) => {
71                    wasm_info.report.memory_grow_flag |= is_mem_grow(&code_section);
72                    wasm_info.create_call_graph(code_section);
73                }
74                Payload::ExportSection(export_section) => {
75                    wasm_info.parse_export_section(export_section);
76                }
77                _ => (),
78            }
79        }
80
81        wasm_info.detect_write_operations_in_views();
82        wasm_info.detect_forbidden_opcodes();
83
84        Ok(wasm_info)
85    }
86
87    pub(crate) fn add_endpoints(self, endpoints: &HashMap<&str, bool>) -> Self {
88        let mut endpoints_map = HashMap::new();
89
90        for (name, readonly) in endpoints {
91            endpoints_map.insert(name.to_string(), EndpointInfo::default(*readonly));
92        }
93
94        WasmInfo {
95            endpoints: endpoints_map,
96            ..self
97        }
98    }
99
100    pub(crate) fn add_wasm_data(self, data: &[u8]) -> Self {
101        WasmInfo {
102            data: data.to_vec(),
103            ..self
104        }
105    }
106
107    fn add_path(self, path: &Path) -> Self {
108        WasmInfo {
109            report: WasmReport {
110                code: CodeReport {
111                    path: path.to_path_buf(),
112                    ..self.report.code
113                },
114                ..self.report
115            },
116            ..self
117        }
118    }
119
120    fn create_call_graph(&mut self, body: FunctionBody) {
121        let mut instructions_reader = body
122            .get_operators_reader()
123            .expect("Failed to get operators reader");
124
125        let mut function_info = FunctionInfo::new();
126        while let Ok(op) = instructions_reader.read() {
127            if let Operator::Call { function_index } = op {
128                let function_usize: usize = function_index.try_into().unwrap();
129                function_info.add_function_index(function_usize);
130            }
131
132            if !is_whitelisted(&op) {
133                let opcode = extract_opcode(op);
134                function_info.add_forbidden_opcode(opcode);
135            }
136        }
137
138        self.call_graph.insert(self.call_graph.len(), function_info);
139    }
140
141    fn process_imports(
142        &mut self,
143        import_section: ImportSectionReader,
144        import_extraction_enabled: bool,
145    ) {
146        for (index, import) in import_section.into_iter().flatten().enumerate() {
147            if import_extraction_enabled {
148                self.report.imports.push(import.name.to_string());
149            }
150            self.call_graph.insert(index, FunctionInfo::new());
151            if WRITE_OP.contains(&import.name) {
152                self.write_index_functions.insert(index);
153            }
154        }
155
156        self.report.imports.sort();
157    }
158
159    fn detect_write_operations_in_views(&mut self) {
160        let mut visited: HashSet<usize> = HashSet::new();
161
162        for index in get_view_endpoints_indexes(&self.endpoints) {
163            mark_write(self, index, &mut visited);
164        }
165
166        for (name, index) in get_view_endpoints(&self.endpoints) {
167            if self.write_index_functions.contains(&index) {
168                println!(
169                    "{} {}",
170                    "Write storage operation in VIEW endpoint:"
171                        .to_string()
172                        .red()
173                        .bold(),
174                    name.red().bold()
175                );
176            }
177        }
178    }
179
180    fn detect_forbidden_opcodes(&mut self) {
181        let mut visited: HashSet<usize> = HashSet::new();
182        for endpoint_info in self.endpoints.values_mut() {
183            mark_forbidden_functions(endpoint_info.index, &mut self.call_graph, &mut visited);
184            endpoint_info.forbidden_opcodes = self
185                .call_graph
186                .get(&endpoint_info.index)
187                .unwrap()
188                .forbidden_opcodes
189                .clone();
190        }
191
192        for (name, endpoint_info) in &self.endpoints {
193            if !endpoint_info.forbidden_opcodes.is_empty() {
194                self.report.forbidden_opcodes.insert(
195                    name.to_string(),
196                    endpoint_info.forbidden_opcodes.iter().cloned().collect(),
197                );
198
199                println!(
200                    "{}{}{} {}",
201                    "Forbidden opcodes detected in endpoint \""
202                        .to_string()
203                        .red()
204                        .bold(),
205                    name.red().bold(),
206                    "\". This are the opcodes:".to_string().red().bold(),
207                    self.report
208                        .forbidden_opcodes
209                        .get(name)
210                        .unwrap()
211                        .join(", ")
212                        .red()
213                        .bold()
214                );
215            }
216        }
217    }
218
219    fn parse_export_section(&mut self, export_section: ExportSectionReader) {
220        if self.endpoints.is_empty() {
221            return;
222        }
223
224        for export in export_section {
225            let export = export.expect("Failed to read export section");
226            if wasmparser::ExternalKind::Func == export.kind {
227                if let Some(endpoint) = self.endpoints.get_mut(export.name) {
228                    endpoint.set_index(export.index.try_into().unwrap());
229                }
230            }
231        }
232    }
233}
234
235pub(crate) fn get_view_endpoints_indexes(endpoints: &HashMap<String, EndpointInfo>) -> Vec<usize> {
236    endpoints
237        .values()
238        .filter(|endpoint_info| endpoint_info.readonly)
239        .map(|endpoint_info| endpoint_info.index)
240        .collect()
241}
242
243pub(crate) fn get_view_endpoints(
244    endpoints: &HashMap<String, EndpointInfo>,
245) -> HashMap<&str, usize> {
246    let mut view_endpoints = HashMap::new();
247
248    for (name, endpoint_info) in endpoints {
249        if endpoint_info.readonly {
250            view_endpoints.insert(name.as_str(), endpoint_info.index);
251        }
252    }
253
254    view_endpoints
255}
256
257fn is_fail_allocator_triggered(data_section: DataSectionReader) -> bool {
258    for data_fragment in data_section.into_iter().flatten() {
259        if data_fragment
260            .data
261            .windows(ERROR_FAIL_ALLOCATOR.len())
262            .any(|data| data == ERROR_FAIL_ALLOCATOR)
263        {
264            println!(
265                "{}",
266                "FailAllocator used while memory allocation is accessible in code. Contract may fail unexpectedly when memory allocation is attempted"
267                    .to_string()
268                    .red()
269                    .bold()
270            );
271            return true;
272        }
273    }
274
275    false
276}
277
278fn mark_write(wasm_info: &mut WasmInfo, func: usize, visited: &mut HashSet<usize>) {
279    // Return early to prevent cycles.
280    if visited.contains(&func) {
281        return;
282    }
283
284    visited.insert(func);
285
286    let callees: Vec<usize> = if let Some(callees) = wasm_info.call_graph.get(&func) {
287        callees.indexes.iter().cloned().collect()
288    } else {
289        return;
290    };
291
292    for callee in callees {
293        if wasm_info.write_index_functions.contains(&callee) {
294            wasm_info.write_index_functions.insert(func);
295        } else {
296            mark_write(wasm_info, callee, visited);
297            if wasm_info.write_index_functions.contains(&callee) {
298                wasm_info.write_index_functions.insert(func);
299            }
300        }
301    }
302}
303
304fn mark_forbidden_functions(func: usize, call_graph: &mut CallGraph, visited: &mut HashSet<usize>) {
305    // Return early to prevent cycles.
306    if visited.contains(&func) {
307        return;
308    }
309
310    visited.insert(func);
311
312    if let Some(function_info) = call_graph.get(&func) {
313        for index in function_info.indexes.clone() {
314            if !call_graph.get(&index).unwrap().forbidden_opcodes.is_empty() {
315                let index_forbidden_opcodes =
316                    call_graph.get(&index).unwrap().forbidden_opcodes.clone();
317
318                call_graph
319                    .get_mut(&func)
320                    .unwrap()
321                    .add_forbidden_opcodes(index_forbidden_opcodes);
322            } else {
323                mark_forbidden_functions(index, call_graph, visited);
324                if !call_graph.get(&index).unwrap().forbidden_opcodes.is_empty() {
325                    let index_forbidden_opcodes =
326                        call_graph.get(&index).unwrap().forbidden_opcodes.clone();
327
328                    call_graph
329                        .get_mut(&func)
330                        .unwrap()
331                        .add_forbidden_opcodes(index_forbidden_opcodes);
332                }
333            }
334        }
335    }
336}
337
338fn is_ei_valid(imports: &[String], check_ei: Option<&EIVersion>) -> bool {
339    if let Some(ei) = check_ei {
340        let mut num_errors = 0;
341        for import in imports {
342            if !ei.contains_vm_hook(import.as_str()) {
343                num_errors += 1;
344            }
345        }
346
347        if num_errors == 0 {
348            return true;
349        }
350    }
351
352    false
353}
354
355fn is_mem_grow(code_section: &FunctionBody) -> bool {
356    let mut instructions_reader = code_section
357        .get_operators_reader()
358        .expect("Failed to get operators reader");
359
360    while let Ok(op) = instructions_reader.read() {
361        if let Operator::MemoryGrow { mem: _ } = op {
362            return true;
363        }
364    }
365
366    false
367}
368
369fn extract_opcode(op: Operator) -> String {
370    let op_str = format!("{:?}", op);
371    let op_vec: Vec<&str> = op_str.split_whitespace().collect();
372
373    op_vec[0].to_owned()
374}