pwasm_utils/
ext.rs

1use crate::std::{borrow::ToOwned, string::String, vec::Vec};
2
3use byteorder::{ByteOrder, LittleEndian};
4use parity_wasm::{builder, elements};
5
6use crate::optimizer::{export_section, import_section};
7
8type Insertion = (usize, u32, u32, String);
9
10pub fn update_call_index(
11	instructions: &mut elements::Instructions,
12	original_imports: usize,
13	inserts: &[Insertion],
14) {
15	use parity_wasm::elements::Instruction::*;
16	for instruction in instructions.elements_mut().iter_mut() {
17		if let Call(call_index) = instruction {
18			if let Some(pos) = inserts.iter().position(|x| x.1 == *call_index) {
19				*call_index = (original_imports + pos) as u32;
20			} else if *call_index as usize > original_imports {
21				*call_index += inserts.len() as u32;
22			}
23		}
24	}
25}
26
27pub fn memory_section(module: &mut elements::Module) -> Option<&mut elements::MemorySection> {
28	for section in module.sections_mut() {
29		if let elements::Section::Memory(sect) = section {
30			return Some(sect)
31		}
32	}
33	None
34}
35
36pub fn externalize_mem(
37	mut module: elements::Module,
38	adjust_pages: Option<u32>,
39	max_pages: u32,
40) -> elements::Module {
41	let mut entry = memory_section(&mut module)
42		.expect("Memory section to exist")
43		.entries_mut()
44		.pop()
45		.expect("Own memory entry to exist in memory section");
46
47	if let Some(adjust_pages) = adjust_pages {
48		assert!(adjust_pages <= max_pages);
49		entry = elements::MemoryType::new(adjust_pages, Some(max_pages));
50	}
51
52	if entry.limits().maximum().is_none() {
53		entry = elements::MemoryType::new(entry.limits().initial(), Some(max_pages));
54	}
55
56	let mut builder = builder::from_module(module);
57	builder.push_import(elements::ImportEntry::new(
58		"env".to_owned(),
59		"memory".to_owned(),
60		elements::External::Memory(entry),
61	));
62
63	builder.build()
64}
65
66fn foreach_public_func_name<F>(mut module: elements::Module, f: F) -> elements::Module
67where
68	F: Fn(&mut String),
69{
70	if let Some(section) = import_section(&mut module) {
71		for entry in section.entries_mut() {
72			if let elements::External::Function(_) = *entry.external() {
73				f(entry.field_mut())
74			}
75		}
76	}
77
78	if let Some(section) = export_section(&mut module) {
79		for entry in section.entries_mut() {
80			if let elements::Internal::Function(_) = *entry.internal() {
81				f(entry.field_mut())
82			}
83		}
84	}
85
86	module
87}
88
89pub fn underscore_funcs(module: elements::Module) -> elements::Module {
90	foreach_public_func_name(module, |n| n.insert(0, '_'))
91}
92
93pub fn ununderscore_funcs(module: elements::Module) -> elements::Module {
94	foreach_public_func_name(module, |n| {
95		n.remove(0);
96	})
97}
98
99pub fn shrink_unknown_stack(
100	mut module: elements::Module,
101	// for example, `shrink_amount = (1MB - 64KB)` will limit stack to 64KB
102	shrink_amount: u32,
103) -> (elements::Module, u32) {
104	let mut new_stack_top = 0;
105	for section in module.sections_mut() {
106		match section {
107			elements::Section::Data(data_section) => {
108				for data_segment in data_section.entries_mut() {
109					if *data_segment
110						.offset()
111						.as_ref()
112						.expect("parity-wasm is compiled without bulk-memory operations")
113						.code() == [elements::Instruction::I32Const(4), elements::Instruction::End]
114					{
115						assert_eq!(data_segment.value().len(), 4);
116						let current_val = LittleEndian::read_u32(data_segment.value());
117						let new_val = current_val - shrink_amount;
118						LittleEndian::write_u32(data_segment.value_mut(), new_val);
119						new_stack_top = new_val;
120					}
121				}
122			},
123			_ => continue,
124		}
125	}
126	(module, new_stack_top)
127}
128
129pub fn externalize(module: elements::Module, replaced_funcs: Vec<&str>) -> elements::Module {
130	// Save import functions number for later
131	let import_funcs_total = module
132		.import_section()
133		.expect("Import section to exist")
134		.entries()
135		.iter()
136		.filter(|e| matches!(e.external(), &elements::External::Function(_)))
137		.count();
138
139	// First, we find functions indices that are to be rewired to externals
140	//   Triple is (function_index (callable), type_index, function_name)
141	let mut replaces: Vec<Insertion> = replaced_funcs
142		.into_iter()
143		.filter_map(|f| {
144			let export = module
145				.export_section()
146				.expect("Export section to exist")
147				.entries()
148				.iter()
149				.enumerate()
150				.find(|&(_, entry)| entry.field() == f)
151				.expect("All functions of interest to exist");
152
153			if let elements::Internal::Function(func_idx) = *export.1.internal() {
154				let type_ref =
155					module.function_section().expect("Functions section to exist").entries()
156						[func_idx as usize - import_funcs_total]
157						.type_ref();
158
159				Some((export.0, func_idx, type_ref, export.1.field().to_owned()))
160			} else {
161				None
162			}
163		})
164		.collect();
165
166	replaces.sort_by_key(|e| e.0);
167
168	// Second, we duplicate them as import definitions
169	let mut mbuilder = builder::from_module(module);
170	for (_, _, type_ref, field) in replaces.iter() {
171		mbuilder.push_import(
172			builder::import().module("env").field(field).external().func(*type_ref).build(),
173		);
174	}
175
176	// Back to mutable access
177	let mut module = mbuilder.build();
178
179	// Third, rewire all calls to imported functions and update all other calls indices
180	for section in module.sections_mut() {
181		match section {
182			elements::Section::Code(code_section) =>
183				for func_body in code_section.bodies_mut() {
184					update_call_index(func_body.code_mut(), import_funcs_total, &replaces);
185				},
186			elements::Section::Export(export_section) => {
187				for export in export_section.entries_mut() {
188					if let elements::Internal::Function(func_index) = export.internal_mut() {
189						if *func_index >= import_funcs_total as u32 {
190							*func_index += replaces.len() as u32;
191						}
192					}
193				}
194			},
195			elements::Section::Element(elements_section) => {
196				for segment in elements_section.entries_mut() {
197					// update all indirect call addresses initial values
198					for func_index in segment.members_mut() {
199						if *func_index >= import_funcs_total as u32 {
200							*func_index += replaces.len() as u32;
201						}
202					}
203				}
204			},
205			_ => {},
206		}
207	}
208
209	module
210}