lucet_runtime_internals/module/
mock.rs

1use crate::error::Error;
2use crate::module::{AddrDetails, GlobalSpec, HeapSpec, Module, ModuleInternal, TableElement};
3use libc::c_void;
4use lucet_module::owned::{
5    OwnedExportFunction, OwnedFunctionMetadata, OwnedGlobalSpec, OwnedImportFunction,
6    OwnedLinearMemorySpec, OwnedModuleData, OwnedSparseData,
7};
8use lucet_module::{
9    FunctionHandle, FunctionIndex, FunctionPointer, FunctionSpec, ModuleData, ModuleFeatures,
10    Signature, TrapSite, UniqueSignatureIndex,
11};
12use std::collections::{BTreeMap, HashMap};
13use std::sync::Arc;
14
15#[derive(Default)]
16pub struct MockModuleBuilder {
17    heap_spec: HeapSpec,
18    sparse_page_data: Vec<Option<Vec<u8>>>,
19    globals: BTreeMap<usize, OwnedGlobalSpec>,
20    table_elements: BTreeMap<usize, TableElement>,
21    export_funcs: HashMap<&'static str, FunctionPointer>,
22    func_table: HashMap<(u32, u32), FunctionPointer>,
23    start_func: Option<FunctionPointer>,
24    function_manifest: Vec<FunctionSpec>,
25    function_info: Vec<OwnedFunctionMetadata>,
26    imports: Vec<OwnedImportFunction>,
27    exports: Vec<OwnedExportFunction>,
28    signatures: Vec<Signature>,
29}
30
31impl MockModuleBuilder {
32    pub fn new() -> Self {
33        const DEFAULT_HEAP_SPEC: HeapSpec = HeapSpec {
34            reserved_size: 4 * 1024 * 1024,
35            guard_size: 4 * 1024 * 1024,
36            initial_size: 64 * 1024,
37            max_size: Some(64 * 1024),
38        };
39        MockModuleBuilder::default().with_heap_spec(DEFAULT_HEAP_SPEC)
40    }
41
42    pub fn with_heap_spec(mut self, heap_spec: HeapSpec) -> Self {
43        self.heap_spec = heap_spec;
44        self
45    }
46
47    pub fn with_initial_heap(mut self, heap: &[u8]) -> Self {
48        self.sparse_page_data = heap
49            .chunks(4096)
50            .map(|page| {
51                if page.iter().all(|b| *b == 0) {
52                    None
53                } else {
54                    let mut page = page.to_vec();
55                    if page.len() < 4096 {
56                        page.resize(4096, 0);
57                    }
58                    Some(page)
59                }
60            })
61            .collect();
62        self
63    }
64
65    pub fn with_global(mut self, idx: u32, init_val: i64) -> Self {
66        self.globals
67            .insert(idx as usize, OwnedGlobalSpec::new_def(init_val, vec![]));
68        self
69    }
70
71    pub fn with_exported_global(mut self, idx: u32, init_val: i64, export_name: &str) -> Self {
72        self.globals.insert(
73            idx as usize,
74            OwnedGlobalSpec::new_def(init_val, vec![export_name.to_string()]),
75        );
76        self
77    }
78
79    pub fn with_import(mut self, idx: u32, import_module: &str, import_field: &str) -> Self {
80        self.globals.insert(
81            idx as usize,
82            OwnedGlobalSpec::new_import(
83                import_module.to_string(),
84                import_field.to_string(),
85                vec![],
86            ),
87        );
88        self
89    }
90
91    pub fn with_exported_import(
92        mut self,
93        idx: u32,
94        import_module: &str,
95        import_field: &str,
96        export_name: &str,
97    ) -> Self {
98        self.globals.insert(
99            idx as usize,
100            OwnedGlobalSpec::new_import(
101                import_module.to_string(),
102                import_field.to_string(),
103                vec![export_name.to_string()],
104            ),
105        );
106        self
107    }
108
109    pub fn with_table_element(mut self, idx: u32, element: &TableElement) -> Self {
110        self.table_elements.insert(idx as usize, element.clone());
111        self
112    }
113
114    fn record_sig(&mut self, sig: Signature) -> UniqueSignatureIndex {
115        let idx = self
116            .signatures
117            .iter()
118            .enumerate()
119            .find(|(_, v)| *v == &sig)
120            .map(|(key, _)| key)
121            .unwrap_or_else(|| {
122                self.signatures.push(sig);
123                self.signatures.len() - 1
124            });
125        UniqueSignatureIndex::from_u32(idx as u32)
126    }
127
128    pub fn with_export_func(mut self, export: MockExportBuilder) -> Self {
129        self.export_funcs.insert(export.sym(), export.func());
130        let sig_idx = self.record_sig(export.sig());
131        self.function_info.push(OwnedFunctionMetadata {
132            signature: sig_idx,
133            name: Some(export.sym().to_string()),
134        });
135        self.exports.push(OwnedExportFunction {
136            fn_idx: FunctionIndex::from_u32(self.function_manifest.len() as u32),
137            names: vec![export.sym().to_string()],
138        });
139        self.function_manifest.push(FunctionSpec::new(
140            export.func().as_usize() as u64,
141            export.func_len() as u32,
142            export.traps().as_ptr() as u64,
143            export.traps().len() as u64,
144        ));
145        self
146    }
147
148    pub fn with_exported_import_func(
149        mut self,
150        export_name: &'static str,
151        import_fn_ptr: FunctionPointer,
152        sig: Signature,
153    ) -> Self {
154        self.export_funcs.insert(export_name, import_fn_ptr);
155        let sig_idx = self.record_sig(sig);
156        self.function_info.push(OwnedFunctionMetadata {
157            signature: sig_idx,
158            name: Some(export_name.to_string()),
159        });
160        self.exports.push(OwnedExportFunction {
161            fn_idx: FunctionIndex::from_u32(self.function_manifest.len() as u32),
162            names: vec![export_name.to_string()],
163        });
164        self.function_manifest.push(FunctionSpec::new(
165            import_fn_ptr.as_usize() as u64,
166            0u32,
167            0u64,
168            0u64,
169        ));
170        self
171    }
172
173    pub fn with_table_func(mut self, table_idx: u32, func_idx: u32, func: FunctionPointer) -> Self {
174        self.func_table.insert((table_idx, func_idx), func);
175        self
176    }
177
178    pub fn with_start_func(mut self, func: FunctionPointer) -> Self {
179        self.start_func = Some(func);
180        self
181    }
182
183    pub fn build(self) -> Arc<dyn Module> {
184        assert!(
185            self.sparse_page_data.len() * 4096 <= self.heap_spec.initial_size as usize,
186            "heap must fit in heap spec initial size"
187        );
188
189        let table_elements = self
190            .table_elements
191            .into_iter()
192            .enumerate()
193            .map(|(expected_idx, (idx, te))| {
194                assert_eq!(
195                    idx, expected_idx,
196                    "table element indices must be contiguous starting from 0"
197                );
198                te
199            })
200            .collect();
201        let globals_spec = self
202            .globals
203            .into_iter()
204            .enumerate()
205            .map(|(expected_idx, (idx, gs))| {
206                assert_eq!(
207                    idx, expected_idx,
208                    "global indices must be contiguous starting from 0"
209                );
210                gs
211            })
212            .collect();
213        let owned_module_data = OwnedModuleData::new(
214            Some(OwnedLinearMemorySpec {
215                heap: self.heap_spec,
216                initializer: OwnedSparseData::new(self.sparse_page_data)
217                    .expect("sparse data pages are valid"),
218            }),
219            globals_spec,
220            self.function_info.clone(),
221            self.imports,
222            self.exports,
223            self.signatures,
224            ModuleFeatures::none(),
225        );
226        let serialized_module_data = owned_module_data
227            .to_ref()
228            .serialize()
229            .expect("serialization of module_data succeeds");
230        let module_data = ModuleData::deserialize(&serialized_module_data)
231            .map(|md| unsafe { std::mem::transmute(md) })
232            .expect("module data can be deserialized");
233        let mock = MockModule {
234            serialized_module_data,
235            module_data,
236            table_elements,
237            export_funcs: self.export_funcs,
238            func_table: self.func_table,
239            start_func: self.start_func,
240            function_manifest: self.function_manifest,
241        };
242        Arc::new(mock)
243    }
244}
245
246pub struct MockModule {
247    #[allow(dead_code)]
248    serialized_module_data: Vec<u8>,
249    module_data: ModuleData<'static>,
250    pub table_elements: Vec<TableElement>,
251    pub export_funcs: HashMap<&'static str, FunctionPointer>,
252    pub func_table: HashMap<(u32, u32), FunctionPointer>,
253    pub start_func: Option<FunctionPointer>,
254    pub function_manifest: Vec<FunctionSpec>,
255}
256
257unsafe impl Send for MockModule {}
258unsafe impl Sync for MockModule {}
259
260impl Module for MockModule {}
261
262impl ModuleInternal for MockModule {
263    fn is_instruction_count_instrumented(&self) -> bool {
264        self.module_data.features().instruction_count
265    }
266
267    fn heap_spec(&self) -> Option<&HeapSpec> {
268        self.module_data.heap_spec()
269    }
270
271    fn globals(&self) -> &[GlobalSpec<'_>] {
272        self.module_data.globals_spec()
273    }
274
275    fn get_sparse_page_data(&self, page: usize) -> Option<&[u8]> {
276        if let Some(ref sparse_data) = self.module_data.sparse_data() {
277            *sparse_data.get_page(page)
278        } else {
279            None
280        }
281    }
282
283    fn sparse_page_data_len(&self) -> usize {
284        self.module_data.sparse_data().map(|d| d.len()).unwrap_or(0)
285    }
286
287    fn table_elements(&self) -> Result<&[TableElement], Error> {
288        Ok(&self.table_elements)
289    }
290
291    fn get_export_func(&self, sym: &str) -> Result<FunctionHandle, Error> {
292        let ptr = *self
293            .export_funcs
294            .get(sym)
295            .ok_or(Error::SymbolNotFound(sym.to_string()))?;
296
297        Ok(self.function_handle_from_ptr(ptr))
298    }
299
300    fn get_func_from_idx(&self, table_id: u32, func_id: u32) -> Result<FunctionHandle, Error> {
301        let ptr = self
302            .func_table
303            .get(&(table_id, func_id))
304            .cloned()
305            .ok_or(Error::FuncNotFound(table_id, func_id))?;
306
307        Ok(self.function_handle_from_ptr(ptr))
308    }
309
310    fn get_start_func(&self) -> Result<Option<FunctionHandle>, Error> {
311        Ok(self
312            .start_func
313            .map(|start| self.function_handle_from_ptr(start)))
314    }
315
316    fn function_manifest(&self) -> &[FunctionSpec] {
317        &self.function_manifest
318    }
319
320    fn addr_details(&self, _addr: *const c_void) -> Result<Option<AddrDetails>, Error> {
321        // we can call `dladdr` on Rust code, but unless we inspect the stack I don't think there's
322        // a way to determine whether or not we're in "module" code; punt for now
323        Ok(None)
324    }
325
326    fn get_signature(&self, fn_id: FunctionIndex) -> &Signature {
327        self.module_data.get_signature(fn_id)
328    }
329}
330
331pub struct MockExportBuilder {
332    sym: &'static str,
333    func: FunctionPointer,
334    func_len: Option<usize>,
335    traps: Option<&'static [TrapSite]>,
336    sig: Signature,
337}
338
339impl MockExportBuilder {
340    pub fn new(name: &'static str, func: FunctionPointer) -> MockExportBuilder {
341        MockExportBuilder {
342            sym: name,
343            func: func,
344            func_len: None,
345            traps: None,
346            sig: Signature {
347                params: vec![],
348                ret_ty: None,
349            },
350        }
351    }
352
353    pub fn with_func_len(mut self, len: usize) -> MockExportBuilder {
354        self.func_len = Some(len);
355        self
356    }
357
358    pub fn with_traps(mut self, traps: &'static [TrapSite]) -> MockExportBuilder {
359        self.traps = Some(traps);
360        self
361    }
362
363    pub fn with_sig(mut self, sig: Signature) -> MockExportBuilder {
364        self.sig = sig;
365        self
366    }
367
368    pub fn sym(&self) -> &'static str {
369        self.sym
370    }
371    pub fn func(&self) -> FunctionPointer {
372        self.func
373    }
374    pub fn func_len(&self) -> usize {
375        self.func_len.unwrap_or(1)
376    }
377    pub fn traps(&self) -> &'static [TrapSite] {
378        self.traps.unwrap_or(&[])
379    }
380    pub fn sig(&self) -> Signature {
381        self.sig.clone()
382    }
383}