Skip to main content

walrus/module/
imports.rs

1//! A wasm module's imports.
2
3use anyhow::Context;
4
5use crate::emit::{Emit, EmitContext};
6use crate::parse::IndicesToIds;
7use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
8use crate::{FunctionId, GlobalId, MemoryId, Result, TableId, TagId};
9use crate::{Module, RefType, TypeId, ValType};
10
11/// The id of an import.
12pub type ImportId = Id<Import>;
13
14/// A named item imported into the wasm.
15#[derive(Clone, Debug, Hash, Eq, PartialEq)]
16pub struct Import {
17    id: ImportId,
18    /// The module name of this import.
19    pub module: String,
20    /// The name of this import.
21    pub name: String,
22    /// The kind of item being imported.
23    pub kind: ImportKind,
24}
25
26impl Tombstone for Import {
27    fn on_delete(&mut self) {
28        self.module = String::new();
29        self.name = String::new();
30    }
31}
32
33impl Import {
34    /// Get this import's identifier.
35    pub fn id(&self) -> ImportId {
36        self.id
37    }
38}
39
40/// An imported item.
41#[derive(Clone, Debug, Hash, Eq, PartialEq)]
42pub enum ImportKind {
43    /// An imported function.
44    Function(FunctionId),
45    /// An imported table.
46    Table(TableId),
47    /// An imported memory.
48    Memory(MemoryId),
49    /// An imported global.
50    Global(GlobalId),
51    /// An imported tag for exception handling.
52    Tag(TagId),
53}
54
55/// The set of imports in a module.
56#[derive(Debug, Default)]
57pub struct ModuleImports {
58    arena: TombstoneArena<Import>,
59}
60
61impl ModuleImports {
62    /// Gets a reference to an import given its id
63    pub fn get(&self, id: ImportId) -> &Import {
64        &self.arena[id]
65    }
66
67    /// Gets a reference to an import given its id
68    pub fn get_mut(&mut self, id: ImportId) -> &mut Import {
69        &mut self.arena[id]
70    }
71
72    /// Removes an import from this module.
73    ///
74    /// It is up to you to ensure that any potential references to the deleted
75    /// import are also removed, eg `get_global` expressions.
76    pub fn delete(&mut self, id: ImportId) {
77        self.arena.delete(id);
78    }
79
80    /// Get a shared reference to this module's imports.
81    pub fn iter(&self) -> impl Iterator<Item = &Import> {
82        self.arena.iter().map(|(_, f)| f)
83    }
84
85    /// Get mutable references to this module's imports.
86    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Import> {
87        self.arena.iter_mut().map(|(_, f)| f)
88    }
89
90    /// Adds a new import to this module
91    pub fn add(&mut self, module: &str, name: &str, kind: impl Into<ImportKind>) -> ImportId {
92        self.arena.alloc_with_id(|id| Import {
93            id,
94            module: module.to_string(),
95            name: name.to_string(),
96            kind: kind.into(),
97        })
98    }
99
100    /// Get the import with the given module and name
101    pub fn find(&self, module: &str, name: &str) -> Option<ImportId> {
102        let import = self
103            .arena
104            .iter()
105            .find(|(_, import)| import.name == name && import.module == module);
106
107        Some(import?.0)
108    }
109
110    /// Retrieve an imported function by import name, including the module in which it resides
111    pub fn get_func(&self, module: impl AsRef<str>, name: impl AsRef<str>) -> Result<FunctionId> {
112        self.iter()
113            .find_map(|impt| match impt.kind {
114                ImportKind::Function(fid)
115                    if impt.module == module.as_ref() && impt.name == name.as_ref() =>
116                {
117                    Some(fid)
118                }
119                _ => None,
120            })
121            .with_context(|| format!("unable to find function import '{}'", name.as_ref()))
122    }
123
124    /// Retrieve an imported function by ID
125    pub fn get_imported_func(&self, fid: FunctionId) -> Option<&Import> {
126        self.arena.iter().find_map(|(_, import)| match import.kind {
127            ImportKind::Function(id) if fid == id => Some(import),
128            _ => None,
129        })
130    }
131
132    /// Delete an imported function by name from this module.
133    pub fn remove(&mut self, module: impl AsRef<str>, name: impl AsRef<str>) -> Result<()> {
134        let import = self
135            .iter()
136            .find(|e| e.module == module.as_ref() && e.name == name.as_ref())
137            .with_context(|| {
138                format!("failed to find imported func with name [{}]", name.as_ref())
139            })?;
140
141        self.delete(import.id());
142
143        Ok(())
144    }
145}
146
147impl Module {
148    /// Construct the import set for a wasm module.
149    pub(crate) fn parse_imports(
150        &mut self,
151        section: wasmparser::ImportSectionReader,
152        ids: &mut IndicesToIds,
153    ) -> Result<()> {
154        log::debug!("parse import section");
155        for entry in section.into_imports() {
156            let entry = entry?;
157            match entry.ty {
158                wasmparser::TypeRef::Func(idx) | wasmparser::TypeRef::FuncExact(idx) => {
159                    let ty = ids.get_type(idx)?;
160                    let id = self.add_import_func(entry.module, entry.name, ty);
161                    ids.push_func(id.0);
162                }
163                wasmparser::TypeRef::Table(t) => {
164                    let id = self.add_import_table(
165                        entry.module,
166                        entry.name,
167                        t.table64,
168                        t.initial,
169                        t.maximum,
170                        RefType::from_wasmparser(t.element_type, ids, 0)?,
171                    );
172                    ids.push_table(id.0);
173                }
174                wasmparser::TypeRef::Memory(m) => {
175                    let id = self.add_import_memory(
176                        entry.module,
177                        entry.name,
178                        m.shared,
179                        m.memory64,
180                        m.initial,
181                        m.maximum,
182                        m.page_size_log2,
183                    );
184                    ids.push_memory(id.0);
185                }
186                wasmparser::TypeRef::Global(g) => {
187                    let id = self.add_import_global(
188                        entry.module,
189                        entry.name,
190                        ValType::from_wasmparser(&g.content_type, ids, 0)?,
191                        g.mutable,
192                        g.shared,
193                    );
194                    ids.push_global(id.0);
195                }
196                wasmparser::TypeRef::Tag(tag_type) => {
197                    let ty = ids.get_type(tag_type.func_type_idx)?;
198                    let id = self.add_import_tag(entry.module, entry.name, ty);
199                    ids.push_tag(id.0);
200                }
201            }
202        }
203
204        Ok(())
205    }
206
207    /// Add an imported function to this module
208    pub fn add_import_func(
209        &mut self,
210        module: &str,
211        name: &str,
212        ty: TypeId,
213    ) -> (FunctionId, ImportId) {
214        let import = self.imports.arena.next_id();
215        let func = self.funcs.add_import(ty, import);
216        self.imports.add(module, name, func);
217        (func, import)
218    }
219
220    /// Add an imported memory to this module
221    #[allow(clippy::too_many_arguments)]
222    pub fn add_import_memory(
223        &mut self,
224        module: &str,
225        name: &str,
226        shared: bool,
227        memory64: bool,
228        initial: u64,
229        maximum: Option<u64>,
230        page_size_log2: Option<u32>,
231    ) -> (MemoryId, ImportId) {
232        let import = self.imports.arena.next_id();
233        let mem =
234            self.memories
235                .add_import(shared, memory64, initial, maximum, page_size_log2, import);
236        self.imports.add(module, name, mem);
237        (mem, import)
238    }
239
240    /// Add an imported table to this module
241    pub fn add_import_table(
242        &mut self,
243        module: &str,
244        name: &str,
245        table64: bool,
246        initial: u64,
247        maximum: Option<u64>,
248        ty: RefType,
249    ) -> (TableId, ImportId) {
250        let import = self.imports.arena.next_id();
251        let table = self
252            .tables
253            .add_import(table64, initial, maximum, ty, import);
254        self.imports.add(module, name, table);
255        (table, import)
256    }
257
258    /// Add an imported global to this module
259    pub fn add_import_global(
260        &mut self,
261        module: &str,
262        name: &str,
263        ty: ValType,
264        mutable: bool,
265        shared: bool,
266    ) -> (GlobalId, ImportId) {
267        let import = self.imports.arena.next_id();
268        let global = self.globals.add_import(ty, mutable, shared, import);
269        self.imports.add(module, name, global);
270        (global, import)
271    }
272
273    /// Add an imported tag to this module
274    pub fn add_import_tag(&mut self, module: &str, name: &str, ty: TypeId) -> (TagId, ImportId) {
275        let import = self.imports.arena.next_id();
276        let tag = self.tags.add_import(ty, import);
277        self.imports.add(module, name, tag);
278        (tag, import)
279    }
280}
281
282impl Emit for ModuleImports {
283    fn emit(&self, cx: &mut EmitContext) {
284        log::debug!("emit import section");
285
286        let mut wasm_import_section = wasm_encoder::ImportSection::new();
287
288        let count = self.iter().count();
289        if count == 0 {
290            return;
291        }
292
293        for import in self.iter() {
294            wasm_import_section.import(
295                &import.module,
296                &import.name,
297                match import.kind {
298                    ImportKind::Function(id) => {
299                        cx.indices.push_func(id);
300                        let ty = cx.module.funcs.get(id).ty();
301                        let idx = cx.indices.get_type_index(ty);
302                        wasm_encoder::EntityType::Function(idx)
303                    }
304                    ImportKind::Table(id) => {
305                        cx.indices.push_table(id);
306                        let table = cx.module.tables.get(id);
307                        wasm_encoder::EntityType::Table(wasm_encoder::TableType {
308                            element_type: table.element_ty.to_wasmencoder_ref_type(cx.indices),
309                            table64: table.table64,
310                            minimum: table.initial,
311                            maximum: table.maximum,
312                            shared: false,
313                        })
314                    }
315                    ImportKind::Memory(id) => {
316                        cx.indices.push_memory(id);
317                        let mem = cx.module.memories.get(id);
318                        wasm_encoder::EntityType::Memory(wasm_encoder::MemoryType {
319                            minimum: mem.initial,
320                            maximum: mem.maximum,
321                            memory64: mem.memory64,
322                            shared: mem.shared,
323                            page_size_log2: mem.page_size_log2,
324                        })
325                    }
326                    ImportKind::Global(id) => {
327                        cx.indices.push_global(id);
328                        let g = cx.module.globals.get(id);
329                        wasm_encoder::EntityType::Global(wasm_encoder::GlobalType {
330                            val_type: g.ty.to_wasmencoder_type(cx.indices),
331                            mutable: g.mutable,
332                            shared: g.shared,
333                        })
334                    }
335                    ImportKind::Tag(id) => {
336                        cx.indices.push_tag(id);
337                        let tag = cx.module.tags.get(id);
338                        let ty_idx = cx.indices.get_type_index(tag.ty);
339                        wasm_encoder::EntityType::Tag(wasm_encoder::TagType {
340                            kind: wasm_encoder::TagKind::Exception,
341                            func_type_idx: ty_idx,
342                        })
343                    }
344                },
345            );
346        }
347
348        cx.wasm_module.section(&wasm_import_section);
349    }
350}
351
352impl From<MemoryId> for ImportKind {
353    fn from(id: MemoryId) -> ImportKind {
354        ImportKind::Memory(id)
355    }
356}
357
358impl From<FunctionId> for ImportKind {
359    fn from(id: FunctionId) -> ImportKind {
360        ImportKind::Function(id)
361    }
362}
363
364impl From<GlobalId> for ImportKind {
365    fn from(id: GlobalId) -> ImportKind {
366        ImportKind::Global(id)
367    }
368}
369
370impl From<TableId> for ImportKind {
371    fn from(id: TableId) -> ImportKind {
372        ImportKind::Table(id)
373    }
374}
375
376impl From<TagId> for ImportKind {
377    fn from(id: TagId) -> ImportKind {
378        ImportKind::Tag(id)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::{FunctionBuilder, Module};
386
387    #[test]
388    fn get_imported_func() {
389        let mut module = Module::default();
390
391        let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
392        builder.func_body().i32_const(1234).drop();
393        let new_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
394        module.imports.add("mod", "dummy", new_fn_id);
395
396        assert!(module.imports.get_imported_func(new_fn_id).is_some());
397    }
398
399    #[test]
400    fn get_func_by_name() {
401        let mut module = Module::default();
402
403        let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
404        builder.func_body().i32_const(1234).drop();
405        let new_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
406        module.imports.add("mod", "dummy", new_fn_id);
407
408        assert!(module
409            .imports
410            .get_func("mod", "dummy")
411            .is_ok_and(|fid| fid == new_fn_id));
412    }
413}