wit_walrus/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use std::borrow::Cow;
3use std::collections::HashMap;
4use walrus::{passes::Roots, CustomSection, IdsToIndices, IndicesToIds, Module};
5use wit_schema_version::SECTION_NAME;
6
7#[derive(Debug, Default)]
8pub struct WasmInterfaceTypes {
9    pub types: Types,
10    pub imports: Imports,
11    pub implements: Implements,
12    pub exports: Exports,
13    pub funcs: Funcs,
14}
15
16mod exports;
17mod funcs;
18mod implements;
19mod imports;
20mod types;
21pub use self::exports::*;
22pub use self::funcs::*;
23pub use self::implements::*;
24pub use self::imports::*;
25pub use self::types::*;
26
27impl CustomSection for WasmInterfaceTypes {
28    fn name(&self) -> &str {
29        SECTION_NAME
30    }
31
32    fn data(&self, indices: &IdsToIndices) -> Cow<'_, [u8]> {
33        let mut writer = wit_writer::Writer::new();
34        let mut wids = WitIdsToIndices::default();
35        self.encode_types(&mut writer, &mut wids);
36        self.encode_imports(&mut writer, &mut wids);
37        self.encode_funcs(&mut writer, &mut wids, indices);
38        self.encode_exports(&mut writer, &wids);
39        self.encode_implements(&mut writer, &wids, indices);
40        writer.into_payload().into()
41    }
42
43    fn add_gc_roots(&self, roots: &mut Roots) {
44        for i in self.implements.iter() {
45            roots.push_func(i.core_func);
46        }
47        for f in self.funcs.iter() {
48            let instrs = match &f.kind {
49                FuncKind::Local(instrs) => instrs,
50                _ => continue,
51            };
52            for instr in instrs {
53                match instr {
54                    Instruction::CallCore(f) | Instruction::DeferCallCore(f) => {
55                        roots.push_func(*f);
56                    }
57                    Instruction::MemoryToString(mem) => {
58                        roots.push_memory(*mem);
59                    }
60                    Instruction::StringToMemory { mem, malloc } => {
61                        roots.push_memory(*mem).push_func(*malloc);
62                    }
63                    _ => {}
64                }
65            }
66        }
67    }
68}
69
70#[derive(Default)]
71struct WitIndicesToIds {
72    types: Vec<TypeId>,
73    funcs: Vec<FuncId>,
74}
75
76impl WitIndicesToIds {
77    fn ty(&self, ty: u32) -> Result<TypeId> {
78        self.types
79            .get(ty as usize)
80            .cloned()
81            .ok_or_else(|| anyhow!("adapter type index out of bounds: {}", ty))
82    }
83
84    fn func(&self, ty: u32) -> Result<FuncId> {
85        self.funcs
86            .get(ty as usize)
87            .cloned()
88            .ok_or_else(|| anyhow!("adapter func index out of bounds: {}", ty))
89    }
90}
91
92#[derive(Default)]
93struct WitIdsToIndices {
94    types: HashMap<TypeId, u32>,
95    funcs: HashMap<FuncId, u32>,
96}
97
98impl WitIdsToIndices {
99    fn push_ty(&mut self, ty: TypeId) {
100        self.types.insert(ty, self.types.len() as u32);
101    }
102
103    fn ty(&self, ty: TypeId) -> u32 {
104        self.types
105            .get(&ty)
106            .cloned()
107            .unwrap_or_else(|| panic!("reference to dead type found {:?}", ty))
108    }
109
110    fn push_func(&mut self, func: FuncId) {
111        self.funcs.insert(func, self.funcs.len() as u32);
112    }
113
114    fn func(&self, f: FuncId) -> u32 {
115        self.funcs
116            .get(&f)
117            .cloned()
118            .unwrap_or_else(|| panic!("reference to dead function found {:?}", f))
119    }
120}
121
122/// Callback for the `ModuleConfig::on_parse` function in `walrus` to act as a
123/// convenience to parse the wasm interface types custom section, if present.
124pub fn on_parse(module: &mut Module, ids: &IndicesToIds) -> Result<()> {
125    let section = match module.customs.remove_raw(SECTION_NAME) {
126        Some(s) => s,
127        None => return Ok(()),
128    };
129    let mut parser = wit_parser::Parser::new(0, &section.data)
130        .context("failed parsing wasm interface types header")?;
131    let mut section = WasmInterfaceTypes::default();
132    let mut wids = WitIndicesToIds::default();
133    while !parser.is_empty() {
134        let s = parser
135            .section()
136            .context("failed parsing wasm interface types section header")?;
137        match s {
138            wit_parser::Section::Type(t) => section.parse_types(t, &mut wids)?,
139            wit_parser::Section::Import(t) => section.parse_imports(t, &mut wids)?,
140            wit_parser::Section::Func(t) => section.parse_funcs(t, ids, &mut wids)?,
141            wit_parser::Section::Implement(t) => section.parse_implements(t, ids, &mut wids)?,
142            wit_parser::Section::Export(t) => section.parse_exports(t, &mut wids)?,
143        }
144    }
145
146    module.customs.add(section);
147    Ok(())
148}