Skip to main content

walrus/module/functions/
mod.rs

1//! Functions within a wasm module.
2
3use std::cmp;
4use std::collections::BTreeMap;
5use std::ops::Range;
6
7use anyhow::{bail, Context};
8use wasm_encoder::Encode;
9use wasmparser::{FuncValidator, FunctionBody, ValidatorResources};
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14mod local_function;
15
16use crate::emit::{Emit, EmitContext};
17use crate::error::Result;
18use crate::ir::InstrLocId;
19use crate::module::imports::ImportId;
20use crate::module::Module;
21use crate::parse::IndicesToIds;
22use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
23use crate::ty::TypeId;
24use crate::ty::ValType;
25use crate::{ExportItem, FunctionBuilder, InstrSeqBuilder, LocalId, Memory, MemoryId};
26
27pub use self::local_function::LocalFunction;
28
29/// A function identifier.
30pub type FunctionId = Id<Function>;
31
32/// Parameter(s) to a function
33pub type FuncParams = Vec<ValType>;
34
35/// Result(s) of a given function
36pub type FuncResults = Vec<ValType>;
37
38/// A wasm function.
39///
40/// Either defined locally or externally and then imported; see `FunctionKind`.
41#[derive(Debug)]
42pub struct Function {
43    // NB: Not public so that it can't get out of sync with the arena that this
44    // function lives within.
45    id: FunctionId,
46
47    /// The kind of function this is.
48    pub kind: FunctionKind,
49
50    /// An optional name associated with this function
51    pub name: Option<String>,
52}
53
54impl Tombstone for Function {
55    fn on_delete(&mut self) {
56        let ty = self.ty();
57        self.kind = FunctionKind::Uninitialized(ty);
58        self.name = None;
59    }
60}
61
62impl Function {
63    fn new_uninitialized(id: FunctionId, ty: TypeId) -> Function {
64        Function {
65            id,
66            kind: FunctionKind::Uninitialized(ty),
67            name: None,
68        }
69    }
70
71    /// Get this function's identifier.
72    pub fn id(&self) -> FunctionId {
73        self.id
74    }
75
76    /// Get this function's type's identifier.
77    pub fn ty(&self) -> TypeId {
78        match &self.kind {
79            FunctionKind::Local(l) => l.ty(),
80            FunctionKind::Import(i) => i.ty,
81            FunctionKind::Uninitialized(t) => *t,
82        }
83    }
84}
85
86/// The local- or external-specific bits of a function.
87#[derive(Debug)]
88pub enum FunctionKind {
89    /// An externally defined, imported wasm function.
90    Import(ImportedFunction),
91
92    /// A locally defined wasm function.
93    Local(LocalFunction),
94
95    /// A locally defined wasm function that we haven't parsed yet (but have
96    /// reserved its id and associated it with its original input wasm module
97    /// index). This should only exist within
98    /// `ModuleFunctions::add_local_functions`.
99    Uninitialized(TypeId),
100}
101
102impl FunctionKind {
103    /// Get the underlying `FunctionKind::Import` or panic if this is not an
104    /// import function
105    pub fn unwrap_import(&self) -> &ImportedFunction {
106        match self {
107            FunctionKind::Import(import) => import,
108            _ => panic!("not an import function"),
109        }
110    }
111
112    /// Get the underlying `FunctionKind::Local` or panic if this is not a local
113    /// function.
114    pub fn unwrap_local(&self) -> &LocalFunction {
115        match self {
116            FunctionKind::Local(l) => l,
117            _ => panic!("not a local function"),
118        }
119    }
120
121    /// Get the underlying `FunctionKind::Import` or panic if this is not an
122    /// import function
123    pub fn unwrap_import_mut(&mut self) -> &mut ImportedFunction {
124        match self {
125            FunctionKind::Import(import) => import,
126            _ => panic!("not an import function"),
127        }
128    }
129
130    /// Get the underlying `FunctionKind::Local` or panic if this is not a local
131    /// function.
132    pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction {
133        match self {
134            FunctionKind::Local(l) => l,
135            _ => panic!("not a local function"),
136        }
137    }
138}
139
140/// An externally defined, imported function.
141#[derive(Debug)]
142pub struct ImportedFunction {
143    /// The import that brings this function into the module.
144    pub import: ImportId,
145    /// The type signature of this imported function.
146    pub ty: TypeId,
147}
148
149/// The set of functions within a module.
150#[derive(Debug, Default)]
151pub struct ModuleFunctions {
152    /// The arena containing this module's functions.
153    arena: TombstoneArena<Function>,
154
155    /// Original code section offset.
156    pub(crate) code_section_offset: usize,
157}
158
159impl ModuleFunctions {
160    /// Construct a new, empty set of functions for a module.
161    pub fn new() -> ModuleFunctions {
162        Default::default()
163    }
164
165    /// Create a new externally defined, imported function.
166    pub fn add_import(&mut self, ty: TypeId, import: ImportId) -> FunctionId {
167        self.arena.alloc_with_id(|id| Function {
168            id,
169            kind: FunctionKind::Import(ImportedFunction { import, ty }),
170            name: None,
171        })
172    }
173
174    /// Create a new internally defined function
175    pub fn add_local(&mut self, func: LocalFunction) -> FunctionId {
176        let func_name = func.builder().name.clone();
177        self.arena.alloc_with_id(|id| Function {
178            id,
179            kind: FunctionKind::Local(func),
180            name: func_name,
181        })
182    }
183
184    /// Gets a reference to a function given its id
185    pub fn get(&self, id: FunctionId) -> &Function {
186        &self.arena[id]
187    }
188
189    /// Gets a reference to a function given its id
190    pub fn get_mut(&mut self, id: FunctionId) -> &mut Function {
191        &mut self.arena[id]
192    }
193
194    /// Get a function ID by its name.
195    ///
196    /// The name used is the "name" custom section name and *not* the export
197    /// name, if a function happens to be exported.
198    ///
199    /// Note that function names are *not* guaranteed to be unique. This will
200    /// return the first function in the module with the given name.
201    pub fn by_name(&self, name: &str) -> Option<FunctionId> {
202        self.arena.iter().find_map(|(id, f)| {
203            if f.name.as_deref() == Some(name) {
204                Some(id)
205            } else {
206                None
207            }
208        })
209    }
210
211    /// Removes a function from this module.
212    ///
213    /// It is up to you to ensure that any potential references to the deleted
214    /// function are also removed, eg `call` expressions, exports, table
215    /// elements, etc.
216    pub fn delete(&mut self, id: FunctionId) {
217        self.arena.delete(id);
218    }
219
220    /// Get a shared reference to this module's functions.
221    pub fn iter(&self) -> impl Iterator<Item = &Function> {
222        self.arena.iter().map(|(_, f)| f)
223    }
224
225    /// Get a shared reference to this module's functions.
226    ///
227    /// Requires the `parallel` feature of this crate to be enabled.
228    #[cfg(feature = "parallel")]
229    pub fn par_iter(&self) -> impl ParallelIterator<Item = &Function> {
230        self.arena.par_iter().map(|(_, f)| f)
231    }
232
233    /// Get an iterator of this module's local functions
234    pub fn iter_local(&self) -> impl Iterator<Item = (FunctionId, &LocalFunction)> {
235        self.iter().filter_map(|f| match &f.kind {
236            FunctionKind::Local(local) => Some((f.id(), local)),
237            _ => None,
238        })
239    }
240
241    /// Get a parallel iterator of this module's local functions
242    ///
243    /// Requires the `parallel` feature of this crate to be enabled.
244    #[cfg(feature = "parallel")]
245    pub fn par_iter_local(&self) -> impl ParallelIterator<Item = (FunctionId, &LocalFunction)> {
246        self.par_iter().filter_map(|f| match &f.kind {
247            FunctionKind::Local(local) => Some((f.id(), local)),
248            _ => None,
249        })
250    }
251
252    /// Get a mutable reference to this module's functions.
253    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Function> {
254        self.arena.iter_mut().map(|(_, f)| f)
255    }
256
257    /// Get a mutable reference to this module's functions.
258    ///
259    /// Requires the `parallel` feature of this crate to be enabled.
260    #[cfg(feature = "parallel")]
261    pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut Function> {
262        self.arena.par_iter_mut().map(|(_, f)| f)
263    }
264
265    /// Get an iterator of this module's local functions
266    pub fn iter_local_mut(&mut self) -> impl Iterator<Item = (FunctionId, &mut LocalFunction)> {
267        self.iter_mut().filter_map(|f| {
268            let id = f.id();
269            match &mut f.kind {
270                FunctionKind::Local(local) => Some((id, local)),
271                _ => None,
272            }
273        })
274    }
275
276    /// Get a parallel iterator of this module's local functions
277    ///
278    /// Requires the `parallel` feature of this crate to be enabled.
279    #[cfg(feature = "parallel")]
280    pub fn par_iter_local_mut(
281        &mut self,
282    ) -> impl ParallelIterator<Item = (FunctionId, &mut LocalFunction)> {
283        self.par_iter_mut().filter_map(|f| {
284            let id = f.id();
285            match &mut f.kind {
286                FunctionKind::Local(local) => Some((id, local)),
287                _ => None,
288            }
289        })
290    }
291
292    pub(crate) fn emit_func_section(&self, cx: &mut EmitContext) {
293        log::debug!("emit function section");
294        let functions = used_local_functions(cx);
295        if functions.is_empty() {
296            return;
297        }
298        let mut func_section = wasm_encoder::FunctionSection::new();
299        for (id, function, _size) in functions {
300            let index = cx.indices.get_type_index(function.ty());
301            func_section.function(index);
302
303            // Assign an index to all local defined functions before we start
304            // translating them. While translating they may refer to future
305            // functions, so we'll need to have an index for it by that point.
306            // We're guaranteed the function section is emitted before the code
307            // section so we should be covered here.
308            cx.indices.push_func(id);
309        }
310        cx.wasm_module.section(&func_section);
311    }
312}
313
314impl Module {
315    /// Declare local functions after seeing the `function` section of a wasm
316    /// executable.
317    pub(crate) fn declare_local_functions(
318        &mut self,
319        section: wasmparser::FunctionSectionReader,
320        ids: &mut IndicesToIds,
321    ) -> Result<()> {
322        log::debug!("parse function section");
323        for func in section {
324            let ty = ids.get_type(func?)?;
325            let id = self
326                .funcs
327                .arena
328                .alloc_with_id(|id| Function::new_uninitialized(id, ty));
329            let idx = ids.push_func(id);
330            if self.config.generate_synthetic_names_for_anonymous_items {
331                self.funcs.get_mut(id).name = Some(format!("f{}", idx));
332            }
333        }
334
335        Ok(())
336    }
337
338    /// Add the locally defined functions in the wasm module to this instance.
339    pub(crate) fn parse_local_functions(
340        &mut self,
341        functions: Vec<(FunctionBody<'_>, FuncValidator<ValidatorResources>)>,
342        indices: &mut IndicesToIds,
343        on_instr_pos: Option<&(dyn Fn(&usize) -> InstrLocId + Sync + Send + 'static)>,
344    ) -> Result<()> {
345        log::debug!("parse code section");
346        let num_imports = self.funcs.arena.len() - functions.len();
347
348        // First up serially create corresponding `LocalId` instances for all
349        // functions as well as extract the operators parser for each function.
350        // This is pretty tough to parallelize, but we can look into it later if
351        // necessary and it's a bottleneck!
352        let mut bodies = Vec::with_capacity(functions.len());
353        for (i, (body, mut validator)) in functions.into_iter().enumerate() {
354            let index = (num_imports + i) as u32;
355            let id = indices.get_func(index)?;
356            let ty = match self.funcs.arena[id].kind {
357                FunctionKind::Uninitialized(ty) => ty,
358                _ => unreachable!(),
359            };
360
361            // First up, implicitly add locals for all function arguments. We also
362            // record these in the function itself for later processing.
363            let mut args = Vec::new();
364            let type_ = self.types.get(ty);
365            for ty in type_.params().iter() {
366                let local_id = self.locals.add(*ty);
367                let idx = indices.push_local(id, local_id);
368                args.push(local_id);
369                if self.config.generate_synthetic_names_for_anonymous_items {
370                    let name = format!("arg{}", idx);
371                    self.locals.get_mut(local_id).name = Some(name);
372                }
373            }
374
375            // Ensure that there exists a `Type` for the function's entry
376            // block. This is required because multi-value blocks reference a
377            // `Type`, however function entry's type is implicit in the
378            // encoding, and doesn't already exist in the `ModuleTypes`.
379            let results = type_.results().to_vec();
380            self.types.add_entry_ty(&results);
381
382            // Next up comes all the locals of the function.
383            let mut locals_reader = body.get_locals_reader()?;
384            for _ in 0..locals_reader.get_count() {
385                let pos = locals_reader.original_position();
386                let (count, ty) = locals_reader.read()?;
387                validator.define_locals(pos, count, ty)?;
388                let ty = ValType::from_wasmparser(&ty, indices, 0)?;
389                for _ in 0..count {
390                    let local_id = self.locals.add(ty);
391                    let idx = indices.push_local(id, local_id);
392                    if self.config.generate_synthetic_names_for_anonymous_items {
393                        let name = format!("l{}", idx);
394                        self.locals.get_mut(local_id).name = Some(name);
395                    }
396                }
397            }
398
399            bodies.push((id, body, args, ty, validator));
400        }
401
402        // Wasm modules can often have a lot of functions and this operation can
403        // take some time, so parse all function bodies in parallel.
404        let results = maybe_parallel!(bodies.(into_iter | into_par_iter))
405            .map(|(id, body, args, ty, validator)| {
406                (
407                    id,
408                    LocalFunction::parse(
409                        self,
410                        indices,
411                        id,
412                        ty,
413                        args,
414                        body,
415                        on_instr_pos,
416                        validator,
417                    ),
418                )
419            })
420            .collect::<Vec<_>>();
421
422        // After all the function bodies are collected and finished push them
423        // into our function arena.
424        for (id, func) in results {
425            let func = func?;
426            self.funcs.arena[id].kind = FunctionKind::Local(func);
427        }
428
429        Ok(())
430    }
431
432    /// Retrieve the ID for the first exported memory.
433    ///
434    /// This method does not work in contexts with [multi-memory enabled](https://github.com/WebAssembly/multi-memory),
435    /// and will error if more than one memory is present.
436    pub fn get_memory_id(&self) -> Result<MemoryId> {
437        if self.memories.len() > 1 {
438            bail!("multiple memories unsupported")
439        }
440
441        self.memories
442            .iter()
443            .next()
444            .map(Memory::id)
445            .context("module does not export a memory")
446    }
447
448    /// Replace a single exported function with the result of the provided builder function.
449    ///
450    /// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
451    /// used to build the function as necessary.
452    ///
453    /// For example, if you wanted to replace an exported function with a no-op,
454    ///
455    /// ```ignore
456    /// module.replace_exported_func(fid, |(body, arg_locals)| {
457    ///     builder.func_body().unreachable();
458    /// });
459    /// ```
460    ///
461    /// The arguments passed to the original function will be passed to the
462    /// new exported function that was built in your closure.
463    ///
464    /// This function returns the function ID of the *new* function,
465    /// after it has been inserted into the module as an export.
466    pub fn replace_exported_func(
467        &mut self,
468        fid: FunctionId,
469        builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
470    ) -> Result<FunctionId> {
471        let original_export_id = self
472            .exports
473            .get_exported_func(fid)
474            .map(|e| e.id())
475            .with_context(|| format!("no exported function with ID [{fid:?}]"))?;
476
477        if let Function {
478            kind: FunctionKind::Local(lf),
479            ..
480        } = self.funcs.get(fid)
481        {
482            // Retrieve the params & result types for the exported (local) function
483            let ty = self.types.get(lf.ty());
484            let (params, results) = (ty.params().to_vec(), ty.results().to_vec());
485
486            // Add the function produced by `fn_builder` as a local function
487            let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
488            let mut new_fn_body = builder.func_body();
489            builder_fn((&mut new_fn_body, &lf.args));
490            let func = builder.local_func(lf.args.clone());
491            let new_fn_id = self.funcs.add_local(func);
492
493            // Mutate the existing export to use the new local function
494            let export = self.exports.get_mut(original_export_id);
495            export.item = ExportItem::Function(new_fn_id);
496            Ok(new_fn_id)
497        } else {
498            bail!("cannot replace function [{fid:?}], it is not an exported function");
499        }
500    }
501
502    /// Replace a single imported function with the result of the provided builder function.
503    ///
504    /// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
505    /// used to build the function as necessary.
506    ///
507    /// For example, if you wanted to replace an imported function with a no-op,
508    ///
509    /// ```ignore
510    /// module.replace_imported_func(fid, |(body, arg_locals)| {
511    ///     builder.func_body().unreachable();
512    /// });
513    /// ```
514    ///
515    /// The arguments passed to the original function will be passed to the
516    /// new exported function that was built in your closure.
517    ///
518    /// This function returns the function ID of the *new* function, and
519    /// removes the existing import that has been replaced (the function will become local).
520    pub fn replace_imported_func(
521        &mut self,
522        fid: FunctionId,
523        builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
524    ) -> Result<FunctionId> {
525        let original_import_id = self
526            .imports
527            .get_imported_func(fid)
528            .map(|import| import.id())
529            .with_context(|| format!("no exported function with ID [{fid:?}]"))?;
530
531        if let Function {
532            kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
533            ..
534        } = self.funcs.get(fid)
535        {
536            // Retrieve the params & result types for the imported function
537            let ty = self.types.get(*tid);
538            let (params, results) = (ty.params().to_vec(), ty.results().to_vec());
539
540            // Build the list LocalIds used by args to match the original function
541            let args = params
542                .iter()
543                .map(|ty| self.locals.add(*ty))
544                .collect::<Vec<_>>();
545
546            // Build the new function
547            let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
548            let mut new_fn_body = builder.func_body();
549            builder_fn((&mut new_fn_body, &args));
550            let new_func_kind = FunctionKind::Local(builder.local_func(args));
551
552            // Mutate the existing function, changing it from a FunctionKind::ImportedFunction
553            // to the local function produced by running the provided `fn_builder`
554            let func = self.funcs.get_mut(fid);
555            func.kind = new_func_kind;
556
557            self.imports.delete(original_import_id);
558
559            Ok(fid)
560        } else {
561            bail!("cannot replace function [{fid:?}], it is not an imported function");
562        }
563    }
564}
565
566fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> {
567    // Extract all local functions because imported ones were already
568    // emitted as part of the import sectin. Find the size of each local
569    // function. Sort imported functions in order so that we can get their
570    // index in the function index space.
571    let mut functions = Vec::new();
572    for f in cx.module.funcs.iter() {
573        match &f.kind {
574            FunctionKind::Local(l) => functions.push((f.id(), l, l.size())),
575            FunctionKind::Import(_) => {}
576            FunctionKind::Uninitialized(_) => unreachable!(),
577        }
578    }
579
580    // Sort local functions from largest to smallest; we will emit them in
581    // this order. This helps load times, since wasm engines generally use
582    // the function as their level of granularity for parallelism. We want
583    // larger functions compiled before smaller ones because they will take
584    // longer to compile.
585    functions.sort_by_key(|(id, _, size)| (cmp::Reverse(*size), *id));
586
587    functions
588}
589
590fn collect_non_default_code_offsets(
591    code_transform: &mut BTreeMap<InstrLocId, usize>,
592    code_offset: usize,
593    map: Vec<(InstrLocId, usize)>,
594) {
595    for (src, dst) in map {
596        let dst = dst + code_offset;
597        if !src.is_default() {
598            code_transform.insert(src, dst);
599        }
600    }
601}
602
603impl Emit for ModuleFunctions {
604    fn emit(&self, cx: &mut EmitContext) {
605        log::debug!("emit code section");
606        let functions = used_local_functions(cx);
607        if functions.is_empty() {
608            return;
609        }
610
611        let mut wasm_code_section = wasm_encoder::CodeSection::new();
612        let generate_map = cx.module.config.preserve_code_transform;
613
614        // Functions can typically take awhile to serialize, so serialize
615        // everything in parallel. Afterwards we'll actually place all the
616        // functions together.
617        let bytes = maybe_parallel!(functions.(into_iter | into_par_iter))
618            .map(|(id, func, _size)| {
619                log::debug!("emit function {:?} {:?}", id, cx.module.funcs.get(id).name);
620                let mut wasm = Vec::new();
621                let mut map = if generate_map { Some(Vec::new()) } else { None };
622
623                let (locals_types, used_locals, local_indices) =
624                    func.emit_locals(cx.module, cx.indices);
625                let mut wasm_function = wasm_encoder::Function::new(locals_types);
626                func.emit_instructions(
627                    cx.indices,
628                    &local_indices,
629                    &mut wasm_function,
630                    map.as_mut(),
631                );
632                wasm_function.encode(&mut wasm);
633                (
634                    wasm,
635                    wasm_function.byte_len(),
636                    id,
637                    used_locals,
638                    local_indices,
639                    map,
640                )
641            })
642            .collect::<Vec<_>>();
643
644        let mut instruction_map = BTreeMap::new();
645        cx.indices.locals.reserve(bytes.len());
646
647        let mut offset_data = Vec::new();
648        for (wasm, byte_len, id, used_locals, local_indices, map) in bytes {
649            let leb_len = wasm.len() - byte_len;
650            wasm_code_section.raw(&wasm[leb_len..]);
651            cx.indices.locals.insert(id, local_indices);
652            cx.locals.insert(id, used_locals);
653            offset_data.push((byte_len, id, map, leb_len));
654        }
655        cx.wasm_module.section(&wasm_code_section);
656
657        let code_section_start_offset =
658            cx.wasm_module.as_slice().len() - wasm_code_section.byte_len();
659        let mut cur_offset = code_section_start_offset;
660
661        // update the map afterwards based on final offset differences
662        for (byte_len, id, map, leb_len) in offset_data {
663            // (this assumes the leb encodes the same)
664            let code_start_offset = cur_offset + leb_len;
665            cur_offset += leb_len + byte_len;
666            if let Some(map) = map {
667                collect_non_default_code_offsets(&mut instruction_map, code_start_offset, map);
668            }
669            cx.code_transform.function_ranges.push((
670                id,
671                Range {
672                    // inclusive leb part
673                    start: code_start_offset - leb_len,
674                    end: cur_offset,
675                },
676            ));
677        }
678        cx.code_transform.function_ranges.sort_by_key(|i| i.0);
679        // FIXME: code section start in DWARF debug information expects 2 bytes before actual code section start.
680        cx.code_transform.code_section_start = code_section_start_offset - 2;
681        cx.code_transform.instruction_map = instruction_map.into_iter().collect();
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688    use crate::{Export, FunctionBuilder, Module};
689
690    #[test]
691    fn get_memory_id() {
692        let mut module = Module::default();
693        let expected_id = module.memories.add_local(false, false, 0, None, None);
694        assert!(module.get_memory_id().is_ok_and(|id| id == expected_id));
695    }
696
697    /// Running `replace_exported_func` with a closure that builds
698    /// a function should replace the existing function with the new one
699    #[test]
700    fn replace_exported_func() {
701        let mut module = Module::default();
702
703        // Create original function
704        let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
705        builder.func_body().i32_const(1234).drop();
706        let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
707        let original_export_id = module.exports.add("dummy", original_fn_id);
708
709        // Replace the existing function with a new one with a reversed const value
710        let new_fn_id = module
711            .replace_exported_func(original_fn_id, |(body, _)| {
712                body.i32_const(4321).drop();
713            })
714            .expect("function replacement worked");
715
716        assert!(
717            module.exports.get_exported_func(original_fn_id).is_none(),
718            "replaced function cannot be gotten by ID"
719        );
720
721        // Ensure the function was replaced
722        match module
723            .exports
724            .get_exported_func(new_fn_id)
725            .expect("failed to unwrap exported func")
726        {
727            exp @ Export {
728                item: ExportItem::Function(fid),
729                ..
730            } => {
731                assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
732                assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
733            }
734            _ => panic!("expected an Export with a Function inside"),
735        }
736    }
737
738    /// Running `replace_exported_func` with a closure that returns None
739    /// should replace the function with a generated no-op function
740    #[test]
741    fn replace_exported_func_generated_no_op() {
742        let mut module = Module::default();
743
744        // Create original function
745        let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
746        builder.func_body().i32_const(1234).drop();
747        let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
748        let original_export_id = module.exports.add("dummy", original_fn_id);
749
750        // Replace the existing function with a new one with a reversed const value
751        let new_fn_id = module
752            .replace_exported_func(original_fn_id, |(body, _arg_locals)| {
753                body.unreachable();
754            })
755            .expect("export function replacement worked");
756
757        assert!(
758            module.exports.get_exported_func(original_fn_id).is_none(),
759            "replaced export function cannot be gotten by ID"
760        );
761
762        // Ensure the function was replaced
763        match module
764            .exports
765            .get_exported_func(new_fn_id)
766            .expect("failed to unwrap exported func")
767        {
768            exp @ Export {
769                item: ExportItem::Function(fid),
770                name,
771                ..
772            } => {
773                assert_eq!(name, "dummy", "function name on export is unchanged");
774                assert_eq!(*fid, new_fn_id, "retrieved function ID matches");
775                assert_eq!(exp.id(), original_export_id, "export ID is unchanged");
776            }
777            _ => panic!("expected an Export with a Function inside"),
778        }
779    }
780
781    /// Running `replace_imported_func` with a closure that builds
782    /// a function should replace the existing function with the new one
783    #[test]
784    fn replace_imported_func() {
785        let mut module = Module::default();
786
787        // Create original import function
788        let types = module.types.add(&[], &[]);
789        let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);
790
791        // Replace the existing function with a new one with a reversed const value
792        let new_fn_id = module
793            .replace_imported_func(original_fn_id, |(body, _)| {
794                body.i32_const(4321).drop();
795            })
796            .expect("import fn replacement worked");
797
798        assert!(
799            !module.imports.iter().any(|i| i.id() == original_import_id),
800            "original import is missing",
801        );
802
803        assert!(
804            module.imports.get_imported_func(original_fn_id).is_none(),
805            "replaced import function cannot be gotten by ID"
806        );
807
808        assert!(
809            module.imports.get_imported_func(new_fn_id).is_none(),
810            "new import function cannot be gotten by ID (it is now local)"
811        );
812
813        assert!(
814            matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
815            "new local function has the right kind"
816        );
817    }
818
819    /// Running `replace_imported_func` with a closure that returns None
820    /// should replace the function with a generated no-op function
821    #[test]
822    fn replace_imported_func_generated_no_op() {
823        let mut module = Module::default();
824
825        // Create original import function
826        let types = module.types.add(&[], &[]);
827        let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);
828
829        // Replace the existing function with a new one with a reversed const value
830        let new_fn_id = module
831            .replace_imported_func(original_fn_id, |(body, _arg_locals)| {
832                body.unreachable();
833            })
834            .expect("import fn replacement worked");
835
836        assert!(
837            !module.imports.iter().any(|i| i.id() == original_import_id),
838            "original import is missing",
839        );
840
841        assert!(
842            module.imports.get_imported_func(original_fn_id).is_none(),
843            "replaced import function cannot be gotten by ID"
844        );
845
846        assert!(
847            module.imports.get_imported_func(new_fn_id).is_none(),
848            "new import function cannot be gotten by ID (it is now local)"
849        );
850
851        assert!(
852            matches!(module.funcs.get(new_fn_id).kind, FunctionKind::Local(_)),
853            "new local function has the right kind"
854        );
855    }
856}