1use std::cmp;
4use std::collections::BTreeMap;
5use std::ops::Range;
6
7use anyhow::{bail, Context};
8use wasm_encoder::Encode;
9use wasmparser::{FuncValidator, FunctionBody, ValidatorResources};
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14mod local_function;
15
16use crate::emit::{Emit, EmitContext};
17use crate::error::Result;
18use crate::ir::InstrLocId;
19use crate::module::imports::ImportId;
20use crate::module::Module;
21use crate::parse::IndicesToIds;
22use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
23use crate::ty::TypeId;
24use crate::ty::ValType;
25use crate::{ExportItem, FunctionBuilder, InstrSeqBuilder, LocalId, Memory, MemoryId};
26
27pub use self::local_function::LocalFunction;
28
29pub type FunctionId = Id<Function>;
31
32pub type FuncParams = Vec<ValType>;
34
35pub type FuncResults = Vec<ValType>;
37
38#[derive(Debug)]
42pub struct Function {
43 id: FunctionId,
46
47 pub kind: FunctionKind,
49
50 pub name: Option<String>,
52}
53
54impl Tombstone for Function {
55 fn on_delete(&mut self) {
56 let ty = self.ty();
57 self.kind = FunctionKind::Uninitialized(ty);
58 self.name = None;
59 }
60}
61
62impl Function {
63 fn new_uninitialized(id: FunctionId, ty: TypeId) -> Function {
64 Function {
65 id,
66 kind: FunctionKind::Uninitialized(ty),
67 name: None,
68 }
69 }
70
71 pub fn id(&self) -> FunctionId {
73 self.id
74 }
75
76 pub fn ty(&self) -> TypeId {
78 match &self.kind {
79 FunctionKind::Local(l) => l.ty(),
80 FunctionKind::Import(i) => i.ty,
81 FunctionKind::Uninitialized(t) => *t,
82 }
83 }
84}
85
86#[derive(Debug)]
88pub enum FunctionKind {
89 Import(ImportedFunction),
91
92 Local(LocalFunction),
94
95 Uninitialized(TypeId),
100}
101
102impl FunctionKind {
103 pub fn unwrap_import(&self) -> &ImportedFunction {
106 match self {
107 FunctionKind::Import(import) => import,
108 _ => panic!("not an import function"),
109 }
110 }
111
112 pub fn unwrap_local(&self) -> &LocalFunction {
115 match self {
116 FunctionKind::Local(l) => l,
117 _ => panic!("not a local function"),
118 }
119 }
120
121 pub fn unwrap_import_mut(&mut self) -> &mut ImportedFunction {
124 match self {
125 FunctionKind::Import(import) => import,
126 _ => panic!("not an import function"),
127 }
128 }
129
130 pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction {
133 match self {
134 FunctionKind::Local(l) => l,
135 _ => panic!("not a local function"),
136 }
137 }
138}
139
140#[derive(Debug)]
142pub struct ImportedFunction {
143 pub import: ImportId,
145 pub ty: TypeId,
147}
148
149#[derive(Debug, Default)]
151pub struct ModuleFunctions {
152 arena: TombstoneArena<Function>,
154
155 pub(crate) code_section_offset: usize,
157}
158
159impl ModuleFunctions {
160 pub fn new() -> ModuleFunctions {
162 Default::default()
163 }
164
165 pub fn add_import(&mut self, ty: TypeId, import: ImportId) -> FunctionId {
167 self.arena.alloc_with_id(|id| Function {
168 id,
169 kind: FunctionKind::Import(ImportedFunction { import, ty }),
170 name: None,
171 })
172 }
173
174 pub fn add_local(&mut self, func: LocalFunction) -> FunctionId {
176 let func_name = func.builder().name.clone();
177 self.arena.alloc_with_id(|id| Function {
178 id,
179 kind: FunctionKind::Local(func),
180 name: func_name,
181 })
182 }
183
184 pub fn get(&self, id: FunctionId) -> &Function {
186 &self.arena[id]
187 }
188
189 pub fn get_mut(&mut self, id: FunctionId) -> &mut Function {
191 &mut self.arena[id]
192 }
193
194 pub fn by_name(&self, name: &str) -> Option<FunctionId> {
202 self.arena.iter().find_map(|(id, f)| {
203 if f.name.as_deref() == Some(name) {
204 Some(id)
205 } else {
206 None
207 }
208 })
209 }
210
211 pub fn delete(&mut self, id: FunctionId) {
217 self.arena.delete(id);
218 }
219
220 pub fn iter(&self) -> impl Iterator<Item = &Function> {
222 self.arena.iter().map(|(_, f)| f)
223 }
224
225 #[cfg(feature = "parallel")]
229 pub fn par_iter(&self) -> impl ParallelIterator<Item = &Function> {
230 self.arena.par_iter().map(|(_, f)| f)
231 }
232
233 pub fn iter_local(&self) -> impl Iterator<Item = (FunctionId, &LocalFunction)> {
235 self.iter().filter_map(|f| match &f.kind {
236 FunctionKind::Local(local) => Some((f.id(), local)),
237 _ => None,
238 })
239 }
240
241 #[cfg(feature = "parallel")]
245 pub fn par_iter_local(&self) -> impl ParallelIterator<Item = (FunctionId, &LocalFunction)> {
246 self.par_iter().filter_map(|f| match &f.kind {
247 FunctionKind::Local(local) => Some((f.id(), local)),
248 _ => None,
249 })
250 }
251
252 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Function> {
254 self.arena.iter_mut().map(|(_, f)| f)
255 }
256
257 #[cfg(feature = "parallel")]
261 pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut Function> {
262 self.arena.par_iter_mut().map(|(_, f)| f)
263 }
264
265 pub fn iter_local_mut(&mut self) -> impl Iterator<Item = (FunctionId, &mut LocalFunction)> {
267 self.iter_mut().filter_map(|f| {
268 let id = f.id();
269 match &mut f.kind {
270 FunctionKind::Local(local) => Some((id, local)),
271 _ => None,
272 }
273 })
274 }
275
276 #[cfg(feature = "parallel")]
280 pub fn par_iter_local_mut(
281 &mut self,
282 ) -> impl ParallelIterator<Item = (FunctionId, &mut LocalFunction)> {
283 self.par_iter_mut().filter_map(|f| {
284 let id = f.id();
285 match &mut f.kind {
286 FunctionKind::Local(local) => Some((id, local)),
287 _ => None,
288 }
289 })
290 }
291
292 pub(crate) fn emit_func_section(&self, cx: &mut EmitContext) {
293 log::debug!("emit function section");
294 let functions = used_local_functions(cx);
295 if functions.is_empty() {
296 return;
297 }
298 let mut func_section = wasm_encoder::FunctionSection::new();
299 for (id, function, _size) in functions {
300 let index = cx.indices.get_type_index(function.ty());
301 func_section.function(index);
302
303 cx.indices.push_func(id);
309 }
310 cx.wasm_module.section(&func_section);
311 }
312}
313
314impl Module {
315 pub(crate) fn declare_local_functions(
318 &mut self,
319 section: wasmparser::FunctionSectionReader,
320 ids: &mut IndicesToIds,
321 ) -> Result<()> {
322 log::debug!("parse function section");
323 for func in section {
324 let ty = ids.get_type(func?)?;
325 let id = self
326 .funcs
327 .arena
328 .alloc_with_id(|id| Function::new_uninitialized(id, ty));
329 let idx = ids.push_func(id);
330 if self.config.generate_synthetic_names_for_anonymous_items {
331 self.funcs.get_mut(id).name = Some(format!("f{}", idx));
332 }
333 }
334
335 Ok(())
336 }
337
338 pub(crate) fn parse_local_functions(
340 &mut self,
341 functions: Vec<(FunctionBody<'_>, FuncValidator<ValidatorResources>)>,
342 indices: &mut IndicesToIds,
343 on_instr_pos: Option<&(dyn Fn(&usize) -> InstrLocId + Sync + Send + 'static)>,
344 ) -> Result<()> {
345 log::debug!("parse code section");
346 let num_imports = self.funcs.arena.len() - functions.len();
347
348 let mut bodies = Vec::with_capacity(functions.len());
353 for (i, (body, mut validator)) in functions.into_iter().enumerate() {
354 let index = (num_imports + i) as u32;
355 let id = indices.get_func(index)?;
356 let ty = match self.funcs.arena[id].kind {
357 FunctionKind::Uninitialized(ty) => ty,
358 _ => unreachable!(),
359 };
360
361 let mut args = Vec::new();
364 let type_ = self.types.get(ty);
365 for ty in type_.params().iter() {
366 let local_id = self.locals.add(*ty);
367 let idx = indices.push_local(id, local_id);
368 args.push(local_id);
369 if self.config.generate_synthetic_names_for_anonymous_items {
370 let name = format!("arg{}", idx);
371 self.locals.get_mut(local_id).name = Some(name);
372 }
373 }
374
375 let results = type_.results().to_vec();
380 self.types.add_entry_ty(&results);
381
382 let mut locals_reader = body.get_locals_reader()?;
384 for _ in 0..locals_reader.get_count() {
385 let pos = locals_reader.original_position();
386 let (count, ty) = locals_reader.read()?;
387 validator.define_locals(pos, count, ty)?;
388 let ty = ValType::from_wasmparser(&ty, indices, 0)?;
389 for _ in 0..count {
390 let local_id = self.locals.add(ty);
391 let idx = indices.push_local(id, local_id);
392 if self.config.generate_synthetic_names_for_anonymous_items {
393 let name = format!("l{}", idx);
394 self.locals.get_mut(local_id).name = Some(name);
395 }
396 }
397 }
398
399 bodies.push((id, body, args, ty, validator));
400 }
401
402 let results = maybe_parallel!(bodies.(into_iter | into_par_iter))
405 .map(|(id, body, args, ty, validator)| {
406 (
407 id,
408 LocalFunction::parse(
409 self,
410 indices,
411 id,
412 ty,
413 args,
414 body,
415 on_instr_pos,
416 validator,
417 ),
418 )
419 })
420 .collect::<Vec<_>>();
421
422 for (id, func) in results {
425 let func = func?;
426 self.funcs.arena[id].kind = FunctionKind::Local(func);
427 }
428
429 Ok(())
430 }
431
432 pub fn get_memory_id(&self) -> Result<MemoryId> {
437 if self.memories.len() > 1 {
438 bail!("multiple memories unsupported")
439 }
440
441 self.memories
442 .iter()
443 .next()
444 .map(Memory::id)
445 .context("module does not export a memory")
446 }
447
448 pub fn replace_exported_func(
467 &mut self,
468 fid: FunctionId,
469 builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
470 ) -> Result<FunctionId> {
471 let original_export_id = self
472 .exports
473 .get_exported_func(fid)
474 .map(|e| e.id())
475 .with_context(|| format!("no exported function with ID [{fid:?}]"))?;
476
477 if let Function {
478 kind: FunctionKind::Local(lf),
479 ..
480 } = self.funcs.get(fid)
481 {
482 let ty = self.types.get(lf.ty());
484 let (params, results) = (ty.params().to_vec(), ty.results().to_vec());
485
486 let mut builder = FunctionBuilder::new(&mut self.types, ¶ms, &results);
488 let mut new_fn_body = builder.func_body();
489 builder_fn((&mut new_fn_body, &lf.args));
490 let func = builder.local_func(lf.args.clone());
491 let new_fn_id = self.funcs.add_local(func);
492
493 let export = self.exports.get_mut(original_export_id);
495 export.item = ExportItem::Function(new_fn_id);
496 Ok(new_fn_id)
497 } else {
498 bail!("cannot replace function [{fid:?}], it is not an exported function");
499 }
500 }
501
502 pub fn replace_imported_func(
521 &mut self,
522 fid: FunctionId,
523 builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
524 ) -> Result<FunctionId> {
525 let original_import_id = self
526 .imports
527 .get_imported_func(fid)
528 .map(|import| import.id())
529 .with_context(|| format!("no exported function with ID [{fid:?}]"))?;
530
531 if let Function {
532 kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
533 ..
534 } = self.funcs.get(fid)
535 {
536 let ty = self.types.get(*tid);
538 let (params, results) = (ty.params().to_vec(), ty.results().to_vec());
539
540 let args = params
542 .iter()
543 .map(|ty| self.locals.add(*ty))
544 .collect::<Vec<_>>();
545
546 let mut builder = FunctionBuilder::new(&mut self.types, ¶ms, &results);
548 let mut new_fn_body = builder.func_body();
549 builder_fn((&mut new_fn_body, &args));
550 let new_func_kind = FunctionKind::Local(builder.local_func(args));
551
552 let func = self.funcs.get_mut(fid);
555 func.kind = new_func_kind;
556
557 self.imports.delete(original_import_id);
558
559 Ok(fid)
560 } else {
561 bail!("cannot replace function [{fid:?}], it is not an imported function");
562 }
563 }
564}
565
566fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> {
567 let mut functions = Vec::new();
572 for f in cx.module.funcs.iter() {
573 match &f.kind {
574 FunctionKind::Local(l) => functions.push((f.id(), l, l.size())),
575 FunctionKind::Import(_) => {}
576 FunctionKind::Uninitialized(_) => unreachable!(),
577 }
578 }
579
580 functions.sort_by_key(|(id, _, size)| (cmp::Reverse(*size), *id));
586
587 functions
588}
589
590fn collect_non_default_code_offsets(
591 code_transform: &mut BTreeMap<InstrLocId, usize>,
592 code_offset: usize,
593 map: Vec<(InstrLocId, usize)>,
594) {
595 for (src, dst) in map {
596 let dst = dst + code_offset;
597 if !src.is_default() {
598 code_transform.insert(src, dst);
599 }
600 }
601}
602
603impl Emit for ModuleFunctions {
604 fn emit(&self, cx: &mut EmitContext) {
605 log::debug!("emit code section");
606 let functions = used_local_functions(cx);
607 if functions.is_empty() {
608 return;
609 }
610
611 let mut wasm_code_section = wasm_encoder::CodeSection::new();
612 let generate_map = cx.module.config.preserve_code_transform;
613
614 let bytes = maybe_parallel!(functions.(into_iter | into_par_iter))
618 .map(|(id, func, _size)| {
619 log::debug!("emit function {:?} {:?}", id, cx.module.funcs.get(id).name);
620 let mut wasm = Vec::new();
621 let mut map = if generate_map { Some(Vec::new()) } else { None };
622
623 let (locals_types, used_locals, local_indices) =
624 func.emit_locals(cx.module, cx.indices);
625 let mut wasm_function = wasm_encoder::Function::new(locals_types);
626 func.emit_instructions(
627 cx.indices,
628 &local_indices,
629 &mut wasm_function,
630 map.as_mut(),
631 );
632 wasm_function.encode(&mut wasm);
633 (
634 wasm,
635 wasm_function.byte_len(),
636 id,
637 used_locals,
638 local_indices,
639 map,
640 )
641 })
642 .collect::<Vec<_>>();
643
644 let mut instruction_map = BTreeMap::new();
645 cx.indices.locals.reserve(bytes.len());
646
647 let mut offset_data = Vec::new();
648 for (wasm, byte_len, id, used_locals, local_indices, map) in bytes {
649 let leb_len = wasm.len() - byte_len;
650 wasm_code_section.raw(&wasm[leb_len..]);
651 cx.indices.locals.insert(id, local_indices);
652 cx.locals.insert(id, used_locals);
653 offset_data.push((byte_len, id, map, leb_len));
654 }
655 cx.wasm_module.section(&wasm_code_section);
656
657 let code_section_start_offset =
658 cx.wasm_module.as_slice().len() - wasm_code_section.byte_len();
659 let mut cur_offset = code_section_start_offset;
660
661 for (byte_len, id, map, leb_len) in offset_data {
663 let code_start_offset = cur_offset + leb_len;
665 cur_offset += leb_len + byte_len;
666 if let Some(map) = map {
667 collect_non_default_code_offsets(&mut instruction_map, code_start_offset, map);
668 }
669 cx.code_transform.function_ranges.push((
670 id,
671 Range {
672 start: code_start_offset - leb_len,
674 end: cur_offset,
675 },
676 ));
677 }
678 cx.code_transform.function_ranges.sort_by_key(|i| i.0);
679 cx.code_transform.code_section_start = code_section_start_offset - 2;
681 cx.code_transform.instruction_map = instruction_map.into_iter().collect();
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::{Export, FunctionBuilder, Module};
689
690 #[test]
691 fn get_memory_id() {
692 let mut module = Module::default();
693 let expected_id = module.memories.add_local(false, false, 0, None, None);
694 assert!(module.get_memory_id().is_ok_and(|id| id == expected_id));
695 }
696
697 #[test]
700 fn replace_exported_func() {
701 let mut module = Module::default();
702
703 let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
705 builder.func_body().i32_const(1234).drop();
706 let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
707 let original_export_id = module.exports.add("dummy", original_fn_id);
708
709 let new_fn_id = module
711 .replace_exported_func(original_fn_id, |(body, _)| {
712 body.i32_const(4321).drop();
713 })
714 .expect("function replacement worked");
715
716 assert!(
717 module.exports.get_exported_func(original_fn_id).is_none(),
718 "replaced function cannot be gotten by ID"
719 );
720
721 match module
723 .exports
724 .get_exported_func(new_fn_id)
725 .expect("failed to unwrap exported func")
726 {
727 exp @ Export {
728 item: ExportItem::Function(fid),
729 ..
730 } => {
731 assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
732 assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
733 }
734 _ => panic!("expected an Export with a Function inside"),
735 }
736 }
737
738 #[test]
741 fn replace_exported_func_generated_no_op() {
742 let mut module = Module::default();
743
744 let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
746 builder.func_body().i32_const(1234).drop();
747 let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
748 let original_export_id = module.exports.add("dummy", original_fn_id);
749
750 let new_fn_id = module
752 .replace_exported_func(original_fn_id, |(body, _arg_locals)| {
753 body.unreachable();
754 })
755 .expect("export function replacement worked");
756
757 assert!(
758 module.exports.get_exported_func(original_fn_id).is_none(),
759 "replaced export function cannot be gotten by ID"
760 );
761
762 match module
764 .exports
765 .get_exported_func(new_fn_id)
766 .expect("failed to unwrap exported func")
767 {
768 exp @ Export {
769 item: ExportItem::Function(fid),
770 name,
771 ..
772 } => {
773 assert_eq!(name, "dummy", "function name on export is unchanged");
774 assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
775 assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
776 }
777 _ => panic!("expected an Export with a Function inside"),
778 }
779 }
780
781 #[test]
784 fn replace_imported_func() {
785 let mut module = Module::default();
786
787 let types = module.types.add(&[], &[]);
789 let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);
790
791 let new_fn_id = module
793 .replace_imported_func(original_fn_id, |(body, _)| {
794 body.i32_const(4321).drop();
795 })
796 .expect("import fn replacement worked");
797
798 assert!(
799 !module.imports.iter().any(|i| i.id() == original_import_id),
800 "original import is missing",
801 );
802
803 assert!(
804 module.imports.get_imported_func(original_fn_id).is_none(),
805 "replaced import function cannot be gotten by ID"
806 );
807
808 assert!(
809 module.imports.get_imported_func(new_fn_id).is_none(),
810 "new import function cannot be gotten by ID (it is now local)"
811 );
812
813 assert!(
814 matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
815 "new local function has the right kind"
816 );
817 }
818
819 #[test]
822 fn replace_imported_func_generated_no_op() {
823 let mut module = Module::default();
824
825 let types = module.types.add(&[], &[]);
827 let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);
828
829 let new_fn_id = module
831 .replace_imported_func(original_fn_id, |(body, _arg_locals)| {
832 body.unreachable();
833 })
834 .expect("import fn replacement worked");
835
836 assert!(
837 !module.imports.iter().any(|i| i.id() == original_import_id),
838 "original import is missing",
839 );
840
841 assert!(
842 module.imports.get_imported_func(original_fn_id).is_none(),
843 "replaced import function cannot be gotten by ID"
844 );
845
846 assert!(
847 module.imports.get_imported_func(new_fn_id).is_none(),
848 "new import function cannot be gotten by ID (it is now local)"
849 );
850
851 assert!(
852 matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
853 "new local function has the right kind"
854 );
855 }
856}