1use crate::std::collections::BTreeSet as Set;
2use crate::std::{mem, vec::Vec};
3
4use crate::symbols::{expand_symbols, push_code_symbols, resolve_function, Symbol};
5use casper_wasm::elements;
6use log::trace;
7
8#[derive(Debug)]
9pub enum Error {
10 NoExportSection,
13}
14
15pub fn optimize(
16 module: &mut elements::Module, used_exports: Vec<&str>, ) -> Result<(), Error> {
19 let module_temp = mem::take(module);
25 let module_temp = module_temp
26 .parse_names()
27 .unwrap_or_else(|(_err, module)| module);
28 *module = module_temp;
29
30 let mut stay: Set<_> = module
32 .export_section()
33 .ok_or(Error::NoExportSection)?
34 .entries()
35 .iter()
36 .enumerate()
37 .filter_map(|(index, entry)| {
38 if used_exports.iter().any(|e| *e == entry.field()) {
39 Some(Symbol::Export(index))
40 } else {
41 None
42 }
43 })
44 .collect();
45
46 if let Some(ss) = module.start_section() {
48 stay.insert(resolve_function(module, ss));
49 }
50
51 let mut init_symbols = Vec::new();
53 if let Some(data_section) = module.data_section() {
54 for segment in data_section.entries() {
55 push_code_symbols(
56 module,
57 segment
58 .offset()
59 .as_ref()
60 .expect("parity-wasm is compiled without bulk-memory operations")
61 .code(),
62 &mut init_symbols,
63 );
64 }
65 }
66 if let Some(elements_section) = module.elements_section() {
67 for segment in elements_section.entries() {
68 push_code_symbols(
69 module,
70 segment
71 .offset()
72 .as_ref()
73 .expect("parity-wasm is compiled without bulk-memory operations")
74 .code(),
75 &mut init_symbols,
76 );
77 for func_index in segment.members() {
78 stay.insert(resolve_function(module, *func_index));
79 }
80 }
81 }
82
83 stay.extend(init_symbols.drain(..));
84
85 expand_symbols(module, &mut stay);
88
89 for symbol in stay.iter() {
90 trace!("symbol to stay: {:?}", symbol);
91 }
92
93 let mut eliminated_funcs = Vec::new();
95 let mut eliminated_globals = Vec::new();
96 let mut eliminated_types = Vec::new();
97
98 let mut index = 0;
100 let mut old_index = 0;
101
102 loop {
103 if type_section(module)
104 .map(|section| section.types_mut().len())
105 .unwrap_or(0)
106 == index
107 {
108 break;
109 }
110
111 if stay.contains(&Symbol::Type(old_index)) {
112 index += 1;
113 } else {
114 type_section(module)
115 .expect("If type section does not exists, the loop will break at the beginning of first iteration")
116 .types_mut().remove(index);
117 eliminated_types.push(old_index);
118 trace!("Eliminated type({})", old_index);
119 }
120 old_index += 1;
121 }
122
123 let mut top_funcs = 0;
125 let mut top_globals = 0;
126 index = 0;
127 old_index = 0;
128
129 if let Some(imports) = import_section(module) {
130 loop {
131 let mut remove = false;
132 match imports.entries()[index].external() {
133 elements::External::Function(_) => {
134 if stay.contains(&Symbol::Import(old_index)) {
135 index += 1;
136 } else {
137 remove = true;
138 eliminated_funcs.push(top_funcs);
139 trace!(
140 "Eliminated import({}) func({}, {})",
141 old_index,
142 top_funcs,
143 imports.entries()[index].field()
144 );
145 }
146 top_funcs += 1;
147 }
148 elements::External::Global(_) => {
149 if stay.contains(&Symbol::Import(old_index)) {
150 index += 1;
151 } else {
152 remove = true;
153 eliminated_globals.push(top_globals);
154 trace!(
155 "Eliminated import({}) global({}, {})",
156 old_index,
157 top_globals,
158 imports.entries()[index].field()
159 );
160 }
161 top_globals += 1;
162 }
163 _ => {
164 index += 1;
165 }
166 }
167 if remove {
168 imports.entries_mut().remove(index);
169 }
170
171 old_index += 1;
172
173 if index == imports.entries().len() {
174 break;
175 }
176 }
177 }
178
179 if let Some(globals) = global_section(module) {
181 index = 0;
182 old_index = 0;
183
184 loop {
185 if globals.entries_mut().len() == index {
186 break;
187 }
188 if stay.contains(&Symbol::Global(old_index)) {
189 index += 1;
190 } else {
191 globals.entries_mut().remove(index);
192 eliminated_globals.push(top_globals + old_index);
193 trace!("Eliminated global({})", top_globals + old_index);
194 }
195 old_index += 1;
196 }
197 }
198
199 if function_section(module).is_some() && code_section(module).is_some() {
201 index = 0;
202 old_index = 0;
203
204 loop {
205 if function_section(module)
206 .expect("Functons section to exist")
207 .entries_mut()
208 .len()
209 == index
210 {
211 break;
212 }
213 if stay.contains(&Symbol::Function(old_index)) {
214 index += 1;
215 } else {
216 function_section(module)
217 .expect("Functons section to exist")
218 .entries_mut()
219 .remove(index);
220 code_section(module)
221 .expect("Code section to exist")
222 .bodies_mut()
223 .remove(index);
224
225 eliminated_funcs.push(top_funcs + old_index);
226 trace!("Eliminated function({})", top_funcs + old_index);
227 }
228 old_index += 1;
229 }
230 }
231
232 {
234 let exports = export_section(module).ok_or(Error::NoExportSection)?;
235
236 index = 0;
237 old_index = 0;
238
239 loop {
240 if exports.entries_mut().len() == index {
241 break;
242 }
243 if stay.contains(&Symbol::Export(old_index)) {
244 index += 1;
245 } else {
246 trace!(
247 "Eliminated export({}, {})",
248 old_index,
249 exports.entries_mut()[index].field()
250 );
251 exports.entries_mut().remove(index);
252 }
253 old_index += 1;
254 }
255 }
256
257 if !eliminated_globals.is_empty()
258 || !eliminated_funcs.is_empty()
259 || !eliminated_types.is_empty()
260 {
261 eliminated_globals.sort_unstable();
265 eliminated_funcs.sort_unstable();
266 eliminated_types.sort_unstable();
267
268 for section in module.sections_mut() {
269 match section {
270 elements::Section::Start(func_index) if !eliminated_funcs.is_empty() => {
271 let totalle = eliminated_funcs
272 .iter()
273 .take_while(|i| (**i as u32) < *func_index)
274 .count();
275 *func_index -= totalle as u32;
276 }
277 elements::Section::Function(function_section) if !eliminated_types.is_empty() => {
278 for func_signature in function_section.entries_mut() {
279 let totalle = eliminated_types
280 .iter()
281 .take_while(|i| (**i as u32) < func_signature.type_ref())
282 .count();
283 *func_signature.type_ref_mut() -= totalle as u32;
284 }
285 }
286 elements::Section::Import(import_section) if !eliminated_types.is_empty() => {
287 for import_entry in import_section.entries_mut() {
288 if let elements::External::Function(type_ref) = import_entry.external_mut()
289 {
290 let totalle = eliminated_types
291 .iter()
292 .take_while(|i| (**i as u32) < *type_ref)
293 .count();
294 *type_ref -= totalle as u32;
295 }
296 }
297 }
298 elements::Section::Code(code_section)
299 if !eliminated_globals.is_empty() || !eliminated_funcs.is_empty() =>
300 {
301 for func_body in code_section.bodies_mut() {
302 if !eliminated_funcs.is_empty() {
303 update_call_index(func_body.code_mut(), &eliminated_funcs);
304 }
305 if !eliminated_globals.is_empty() {
306 update_global_index(
307 func_body.code_mut().elements_mut(),
308 &eliminated_globals,
309 )
310 }
311 if !eliminated_types.is_empty() {
312 update_type_index(func_body.code_mut(), &eliminated_types)
313 }
314 }
315 }
316 elements::Section::Export(export_section) => {
317 for export in export_section.entries_mut() {
318 match export.internal_mut() {
319 elements::Internal::Function(func_index) => {
320 let totalle = eliminated_funcs
321 .iter()
322 .take_while(|i| (**i as u32) < *func_index)
323 .count();
324 *func_index -= totalle as u32;
325 }
326 elements::Internal::Global(global_index) => {
327 let totalle = eliminated_globals
328 .iter()
329 .take_while(|i| (**i as u32) < *global_index)
330 .count();
331 *global_index -= totalle as u32;
332 }
333 _ => {}
334 }
335 }
336 }
337 elements::Section::Global(global_section) => {
338 for global_entry in global_section.entries_mut() {
339 update_global_index(
340 global_entry.init_expr_mut().code_mut(),
341 &eliminated_globals,
342 )
343 }
344 }
345 elements::Section::Data(data_section) => {
346 for segment in data_section.entries_mut() {
347 update_global_index(
348 segment
349 .offset_mut()
350 .as_mut()
351 .expect("parity-wasm is compiled without bulk-memory operations")
352 .code_mut(),
353 &eliminated_globals,
354 )
355 }
356 }
357 elements::Section::Element(elements_section) => {
358 for segment in elements_section.entries_mut() {
359 update_global_index(
360 segment
361 .offset_mut()
362 .as_mut()
363 .expect("parity-wasm is compiled without bulk-memory operations")
364 .code_mut(),
365 &eliminated_globals,
366 );
367 for func_index in segment.members_mut() {
369 let totalle = eliminated_funcs
370 .iter()
371 .take_while(|i| (**i as u32) < *func_index)
372 .count();
373 *func_index -= totalle as u32;
374 }
375 }
376 }
377 elements::Section::Name(name_section) => {
378 if let Some(func_name) = name_section.functions_mut() {
379 let mut func_name_map = mem::take(func_name.names_mut());
380 for index in &eliminated_funcs {
381 func_name_map.remove(*index as u32);
382 }
383 let updated_map = func_name_map
384 .into_iter()
385 .map(|(index, value)| {
386 let totalle = eliminated_funcs
387 .iter()
388 .take_while(|i| (**i as u32) < index)
389 .count() as u32;
390 (index - totalle, value)
391 })
392 .collect();
393 *func_name.names_mut() = updated_map;
394 }
395
396 if let Some(local_name) = name_section.locals_mut() {
397 let mut local_names_map = mem::take(local_name.local_names_mut());
398 for index in &eliminated_funcs {
399 local_names_map.remove(*index as u32);
400 }
401 let updated_map = local_names_map
402 .into_iter()
403 .map(|(index, value)| {
404 let totalle = eliminated_funcs
405 .iter()
406 .take_while(|i| (**i as u32) < index)
407 .count() as u32;
408 (index - totalle, value)
409 })
410 .collect();
411 *local_name.local_names_mut() = updated_map;
412 }
413 }
414 _ => {}
415 }
416 }
417 }
418
419 module
421 .sections_mut()
422 .retain(|section| !matches!(section, elements::Section::Custom(_)));
423
424 Ok(())
425}
426
427pub fn update_call_index(instructions: &mut elements::Instructions, eliminated_indices: &[usize]) {
428 use casper_wasm::elements::Instruction::*;
429 for instruction in instructions.elements_mut().iter_mut() {
430 if let Call(call_index) = instruction {
431 let totalle = eliminated_indices
432 .iter()
433 .take_while(|i| (**i as u32) < *call_index)
434 .count();
435 trace!(
436 "rewired call {} -> call {}",
437 *call_index,
438 *call_index - totalle as u32
439 );
440 *call_index -= totalle as u32;
441 }
442 }
443}
444
445pub fn update_global_index(
447 instructions: &mut [elements::Instruction],
448 eliminated_indices: &[usize],
449) {
450 use casper_wasm::elements::Instruction::*;
451 for instruction in instructions.iter_mut() {
452 match instruction {
453 GetGlobal(index) | SetGlobal(index) => {
454 let totalle = eliminated_indices
455 .iter()
456 .take_while(|i| (**i as u32) < *index)
457 .count();
458 trace!(
459 "rewired global {} -> global {}",
460 *index,
461 *index - totalle as u32
462 );
463 *index -= totalle as u32;
464 }
465 _ => {}
466 }
467 }
468}
469
470pub fn update_type_index(instructions: &mut elements::Instructions, eliminated_indices: &[usize]) {
472 use casper_wasm::elements::Instruction::*;
473 for instruction in instructions.elements_mut().iter_mut() {
474 if let CallIndirect(call_index, _) = instruction {
475 let totalle = eliminated_indices
476 .iter()
477 .take_while(|i| (**i as u32) < *call_index)
478 .count();
479 trace!(
480 "rewired call_indrect {} -> call_indirect {}",
481 *call_index,
482 *call_index - totalle as u32
483 );
484 *call_index -= totalle as u32;
485 }
486 }
487}
488
489pub fn import_section(module: &mut elements::Module) -> Option<&mut elements::ImportSection> {
490 for section in module.sections_mut() {
491 if let elements::Section::Import(sect) = section {
492 return Some(sect);
493 }
494 }
495 None
496}
497
498pub fn global_section(module: &mut elements::Module) -> Option<&mut elements::GlobalSection> {
499 for section in module.sections_mut() {
500 if let elements::Section::Global(sect) = section {
501 return Some(sect);
502 }
503 }
504 None
505}
506
507pub fn function_section(module: &mut elements::Module) -> Option<&mut elements::FunctionSection> {
508 for section in module.sections_mut() {
509 if let elements::Section::Function(sect) = section {
510 return Some(sect);
511 }
512 }
513 None
514}
515
516pub fn code_section(module: &mut elements::Module) -> Option<&mut elements::CodeSection> {
517 for section in module.sections_mut() {
518 if let elements::Section::Code(sect) = section {
519 return Some(sect);
520 }
521 }
522 None
523}
524
525pub fn export_section(module: &mut elements::Module) -> Option<&mut elements::ExportSection> {
526 for section in module.sections_mut() {
527 if let elements::Section::Export(sect) = section {
528 return Some(sect);
529 }
530 }
531 None
532}
533
534pub fn type_section(module: &mut elements::Module) -> Option<&mut elements::TypeSection> {
535 for section in module.sections_mut() {
536 if let elements::Section::Type(sect) = section {
537 return Some(sect);
538 }
539 }
540 None
541}
542
543#[cfg(test)]
544mod tests {
545
546 use super::*;
547 use casper_wasm::{builder, elements};
548
549 #[test]
555 fn empty() {
556 let mut module = builder::module().build();
557 let result = optimize(&mut module, vec!["_call"]);
558
559 assert!(result.is_err());
560 }
561
562 #[test]
568 fn minimal() {
569 let mut module = builder::module()
570 .function()
571 .signature()
572 .param()
573 .i32()
574 .build()
575 .build()
576 .function()
577 .signature()
578 .param()
579 .i32()
580 .param()
581 .i32()
582 .build()
583 .build()
584 .export()
585 .field("_call")
586 .internal()
587 .func(0)
588 .build()
589 .export()
590 .field("_random")
591 .internal()
592 .func(1)
593 .build()
594 .build();
595 assert_eq!(
596 module
597 .export_section()
598 .expect("export section to be generated")
599 .entries()
600 .len(),
601 2
602 );
603
604 optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
605
606 assert_eq!(
607 1,
608 module
609 .export_section()
610 .expect("export section to be generated")
611 .entries()
612 .len(),
613 "There should only 1 (one) export entry in the optimized module"
614 );
615
616 assert_eq!(
617 1,
618 module
619 .function_section()
620 .expect("functions section to be generated")
621 .entries()
622 .len(),
623 "There should 2 (two) functions in the optimized module"
624 );
625 }
626
627 #[test]
632 fn globals() {
633 let mut module = builder::module()
634 .global()
635 .value_type()
636 .i32()
637 .build()
638 .function()
639 .signature()
640 .param()
641 .i32()
642 .build()
643 .body()
644 .with_instructions(elements::Instructions::new(vec![
645 elements::Instruction::GetGlobal(0),
646 elements::Instruction::End,
647 ]))
648 .build()
649 .build()
650 .export()
651 .field("_call")
652 .internal()
653 .func(0)
654 .build()
655 .build();
656
657 optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
658
659 assert_eq!(
660 1,
661 module.global_section().expect("global section to be generated").entries().len(),
662 "There should 1 (one) global entry in the optimized module, since _call function uses it"
663 );
664 }
665
666 #[test]
672 fn globals_2() {
673 let mut module = builder::module()
674 .global()
675 .value_type()
676 .i32()
677 .build()
678 .global()
679 .value_type()
680 .i64()
681 .build()
682 .global()
683 .value_type()
684 .f32()
685 .build()
686 .function()
687 .signature()
688 .param()
689 .i32()
690 .build()
691 .body()
692 .with_instructions(elements::Instructions::new(vec![
693 elements::Instruction::GetGlobal(1),
694 elements::Instruction::End,
695 ]))
696 .build()
697 .build()
698 .export()
699 .field("_call")
700 .internal()
701 .func(0)
702 .build()
703 .build();
704
705 optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
706
707 assert_eq!(
708 1,
709 module.global_section().expect("global section to be generated").entries().len(),
710 "There should 1 (one) global entry in the optimized module, since _call function uses only one"
711 );
712 }
713
714 #[test]
720 fn call_ref() {
721 let mut module = builder::module()
722 .function()
723 .signature()
724 .param()
725 .i32()
726 .build()
727 .body()
728 .with_instructions(elements::Instructions::new(vec![
729 elements::Instruction::Call(1),
730 elements::Instruction::End,
731 ]))
732 .build()
733 .build()
734 .function()
735 .signature()
736 .param()
737 .i32()
738 .param()
739 .i32()
740 .build()
741 .build()
742 .export()
743 .field("_call")
744 .internal()
745 .func(0)
746 .build()
747 .export()
748 .field("_random")
749 .internal()
750 .func(1)
751 .build()
752 .build();
753 assert_eq!(
754 module
755 .export_section()
756 .expect("export section to be generated")
757 .entries()
758 .len(),
759 2
760 );
761
762 optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
763
764 assert_eq!(
765 1,
766 module
767 .export_section()
768 .expect("export section to be generated")
769 .entries()
770 .len(),
771 "There should only 1 (one) export entry in the optimized module"
772 );
773
774 assert_eq!(
775 2,
776 module
777 .function_section()
778 .expect("functions section to be generated")
779 .entries()
780 .len(),
781 "There should 2 (two) functions in the optimized module"
782 );
783 }
784
785 #[test]
789 fn call_indirect() {
790 let mut module = builder::module()
791 .function()
792 .signature()
793 .param()
794 .i32()
795 .param()
796 .i32()
797 .build()
798 .build()
799 .function()
800 .signature()
801 .param()
802 .i32()
803 .param()
804 .i32()
805 .param()
806 .i32()
807 .build()
808 .build()
809 .function()
810 .signature()
811 .param()
812 .i32()
813 .build()
814 .body()
815 .with_instructions(elements::Instructions::new(vec![
816 elements::Instruction::CallIndirect(1, 0),
817 elements::Instruction::End,
818 ]))
819 .build()
820 .build()
821 .export()
822 .field("_call")
823 .internal()
824 .func(2)
825 .build()
826 .build();
827
828 optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
829
830 assert_eq!(
831 2,
832 module
833 .type_section()
834 .expect("type section to be generated")
835 .types()
836 .len(),
837 "There should 2 (two) types left in the module, 1 for indirect call and one for _call"
838 );
839
840 let indirect_opcode = &module
841 .code_section()
842 .expect("code section to be generated")
843 .bodies()[0]
844 .code()
845 .elements()[0];
846 match *indirect_opcode {
847 elements::Instruction::CallIndirect(0, 0) => {}
848 _ => {
849 panic!(
850 "Expected call_indirect to use index 0 after optimization, since previois 0th was eliminated, but got {:?}",
851 indirect_opcode
852 );
853 }
854 }
855 }
856}