casper_wasm_utils/
optimizer.rs

1use crate::std::collections::BTreeSet as Set;
2use crate::std::{mem, vec::Vec};
3
4use crate::symbols::{expand_symbols, push_code_symbols, resolve_function, Symbol};
5use casper_wasm::elements;
6use log::trace;
7
8#[derive(Debug)]
9pub enum Error {
10    /// Since optimizer starts with export entries, export
11    ///   section is supposed to exist.
12    NoExportSection,
13}
14
15pub fn optimize(
16    module: &mut elements::Module, // Module to optimize
17    used_exports: Vec<&str>,       // List of only exports that will be usable after optimization
18) -> Result<(), Error> {
19    // WebAssembly exports optimizer
20    // Motivation: emscripten compiler backend compiles in many unused exports
21    //   which in turn compile in unused imports and leaves unused functions
22
23    // try to parse name section
24    let module_temp = mem::take(module);
25    let module_temp = module_temp
26        .parse_names()
27        .unwrap_or_else(|(_err, module)| module);
28    *module = module_temp;
29
30    // Algo starts from the top, listing all items that should stay
31    let mut stay: Set<_> = module
32        .export_section()
33        .ok_or(Error::NoExportSection)?
34        .entries()
35        .iter()
36        .enumerate()
37        .filter_map(|(index, entry)| {
38            if used_exports.iter().any(|e| *e == entry.field()) {
39                Some(Symbol::Export(index))
40            } else {
41                None
42            }
43        })
44        .collect();
45
46    // If there is start function in module, it should start
47    if let Some(ss) = module.start_section() {
48        stay.insert(resolve_function(module, ss));
49    }
50
51    // All symbols used in data/element segments are also should be preserved
52    let mut init_symbols = Vec::new();
53    if let Some(data_section) = module.data_section() {
54        for segment in data_section.entries() {
55            push_code_symbols(
56                module,
57                segment
58                    .offset()
59                    .as_ref()
60                    .expect("parity-wasm is compiled without bulk-memory operations")
61                    .code(),
62                &mut init_symbols,
63            );
64        }
65    }
66    if let Some(elements_section) = module.elements_section() {
67        for segment in elements_section.entries() {
68            push_code_symbols(
69                module,
70                segment
71                    .offset()
72                    .as_ref()
73                    .expect("parity-wasm is compiled without bulk-memory operations")
74                    .code(),
75                &mut init_symbols,
76            );
77            for func_index in segment.members() {
78                stay.insert(resolve_function(module, *func_index));
79            }
80        }
81    }
82
83    stay.extend(init_symbols.drain(..));
84
85    // Call function which will traverse the list recursively, filling stay with all symbols
86    // that are already used by those which already there
87    expand_symbols(module, &mut stay);
88
89    for symbol in stay.iter() {
90        trace!("symbol to stay: {:?}", symbol);
91    }
92
93    // Keep track of referreable symbols to rewire calls/globals
94    let mut eliminated_funcs = Vec::new();
95    let mut eliminated_globals = Vec::new();
96    let mut eliminated_types = Vec::new();
97
98    // First, iterate through types
99    let mut index = 0;
100    let mut old_index = 0;
101
102    loop {
103        if type_section(module)
104            .map(|section| section.types_mut().len())
105            .unwrap_or(0)
106            == index
107        {
108            break;
109        }
110
111        if stay.contains(&Symbol::Type(old_index)) {
112            index += 1;
113        } else {
114            type_section(module)
115                    .expect("If type section does not exists, the loop will break at the beginning of first iteration")
116                    .types_mut().remove(index);
117            eliminated_types.push(old_index);
118            trace!("Eliminated type({})", old_index);
119        }
120        old_index += 1;
121    }
122
123    // Second, iterate through imports
124    let mut top_funcs = 0;
125    let mut top_globals = 0;
126    index = 0;
127    old_index = 0;
128
129    if let Some(imports) = import_section(module) {
130        loop {
131            let mut remove = false;
132            match imports.entries()[index].external() {
133                elements::External::Function(_) => {
134                    if stay.contains(&Symbol::Import(old_index)) {
135                        index += 1;
136                    } else {
137                        remove = true;
138                        eliminated_funcs.push(top_funcs);
139                        trace!(
140                            "Eliminated import({}) func({}, {})",
141                            old_index,
142                            top_funcs,
143                            imports.entries()[index].field()
144                        );
145                    }
146                    top_funcs += 1;
147                }
148                elements::External::Global(_) => {
149                    if stay.contains(&Symbol::Import(old_index)) {
150                        index += 1;
151                    } else {
152                        remove = true;
153                        eliminated_globals.push(top_globals);
154                        trace!(
155                            "Eliminated import({}) global({}, {})",
156                            old_index,
157                            top_globals,
158                            imports.entries()[index].field()
159                        );
160                    }
161                    top_globals += 1;
162                }
163                _ => {
164                    index += 1;
165                }
166            }
167            if remove {
168                imports.entries_mut().remove(index);
169            }
170
171            old_index += 1;
172
173            if index == imports.entries().len() {
174                break;
175            }
176        }
177    }
178
179    // Third, iterate through globals
180    if let Some(globals) = global_section(module) {
181        index = 0;
182        old_index = 0;
183
184        loop {
185            if globals.entries_mut().len() == index {
186                break;
187            }
188            if stay.contains(&Symbol::Global(old_index)) {
189                index += 1;
190            } else {
191                globals.entries_mut().remove(index);
192                eliminated_globals.push(top_globals + old_index);
193                trace!("Eliminated global({})", top_globals + old_index);
194            }
195            old_index += 1;
196        }
197    }
198
199    // Forth, delete orphaned functions
200    if function_section(module).is_some() && code_section(module).is_some() {
201        index = 0;
202        old_index = 0;
203
204        loop {
205            if function_section(module)
206                .expect("Functons section to exist")
207                .entries_mut()
208                .len()
209                == index
210            {
211                break;
212            }
213            if stay.contains(&Symbol::Function(old_index)) {
214                index += 1;
215            } else {
216                function_section(module)
217                    .expect("Functons section to exist")
218                    .entries_mut()
219                    .remove(index);
220                code_section(module)
221                    .expect("Code section to exist")
222                    .bodies_mut()
223                    .remove(index);
224
225                eliminated_funcs.push(top_funcs + old_index);
226                trace!("Eliminated function({})", top_funcs + old_index);
227            }
228            old_index += 1;
229        }
230    }
231
232    // Fifth, eliminate unused exports
233    {
234        let exports = export_section(module).ok_or(Error::NoExportSection)?;
235
236        index = 0;
237        old_index = 0;
238
239        loop {
240            if exports.entries_mut().len() == index {
241                break;
242            }
243            if stay.contains(&Symbol::Export(old_index)) {
244                index += 1;
245            } else {
246                trace!(
247                    "Eliminated export({}, {})",
248                    old_index,
249                    exports.entries_mut()[index].field()
250                );
251                exports.entries_mut().remove(index);
252            }
253            old_index += 1;
254        }
255    }
256
257    if !eliminated_globals.is_empty()
258        || !eliminated_funcs.is_empty()
259        || !eliminated_types.is_empty()
260    {
261        // Finaly, rewire all calls, globals references and types to the new indices
262        //   (only if there is anything to do)
263        // When sorting primitives sorting unstable is faster without any difference in result.
264        eliminated_globals.sort_unstable();
265        eliminated_funcs.sort_unstable();
266        eliminated_types.sort_unstable();
267
268        for section in module.sections_mut() {
269            match section {
270                elements::Section::Start(func_index) if !eliminated_funcs.is_empty() => {
271                    let totalle = eliminated_funcs
272                        .iter()
273                        .take_while(|i| (**i as u32) < *func_index)
274                        .count();
275                    *func_index -= totalle as u32;
276                }
277                elements::Section::Function(function_section) if !eliminated_types.is_empty() => {
278                    for func_signature in function_section.entries_mut() {
279                        let totalle = eliminated_types
280                            .iter()
281                            .take_while(|i| (**i as u32) < func_signature.type_ref())
282                            .count();
283                        *func_signature.type_ref_mut() -= totalle as u32;
284                    }
285                }
286                elements::Section::Import(import_section) if !eliminated_types.is_empty() => {
287                    for import_entry in import_section.entries_mut() {
288                        if let elements::External::Function(type_ref) = import_entry.external_mut()
289                        {
290                            let totalle = eliminated_types
291                                .iter()
292                                .take_while(|i| (**i as u32) < *type_ref)
293                                .count();
294                            *type_ref -= totalle as u32;
295                        }
296                    }
297                }
298                elements::Section::Code(code_section)
299                    if !eliminated_globals.is_empty() || !eliminated_funcs.is_empty() =>
300                {
301                    for func_body in code_section.bodies_mut() {
302                        if !eliminated_funcs.is_empty() {
303                            update_call_index(func_body.code_mut(), &eliminated_funcs);
304                        }
305                        if !eliminated_globals.is_empty() {
306                            update_global_index(
307                                func_body.code_mut().elements_mut(),
308                                &eliminated_globals,
309                            )
310                        }
311                        if !eliminated_types.is_empty() {
312                            update_type_index(func_body.code_mut(), &eliminated_types)
313                        }
314                    }
315                }
316                elements::Section::Export(export_section) => {
317                    for export in export_section.entries_mut() {
318                        match export.internal_mut() {
319                            elements::Internal::Function(func_index) => {
320                                let totalle = eliminated_funcs
321                                    .iter()
322                                    .take_while(|i| (**i as u32) < *func_index)
323                                    .count();
324                                *func_index -= totalle as u32;
325                            }
326                            elements::Internal::Global(global_index) => {
327                                let totalle = eliminated_globals
328                                    .iter()
329                                    .take_while(|i| (**i as u32) < *global_index)
330                                    .count();
331                                *global_index -= totalle as u32;
332                            }
333                            _ => {}
334                        }
335                    }
336                }
337                elements::Section::Global(global_section) => {
338                    for global_entry in global_section.entries_mut() {
339                        update_global_index(
340                            global_entry.init_expr_mut().code_mut(),
341                            &eliminated_globals,
342                        )
343                    }
344                }
345                elements::Section::Data(data_section) => {
346                    for segment in data_section.entries_mut() {
347                        update_global_index(
348                            segment
349                                .offset_mut()
350                                .as_mut()
351                                .expect("parity-wasm is compiled without bulk-memory operations")
352                                .code_mut(),
353                            &eliminated_globals,
354                        )
355                    }
356                }
357                elements::Section::Element(elements_section) => {
358                    for segment in elements_section.entries_mut() {
359                        update_global_index(
360                            segment
361                                .offset_mut()
362                                .as_mut()
363                                .expect("parity-wasm is compiled without bulk-memory operations")
364                                .code_mut(),
365                            &eliminated_globals,
366                        );
367                        // update all indirect call addresses initial values
368                        for func_index in segment.members_mut() {
369                            let totalle = eliminated_funcs
370                                .iter()
371                                .take_while(|i| (**i as u32) < *func_index)
372                                .count();
373                            *func_index -= totalle as u32;
374                        }
375                    }
376                }
377                elements::Section::Name(name_section) => {
378                    if let Some(func_name) = name_section.functions_mut() {
379                        let mut func_name_map = mem::take(func_name.names_mut());
380                        for index in &eliminated_funcs {
381                            func_name_map.remove(*index as u32);
382                        }
383                        let updated_map = func_name_map
384                            .into_iter()
385                            .map(|(index, value)| {
386                                let totalle = eliminated_funcs
387                                    .iter()
388                                    .take_while(|i| (**i as u32) < index)
389                                    .count() as u32;
390                                (index - totalle, value)
391                            })
392                            .collect();
393                        *func_name.names_mut() = updated_map;
394                    }
395
396                    if let Some(local_name) = name_section.locals_mut() {
397                        let mut local_names_map = mem::take(local_name.local_names_mut());
398                        for index in &eliminated_funcs {
399                            local_names_map.remove(*index as u32);
400                        }
401                        let updated_map = local_names_map
402                            .into_iter()
403                            .map(|(index, value)| {
404                                let totalle = eliminated_funcs
405                                    .iter()
406                                    .take_while(|i| (**i as u32) < index)
407                                    .count() as u32;
408                                (index - totalle, value)
409                            })
410                            .collect();
411                        *local_name.local_names_mut() = updated_map;
412                    }
413                }
414                _ => {}
415            }
416        }
417    }
418
419    // Also drop all custom sections
420    module
421        .sections_mut()
422        .retain(|section| !matches!(section, elements::Section::Custom(_)));
423
424    Ok(())
425}
426
427pub fn update_call_index(instructions: &mut elements::Instructions, eliminated_indices: &[usize]) {
428    use casper_wasm::elements::Instruction::*;
429    for instruction in instructions.elements_mut().iter_mut() {
430        if let Call(call_index) = instruction {
431            let totalle = eliminated_indices
432                .iter()
433                .take_while(|i| (**i as u32) < *call_index)
434                .count();
435            trace!(
436                "rewired call {} -> call {}",
437                *call_index,
438                *call_index - totalle as u32
439            );
440            *call_index -= totalle as u32;
441        }
442    }
443}
444
445/// Updates global references considering the _ordered_ list of eliminated indices
446pub fn update_global_index(
447    instructions: &mut [elements::Instruction],
448    eliminated_indices: &[usize],
449) {
450    use casper_wasm::elements::Instruction::*;
451    for instruction in instructions.iter_mut() {
452        match instruction {
453            GetGlobal(index) | SetGlobal(index) => {
454                let totalle = eliminated_indices
455                    .iter()
456                    .take_while(|i| (**i as u32) < *index)
457                    .count();
458                trace!(
459                    "rewired global {} -> global {}",
460                    *index,
461                    *index - totalle as u32
462                );
463                *index -= totalle as u32;
464            }
465            _ => {}
466        }
467    }
468}
469
470/// Updates global references considering the _ordered_ list of eliminated indices
471pub fn update_type_index(instructions: &mut elements::Instructions, eliminated_indices: &[usize]) {
472    use casper_wasm::elements::Instruction::*;
473    for instruction in instructions.elements_mut().iter_mut() {
474        if let CallIndirect(call_index, _) = instruction {
475            let totalle = eliminated_indices
476                .iter()
477                .take_while(|i| (**i as u32) < *call_index)
478                .count();
479            trace!(
480                "rewired call_indrect {} -> call_indirect {}",
481                *call_index,
482                *call_index - totalle as u32
483            );
484            *call_index -= totalle as u32;
485        }
486    }
487}
488
489pub fn import_section(module: &mut elements::Module) -> Option<&mut elements::ImportSection> {
490    for section in module.sections_mut() {
491        if let elements::Section::Import(sect) = section {
492            return Some(sect);
493        }
494    }
495    None
496}
497
498pub fn global_section(module: &mut elements::Module) -> Option<&mut elements::GlobalSection> {
499    for section in module.sections_mut() {
500        if let elements::Section::Global(sect) = section {
501            return Some(sect);
502        }
503    }
504    None
505}
506
507pub fn function_section(module: &mut elements::Module) -> Option<&mut elements::FunctionSection> {
508    for section in module.sections_mut() {
509        if let elements::Section::Function(sect) = section {
510            return Some(sect);
511        }
512    }
513    None
514}
515
516pub fn code_section(module: &mut elements::Module) -> Option<&mut elements::CodeSection> {
517    for section in module.sections_mut() {
518        if let elements::Section::Code(sect) = section {
519            return Some(sect);
520        }
521    }
522    None
523}
524
525pub fn export_section(module: &mut elements::Module) -> Option<&mut elements::ExportSection> {
526    for section in module.sections_mut() {
527        if let elements::Section::Export(sect) = section {
528            return Some(sect);
529        }
530    }
531    None
532}
533
534pub fn type_section(module: &mut elements::Module) -> Option<&mut elements::TypeSection> {
535    for section in module.sections_mut() {
536        if let elements::Section::Type(sect) = section {
537            return Some(sect);
538        }
539    }
540    None
541}
542
543#[cfg(test)]
544mod tests {
545
546    use super::*;
547    use casper_wasm::{builder, elements};
548
549    /// @spec 0
550    /// Optimizer presumes that export section exists and contains
551    /// all symbols passed as a second parameter. Since empty module
552    /// obviously contains no export section, optimizer should return
553    /// error on it.
554    #[test]
555    fn empty() {
556        let mut module = builder::module().build();
557        let result = optimize(&mut module, vec!["_call"]);
558
559        assert!(result.is_err());
560    }
561
562    /// @spec 1
563    /// Imagine the unoptimized module has two own functions, `_call` and `_random`
564    /// and exports both of them in the export section. During optimization, the `_random`
565    /// function should vanish completely, given we pass `_call` as the only function to stay
566    /// in the module.
567    #[test]
568    fn minimal() {
569        let mut module = builder::module()
570            .function()
571            .signature()
572            .param()
573            .i32()
574            .build()
575            .build()
576            .function()
577            .signature()
578            .param()
579            .i32()
580            .param()
581            .i32()
582            .build()
583            .build()
584            .export()
585            .field("_call")
586            .internal()
587            .func(0)
588            .build()
589            .export()
590            .field("_random")
591            .internal()
592            .func(1)
593            .build()
594            .build();
595        assert_eq!(
596            module
597                .export_section()
598                .expect("export section to be generated")
599                .entries()
600                .len(),
601            2
602        );
603
604        optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
605
606        assert_eq!(
607            1,
608            module
609                .export_section()
610                .expect("export section to be generated")
611                .entries()
612                .len(),
613            "There should only 1 (one) export entry in the optimized module"
614        );
615
616        assert_eq!(
617            1,
618            module
619                .function_section()
620                .expect("functions section to be generated")
621                .entries()
622                .len(),
623            "There should 2 (two) functions in the optimized module"
624        );
625    }
626
627    /// @spec 2
628    /// Imagine there is one exported function in unoptimized module, `_call`, that we specify as the one
629    /// to stay during the optimization. The code of this function uses global during the execution.
630    /// This sayed global should survive the optimization.
631    #[test]
632    fn globals() {
633        let mut module = builder::module()
634            .global()
635            .value_type()
636            .i32()
637            .build()
638            .function()
639            .signature()
640            .param()
641            .i32()
642            .build()
643            .body()
644            .with_instructions(elements::Instructions::new(vec![
645                elements::Instruction::GetGlobal(0),
646                elements::Instruction::End,
647            ]))
648            .build()
649            .build()
650            .export()
651            .field("_call")
652            .internal()
653            .func(0)
654            .build()
655            .build();
656
657        optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
658
659        assert_eq!(
660            1,
661            module.global_section().expect("global section to be generated").entries().len(),
662            "There should 1 (one) global entry in the optimized module, since _call function uses it"
663        );
664    }
665
666    /// @spec 2
667    /// Imagine there is one exported function in unoptimized module, `_call`, that we specify as the one
668    /// to stay during the optimization. The code of this function uses one global during the execution,
669    /// but we have a bunch of other unused globals in the code. Last globals should not survive the optimization,
670    /// while the former should.
671    #[test]
672    fn globals_2() {
673        let mut module = builder::module()
674            .global()
675            .value_type()
676            .i32()
677            .build()
678            .global()
679            .value_type()
680            .i64()
681            .build()
682            .global()
683            .value_type()
684            .f32()
685            .build()
686            .function()
687            .signature()
688            .param()
689            .i32()
690            .build()
691            .body()
692            .with_instructions(elements::Instructions::new(vec![
693                elements::Instruction::GetGlobal(1),
694                elements::Instruction::End,
695            ]))
696            .build()
697            .build()
698            .export()
699            .field("_call")
700            .internal()
701            .func(0)
702            .build()
703            .build();
704
705        optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
706
707        assert_eq!(
708            1,
709            module.global_section().expect("global section to be generated").entries().len(),
710            "There should 1 (one) global entry in the optimized module, since _call function uses only one"
711        );
712    }
713
714    /// @spec 3
715    /// Imagine the unoptimized module has two own functions, `_call` and `_random`
716    /// and exports both of them in the export section. Function `_call` also calls `_random`
717    /// in its function body. The optimization should kick `_random` function from the export section
718    /// but preserve it's body.
719    #[test]
720    fn call_ref() {
721        let mut module = builder::module()
722            .function()
723            .signature()
724            .param()
725            .i32()
726            .build()
727            .body()
728            .with_instructions(elements::Instructions::new(vec![
729                elements::Instruction::Call(1),
730                elements::Instruction::End,
731            ]))
732            .build()
733            .build()
734            .function()
735            .signature()
736            .param()
737            .i32()
738            .param()
739            .i32()
740            .build()
741            .build()
742            .export()
743            .field("_call")
744            .internal()
745            .func(0)
746            .build()
747            .export()
748            .field("_random")
749            .internal()
750            .func(1)
751            .build()
752            .build();
753        assert_eq!(
754            module
755                .export_section()
756                .expect("export section to be generated")
757                .entries()
758                .len(),
759            2
760        );
761
762        optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
763
764        assert_eq!(
765            1,
766            module
767                .export_section()
768                .expect("export section to be generated")
769                .entries()
770                .len(),
771            "There should only 1 (one) export entry in the optimized module"
772        );
773
774        assert_eq!(
775            2,
776            module
777                .function_section()
778                .expect("functions section to be generated")
779                .entries()
780                .len(),
781            "There should 2 (two) functions in the optimized module"
782        );
783    }
784
785    /// @spec 4
786    /// Imagine the unoptimized module has an indirect call to function of type 1
787    /// The type should persist so that indirect call would work
788    #[test]
789    fn call_indirect() {
790        let mut module = builder::module()
791            .function()
792            .signature()
793            .param()
794            .i32()
795            .param()
796            .i32()
797            .build()
798            .build()
799            .function()
800            .signature()
801            .param()
802            .i32()
803            .param()
804            .i32()
805            .param()
806            .i32()
807            .build()
808            .build()
809            .function()
810            .signature()
811            .param()
812            .i32()
813            .build()
814            .body()
815            .with_instructions(elements::Instructions::new(vec![
816                elements::Instruction::CallIndirect(1, 0),
817                elements::Instruction::End,
818            ]))
819            .build()
820            .build()
821            .export()
822            .field("_call")
823            .internal()
824            .func(2)
825            .build()
826            .build();
827
828        optimize(&mut module, vec!["_call"]).expect("optimizer to succeed");
829
830        assert_eq!(
831            2,
832            module
833                .type_section()
834                .expect("type section to be generated")
835                .types()
836                .len(),
837            "There should 2 (two) types left in the module, 1 for indirect call and one for _call"
838        );
839
840        let indirect_opcode = &module
841            .code_section()
842            .expect("code section to be generated")
843            .bodies()[0]
844            .code()
845            .elements()[0];
846        match *indirect_opcode {
847            elements::Instruction::CallIndirect(0, 0) => {}
848            _ => {
849                panic!(
850                    "Expected call_indirect to use index 0 after optimization, since previois 0th was eliminated, but got {:?}",
851                    indirect_opcode
852                );
853            }
854        }
855    }
856}