1use parity_wasm::elements::{ExportEntry, ExportSection, Internal, Module, Section};
2
3use super::{ChiselModule, ModuleError, ModuleKind, ModulePreset, ModuleTranslator};
4
5pub struct RemapStart;
6
7impl ModulePreset for RemapStart {
8 fn with_preset(preset: &str) -> Result<Self, ModuleError> {
9 match preset {
10 "ewasm" => Ok(RemapStart {}),
12 _ => Err(ModuleError::NotSupported),
13 }
14 }
15}
16
17impl<'a> ChiselModule<'a> for RemapStart {
18 type ObjectReference = &'a dyn ModuleTranslator;
19
20 fn id(&'a self) -> String {
21 "remapstart".to_string()
22 }
23
24 fn kind(&'a self) -> ModuleKind {
25 ModuleKind::Translator
26 }
27
28 fn as_abstract(&'a self) -> Self::ObjectReference {
29 self as Self::ObjectReference
30 }
31}
32
33impl ModuleTranslator for RemapStart {
34 fn translate_inplace(&self, module: &mut Module) -> Result<bool, ModuleError> {
35 Ok(remap_start(module))
36 }
37
38 fn translate(&self, module: &Module) -> Result<Option<Module>, ModuleError> {
39 let mut ret = module.clone();
40 if remap_start(&mut ret) {
41 Ok(Some(ret))
42 } else {
43 Ok(None)
44 }
45 }
46}
47
48fn remap_or_export_main(module: &mut Module, export_name: &str, func_idx: u32) {
50 let new_func_export = ExportEntry::new(export_name.to_string(), Internal::Function(func_idx));
51
52 if let Some(export_section) = module.export_section_mut() {
53 let export_section = export_section.entries_mut();
54 if let Some(main_export_loc) = export_section
57 .iter_mut()
58 .position(|e| e.field() == export_name)
59 {
60 export_section[main_export_loc] = new_func_export;
61 } else {
62 export_section.push(new_func_export);
63 }
64 } else {
65 let new_export_section =
66 Section::Export(ExportSection::with_entries(vec![new_func_export]));
67
68 module
70 .insert_section(new_export_section)
71 .expect("insert_section should not fail");
72 }
73}
74
75fn remap_start(module: &mut Module) -> bool {
76 if let Some(start_func_idx) = module.start_section() {
77 remap_or_export_main(module, "main", start_func_idx);
80
81 module.clear_start_section();
83
84 true
85 } else {
86 false
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use rustc_hex::FromHex;
93
94 use super::*;
95 use crate::{ModulePreset, ModuleTranslator};
96
97 #[test]
98 fn remapstart_mutation() {
99 let wasm: Vec<u8> = FromHex::from_hex(
111 "0061736d0100000001080260017e0060
112000002170103656e760f657468657265756d5f75736547617300000303020101050301000107110
1132046d61696e0001066d656d6f727902000801020a070202000b02000b0020046e616d65010e0201
114046d61696e02056d61696e320209030001000001000200",
115 )
116 .unwrap();
117
118 let mut module = Module::from_bytes(&wasm).unwrap();
119 module = module.parse_names().unwrap();
120 assert!(module.names_section().is_some());
121 let start_idx = module
122 .start_section()
123 .expect("Module missing start function");
124
125 let new = RemapStart::with_preset("ewasm")
126 .unwrap()
127 .translate(&module)
128 .expect("Module internal error")
129 .expect("new module not returned");
130
131 assert!(
132 new.start_section().is_none(),
133 "start section wasn't removed"
134 );
135 assert!(new
136 .export_section()
137 .expect("Module missing export section")
138 .entries()
139 .iter()
140 .find(|e| e.field() == String::from("main")
141 && *e.internal() == Internal::Function(start_idx))
142 .is_some());
143 }
144
145 #[test]
146 fn remapstart_no_mutation() {
147 let wasm: Vec<u8> = FromHex::from_hex(
156 "0061736d0100000001080260017e0060
157 000002170103656e760f657468657265756d5f757365476173000003020101050301000
158 1071102046d61696e0001066d656d6f727902000a040102000b",
159 )
160 .unwrap();
161
162 let module = Module::from_bytes(&wasm).unwrap();
163 let new = RemapStart::with_preset("ewasm")
164 .unwrap()
165 .translate(&module)
166 .expect("Module internal error");
167
168 assert!(new.is_none());
169 }
170
171 #[test]
172 fn remapstart_inplace_mutation() {
173 let wasm: Vec<u8> = FromHex::from_hex(
185 "0061736d0100000001080260017e0060
186000002170103656e760f657468657265756d5f75736547617300000303020101050301000107110
1872046d61696e0001066d656d6f727902000801020a070202000b02000b0020046e616d65010e0201
188046d61696e02056d61696e320209030001000001000200",
189 )
190 .unwrap();
191
192 let mut module = Module::from_bytes(&wasm).unwrap();
193 module = module.parse_names().unwrap();
194 assert!(module.names_section().is_some());
195
196 let res = RemapStart::with_preset("ewasm")
197 .unwrap()
198 .translate_inplace(&mut module)
199 .unwrap();
200
201 assert!(res, "module was not modified");
202 assert!(
203 module.start_section().is_none(),
204 "start section wasn't removed"
205 );
206 }
207
208 #[test]
209 fn remapstart_inplace_no_mutation() {
210 let wasm: Vec<u8> = FromHex::from_hex(
219 "0061736d0100000001080260017e0060
220000002170103656e760f657468657265756d5f75736547617300000302010105030100010711020
22146d61696e0001066d656d6f727902000a040102000b",
222 )
223 .unwrap();
224
225 let mut module = Module::from_bytes(&wasm).unwrap();
226 let res = RemapStart::with_preset("ewasm")
227 .unwrap()
228 .translate_inplace(&mut module)
229 .unwrap();
230
231 assert!(!res, "module was modified");
232 }
233
234 #[test]
235 fn remapstart_mutation_no_exports() {
236 let wasm: Vec<u8> = FromHex::from_hex(
246 "0061736d0100000001080260017e0060000002170103656e760f657468657265756d5f7573654761730000030302010105030100010801010a070202000b02000b",
247 )
248 .unwrap();
249
250 let mut module = Module::from_bytes(&wasm).unwrap();
251 let res = RemapStart::with_preset("ewasm")
252 .unwrap()
253 .translate_inplace(&mut module)
254 .unwrap();
255
256 assert!(res, "module was not modified");
257 assert!(
258 module.export_section().is_some(),
259 "export section does not exist"
260 );
261 }
262
263 #[test]
264 fn export_section_exists_but_no_main() {
265 let wasm: Vec<u8> = FromHex::from_hex(
274 "0061736d0100000001080260017e0060000002170103656e760f657468657265756d5f7573654761730000030201010503010001070a01066d656d6f727902000801010a040102000b"
275 ).unwrap();
276 let mut module = Module::from_bytes(&wasm).unwrap();
277 let remapper = RemapStart::with_preset("ewasm").expect("Can't fail");
278
279 let res = remapper.translate_inplace(&mut module);
280 assert!(res.is_ok());
281 let mutated = res.unwrap();
282 assert_eq!(mutated, true);
283 assert!(module.export_section().is_some());
284 assert!(module.start_section().is_none());
285 assert!(module
286 .export_section()
287 .unwrap()
288 .entries()
289 .iter()
290 .find(|e| e.field() == "main")
291 .is_some());
292 }
293}