Skip to main content

wasm_pvm/translate/
mod.rs

1// Address calculations and jump offsets often require wrapping/truncation.
2#![allow(
3    clippy::cast_possible_truncation,
4    clippy::cast_possible_wrap,
5    clippy::cast_sign_loss
6)]
7
8pub mod adapter_merge;
9pub mod dead_function_elimination;
10pub use crate::memory_layout;
11pub mod wasm_module;
12
13use std::collections::HashMap;
14
15use crate::pvm::Instruction;
16use crate::{Error, Result, SpiProgram};
17
18pub use wasm_module::WasmModule;
19
20/// Action to take when a WASM import is called.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ImportAction {
23    /// Emit a trap (unreachable) instruction.
24    Trap,
25    /// Emit a no-op (return 0 for functions with return values).
26    Nop,
27}
28
29/// Flags to enable/disable individual compiler optimizations.
30/// All optimizations are enabled by default.
31#[derive(Debug, Clone)]
32#[allow(clippy::struct_excessive_bools)]
33pub struct OptimizationFlags {
34    /// Run LLVM optimization passes (mem2reg, instcombine, simplifycfg, gvn, dce).
35    /// When false, also disables inlining (all LLVM passes are skipped).
36    pub llvm_passes: bool,
37    /// Run peephole optimizer (fallthrough removal, dead code elimination).
38    pub peephole: bool,
39    /// Enable per-block register cache (store-load forwarding).
40    pub register_cache: bool,
41    /// Fuse `ICmp` + Branch into a single PVM branch instruction.
42    pub icmp_branch_fusion: bool,
43    /// Only save/restore callee-saved registers (r9-r12) that are actually used.
44    pub shrink_wrap_callee_saves: bool,
45    /// Eliminate SP-relative stores whose target offset is never loaded from.
46    pub dead_store_elimination: bool,
47    /// Skip redundant `LoadImm`/`LoadImm64` when the register already holds the constant.
48    pub constant_propagation: bool,
49    /// Inline small functions at the LLVM IR level to eliminate call overhead.
50    pub inlining: bool,
51    /// Propagate register cache across single-predecessor block boundaries.
52    pub cross_block_cache: bool,
53    /// Allocate long-lived SSA values to physical registers (r5, r6) across block boundaries.
54    pub register_allocation: bool,
55    /// Eliminate unreachable functions not called from entry points or the function table.
56    pub dead_function_elimination: bool,
57    /// Eliminate unconditional jumps to the immediately following block (fallthrough).
58    pub fallthrough_jumps: bool,
59}
60
61impl Default for OptimizationFlags {
62    fn default() -> Self {
63        Self {
64            llvm_passes: true,
65            peephole: true,
66            register_cache: true,
67            icmp_branch_fusion: true,
68            shrink_wrap_callee_saves: true,
69            dead_store_elimination: true,
70            constant_propagation: true,
71            inlining: true,
72            cross_block_cache: true,
73            register_allocation: true,
74            dead_function_elimination: true,
75            fallthrough_jumps: true,
76        }
77    }
78}
79
80/// Options for compilation.
81#[derive(Debug, Clone, Default)]
82pub struct CompileOptions {
83    /// Mapping from import function names to actions.
84    /// When provided, all imports (except known intrinsics like `host_call_N` and `pvm_ptr`)
85    /// must have a mapping or compilation will fail with `UnresolvedImport`.
86    pub import_map: Option<HashMap<String, ImportAction>>,
87    /// WAT source for an adapter module whose exports replace matching main imports.
88    /// Applied before the text-based import map, so the two compose.
89    pub adapter: Option<String>,
90    /// Metadata blob to prepend to the SPI output.
91    /// Typically contains the source filename and compiler version.
92    pub metadata: Vec<u8>,
93    /// Optimization flags controlling which compiler passes are enabled.
94    pub optimizations: OptimizationFlags,
95}
96
97// Re-export register constants from abi module
98pub use crate::abi::{ARGS_LEN_REG, ARGS_PTR_REG, RETURN_ADDR_REG, STACK_PTR_REG};
99
100// ── Call fixup types (shared with LLVM backend) ──
101
102#[derive(Debug, Clone)]
103pub struct CallFixup {
104    pub return_addr_instr: usize,
105    pub jump_instr: usize,
106    pub target_func: u32,
107}
108
109#[derive(Debug, Clone)]
110pub struct IndirectCallFixup {
111    pub return_addr_instr: usize,
112    // For `LoadImmJumpInd`, this equals `return_addr_instr`.
113    pub jump_ind_instr: usize,
114}
115
116/// `RO_DATA` region size is 64KB (0x10000 to 0x1FFFF)
117const RO_DATA_SIZE: usize = 64 * 1024;
118
119/// Check if an import name is a known compiler intrinsic (`host_call_N` variants, `pvm_ptr`).
120fn is_known_intrinsic(name: &str) -> bool {
121    if name == "pvm_ptr" || name == "host_call_r8" {
122        return true;
123    }
124    if let Some(suffix) = name.strip_prefix("host_call_") {
125        // host_call_0..6 or host_call_0b..6b
126        let digits = suffix.strip_suffix('b').unwrap_or(suffix);
127        if let Ok(n) = digits.parse::<u8>() {
128            return n <= crate::abi::MAX_HOST_CALL_DATA_ARGS;
129        }
130    }
131    false
132}
133
134pub fn compile(wasm: &[u8]) -> Result<SpiProgram> {
135    compile_with_options(wasm, &CompileOptions::default())
136}
137
138pub fn compile_with_options(wasm: &[u8], options: &CompileOptions) -> Result<SpiProgram> {
139    // Default mappings applied when no explicit import map is provided.
140    const DEFAULT_MAPPINGS: &[&str] = &["abort"];
141
142    // Apply adapter merge if provided (produces a new WASM binary with fewer imports).
143    let merged_wasm;
144    let wasm = if let Some(adapter_wat) = &options.adapter {
145        merged_wasm = adapter_merge::merge_adapter(wasm, adapter_wat)?;
146        &merged_wasm
147    } else {
148        wasm
149    };
150
151    let module = WasmModule::parse(wasm)?;
152
153    // Validate that all imports are resolved.
154    for name in &module.imported_func_names {
155        if is_known_intrinsic(name) {
156            continue;
157        }
158        if let Some(import_map) = &options.import_map {
159            if import_map.contains_key(name) {
160                continue;
161            }
162        } else if DEFAULT_MAPPINGS.contains(&name.as_str()) {
163            continue;
164        }
165        return Err(Error::UnresolvedImport(format!(
166            "import '{name}' has no mapping. Provide a mapping via --imports or add it to the import map."
167        )));
168    }
169
170    compile_via_llvm(&module, options)
171}
172
173pub fn compile_via_llvm(module: &WasmModule, options: &CompileOptions) -> Result<SpiProgram> {
174    use crate::llvm_backend::{self, LoweringContext};
175    use crate::llvm_frontend;
176    use inkwell::context::Context;
177
178    // Phase 0: Dead function elimination — compute reachable set.
179    let reachable_locals = if options.optimizations.dead_function_elimination {
180        Some(dead_function_elimination::reachable_functions(module)?)
181    } else {
182        None
183    };
184
185    // Phase 1: WASM → LLVM IR
186    let context = Context::create();
187    let llvm_module = llvm_frontend::translate_wasm_to_llvm(
188        &context,
189        module,
190        options.optimizations.llvm_passes,
191        options.optimizations.inlining,
192        reachable_locals.as_ref(),
193    )?;
194
195    // Calculate RO_DATA offsets and lengths for passive data segments
196    let mut data_segment_offsets = std::collections::HashMap::new();
197    let mut data_segment_lengths = std::collections::HashMap::new();
198    let mut current_ro_offset = if module.function_table.is_empty() {
199        1 // dummy byte if no function table
200    } else {
201        module.function_table.len() * 8 // jump_ref + type_idx per entry
202    };
203
204    let mut data_segment_length_addrs = std::collections::HashMap::new();
205    let mut passive_ordinal = 0usize;
206
207    for (idx, seg) in module.data_segments.iter().enumerate() {
208        if seg.offset.is_none() {
209            // Check that segment fits within RO_DATA region
210            if current_ro_offset + seg.data.len() > RO_DATA_SIZE {
211                return Err(Error::Internal(format!(
212                    "passive data segment {} (size {}) would overflow RO_DATA region ({} bytes used of {})",
213                    idx,
214                    seg.data.len(),
215                    current_ro_offset,
216                    RO_DATA_SIZE
217                )));
218            }
219            data_segment_offsets.insert(idx as u32, current_ro_offset as u32);
220            data_segment_lengths.insert(idx as u32, seg.data.len() as u32);
221            data_segment_length_addrs.insert(
222                idx as u32,
223                memory_layout::data_segment_length_offset(module.globals.len(), passive_ordinal),
224            );
225            current_ro_offset += seg.data.len();
226            passive_ordinal += 1;
227        }
228    }
229
230    // Phase 2: Build lowering context
231    let ctx = LoweringContext {
232        wasm_memory_base: module.wasm_memory_base,
233        num_globals: module.globals.len(),
234        function_signatures: module.function_signatures.clone(),
235        type_signatures: module.type_signatures.clone(),
236        function_table: module.function_table.clone(),
237        num_imported_funcs: module.num_imported_funcs as usize,
238        imported_func_names: module.imported_func_names.clone(),
239        initial_memory_pages: module.memory_limits.initial_pages,
240        max_memory_pages: module.max_memory_pages,
241        stack_size: memory_layout::DEFAULT_STACK_SIZE,
242        data_segment_offsets,
243        data_segment_lengths,
244        data_segment_length_addrs,
245        wasm_import_map: options.import_map.clone(),
246        optimizations: options.optimizations.clone(),
247    };
248
249    // Phase 3: LLVM IR → PVM bytecode for each function
250    let mut all_instructions: Vec<Instruction> = Vec::new();
251    let mut all_call_fixups: Vec<(usize, CallFixup)> = Vec::new();
252    let mut all_indirect_call_fixups: Vec<(usize, IndirectCallFixup)> = Vec::new();
253    let mut function_offsets: Vec<usize> = vec![0; module.functions.len()];
254    let mut next_call_return_idx: usize = 0;
255
256    // Entry header: Jump to main (PC=0) + Trap or secondary Jump (PC=5).
257    // When there's no secondary entry, we omit the Fallthrough padding (6 bytes instead of 10).
258    all_instructions.push(Instruction::Jump { offset: 0 });
259    if module.has_secondary_entry {
260        all_instructions.push(Instruction::Jump { offset: 0 });
261    } else {
262        all_instructions.push(Instruction::Trap);
263    }
264
265    // Build emission order: main first, then secondary (if any), then remaining in index order.
266    // This places main immediately after the entry header, minimizing the entry Jump distance.
267    let mut emission_order: Vec<usize> = Vec::with_capacity(module.functions.len());
268    emission_order.push(module.main_func_local_idx);
269    if let Some(secondary_idx) = module.secondary_entry_local_idx
270        && secondary_idx != module.main_func_local_idx
271    {
272        emission_order.push(secondary_idx);
273    }
274    for idx in 0..module.functions.len() {
275        if idx != module.main_func_local_idx && module.secondary_entry_local_idx != Some(idx) {
276            emission_order.push(idx);
277        }
278    }
279
280    for &local_func_idx in &emission_order {
281        // Dead functions: emit a single Trap as a placeholder.
282        // The function offset is still recorded so dispatch table indices stay valid.
283        if reachable_locals
284            .as_ref()
285            .is_some_and(|r| !r.contains(&local_func_idx))
286        {
287            let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
288            function_offsets[local_func_idx] = func_start_offset;
289            all_instructions.push(Instruction::Trap);
290            continue;
291        }
292
293        let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
294        let fn_name = format!("wasm_func_{global_func_idx}");
295        let llvm_func = llvm_module
296            .get_function(&fn_name)
297            .ok_or_else(|| Error::Internal(format!("missing LLVM function: {fn_name}")))?;
298
299        let is_main = local_func_idx == module.main_func_local_idx;
300        let is_secondary = module.secondary_entry_local_idx == Some(local_func_idx);
301        let is_entry = is_main || is_secondary;
302
303        let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
304        function_offsets[local_func_idx] = func_start_offset;
305
306        // If entry function and there's a start function, call it first.
307        if let Some(start_local_idx) = module.start_func_local_idx.filter(|_| is_entry) {
308            // Save r7 and r8 to stack.
309            all_instructions.push(Instruction::AddImm64 {
310                dst: STACK_PTR_REG,
311                src: STACK_PTR_REG,
312                value: -16,
313            });
314            all_instructions.push(Instruction::StoreIndU64 {
315                base: STACK_PTR_REG,
316                src: ARGS_PTR_REG,
317                offset: 0,
318            });
319            all_instructions.push(Instruction::StoreIndU64 {
320                base: STACK_PTR_REG,
321                src: ARGS_LEN_REG,
322                offset: 8,
323            });
324
325            // Call start function using LoadImmJump (combined load + jump).
326            let call_return_addr = ((next_call_return_idx + 1) * 2) as i32;
327            next_call_return_idx += 1;
328            let current_instr_idx = all_instructions.len();
329            all_instructions.push(Instruction::LoadImmJump {
330                reg: RETURN_ADDR_REG,
331                value: call_return_addr,
332                offset: 0, // patched during fixup resolution
333            });
334
335            all_call_fixups.push((
336                current_instr_idx,
337                CallFixup {
338                    target_func: start_local_idx as u32,
339                    return_addr_instr: 0,
340                    jump_instr: 0, // same instruction for LoadImmJump
341                },
342            ));
343
344            // Restore r7 and r8.
345            all_instructions.push(Instruction::LoadIndU64 {
346                dst: ARGS_PTR_REG,
347                base: STACK_PTR_REG,
348                offset: 0,
349            });
350            all_instructions.push(Instruction::LoadIndU64 {
351                dst: ARGS_LEN_REG,
352                base: STACK_PTR_REG,
353                offset: 8,
354            });
355            all_instructions.push(Instruction::AddImm64 {
356                dst: STACK_PTR_REG,
357                src: STACK_PTR_REG,
358                value: 16,
359            });
360        }
361
362        let translation = llvm_backend::lower_function(
363            llvm_func,
364            &ctx,
365            is_entry,
366            global_func_idx,
367            next_call_return_idx,
368        )?;
369        next_call_return_idx += translation.num_call_returns;
370
371        let instr_base = all_instructions.len();
372        for fixup in translation.call_fixups {
373            all_call_fixups.push((
374                instr_base,
375                CallFixup {
376                    return_addr_instr: fixup.return_addr_instr,
377                    jump_instr: fixup.jump_instr,
378                    target_func: fixup.target_func,
379                },
380            ));
381        }
382        for fixup in translation.indirect_call_fixups {
383            all_indirect_call_fixups.push((
384                instr_base,
385                IndirectCallFixup {
386                    return_addr_instr: fixup.return_addr_instr,
387                    jump_ind_instr: fixup.jump_ind_instr,
388                },
389            ));
390        }
391
392        all_instructions.extend(translation.instructions);
393    }
394
395    // Phase 4: Resolve call fixups and build jump table.
396    let (jump_table, func_entry_jump_table_base) = resolve_call_fixups(
397        &mut all_instructions,
398        &all_call_fixups,
399        &all_indirect_call_fixups,
400        &function_offsets,
401    )?;
402
403    // Patch entry header jumps.
404    let main_offset = function_offsets[module.main_func_local_idx] as i32;
405    if let Instruction::Jump { offset } = &mut all_instructions[0] {
406        *offset = main_offset;
407    }
408
409    if let Some(secondary_idx) = module.secondary_entry_local_idx {
410        let secondary_offset = function_offsets[secondary_idx] as i32 - 5;
411        if let Instruction::Jump { offset } = &mut all_instructions[1] {
412            *offset = secondary_offset;
413        }
414    }
415
416    // Phase 5: Build dispatch table for call_indirect.
417    let mut ro_data = vec![0u8];
418    if !module.function_table.is_empty() {
419        ro_data.clear();
420        for &func_idx in &module.function_table {
421            if func_idx == u32::MAX || (func_idx as usize) < module.num_imported_funcs as usize {
422                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
423                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
424            } else {
425                let local_func_idx = func_idx as usize - module.num_imported_funcs as usize;
426                let jump_ref = 2 * (func_entry_jump_table_base + local_func_idx + 1) as u32;
427                ro_data.extend_from_slice(&jump_ref.to_le_bytes());
428                let type_idx = *module
429                    .function_type_indices
430                    .get(local_func_idx)
431                    .unwrap_or(&u32::MAX);
432                ro_data.extend_from_slice(&type_idx.to_le_bytes());
433            }
434        }
435    }
436
437    // Append passive data segments to RO_DATA.
438    // NOTE: This loop must iterate data_segments in the same order as the offset
439    // calculation loop above, since data_segment_offsets indices depend on it.
440    for seg in &module.data_segments {
441        if seg.offset.is_none() {
442            ro_data.extend_from_slice(&seg.data);
443        }
444    }
445
446    let blob = crate::pvm::ProgramBlob::new(all_instructions).with_jump_table(jump_table);
447    let rw_data_section = build_rw_data(
448        &module.data_segments,
449        &module.global_init_values,
450        module.memory_limits.initial_pages,
451        module.wasm_memory_base,
452        &ctx.data_segment_length_addrs,
453        &ctx.data_segment_lengths,
454    );
455
456    let heap_pages = calculate_heap_pages(
457        rw_data_section.len(),
458        module.wasm_memory_base,
459        module.memory_limits.initial_pages,
460        module.functions.len(),
461    )?;
462
463    Ok(SpiProgram::new(blob)
464        .with_heap_pages(heap_pages)
465        .with_ro_data(ro_data)
466        .with_rw_data(rw_data_section)
467        .with_metadata(options.metadata.clone()))
468}
469
470/// Calculate the number of 4KB PVM heap pages needed after `rw_data`.
471///
472/// `heap_pages` tells the runtime how many zero-initialized writable pages to allocate
473/// immediately after the `rw_data` blob. This covers the initial WASM linear memory,
474/// globals, and spilled locals that aren't already covered by `rw_data`.
475///
476/// By computing this **after** `build_rw_data()`, we use the actual (trimmed) `rw_data`
477/// length instead of guessing with headroom.
478///
479/// We add 1 extra page beyond the exact initial memory requirement. This ensures that
480/// the first `memory.grow` / sbrk allocation has a pre-allocated page available at the
481/// boundary of the initial WASM memory. Without it, PVM-in-PVM execution fails because
482/// the inner interpreter's page-fault handling at the exact heap boundary doesn't
483/// correctly propagate through the outer PVM.
484fn calculate_heap_pages(
485    rw_data_len: usize,
486    wasm_memory_base: i32,
487    initial_pages: u32,
488    num_functions: usize,
489) -> Result<u16> {
490    use wasm_module::MIN_INITIAL_WASM_PAGES;
491
492    let initial_pages = initial_pages.max(MIN_INITIAL_WASM_PAGES);
493    let wasm_memory_initial_end = wasm_memory_base as usize + (initial_pages as usize) * 64 * 1024;
494
495    let spilled_locals_end = memory_layout::SPILLED_LOCALS_BASE as usize
496        + num_functions * memory_layout::SPILLED_LOCALS_PER_FUNC as usize;
497
498    let end = spilled_locals_end.max(wasm_memory_initial_end);
499    let total_bytes = end - memory_layout::GLOBAL_MEMORY_BASE as usize;
500    let rw_pages = rw_data_len.div_ceil(4096);
501    let total_pages = total_bytes.div_ceil(4096);
502    let heap_pages = total_pages.saturating_sub(rw_pages) + 1;
503
504    u16::try_from(heap_pages).map_err(|_| {
505        Error::Internal(format!(
506            "heap size {heap_pages} pages exceeds u16::MAX ({}) — module too large",
507            u16::MAX
508        ))
509    })
510}
511
512/// Build the `rw_data` section from WASM data segments and global initializers.
513pub(crate) fn build_rw_data(
514    data_segments: &[wasm_module::DataSegment],
515    global_init_values: &[i32],
516    initial_memory_pages: u32,
517    wasm_memory_base: i32,
518    data_segment_length_addrs: &std::collections::HashMap<u32, i32>,
519    data_segment_lengths: &std::collections::HashMap<u32, u32>,
520) -> Vec<u8> {
521    // Calculate the minimum size needed for globals
522    // +1 for the compiler-managed memory size global, plus passive segment lengths
523    let num_passive_segments = data_segment_length_addrs.len();
524    let globals_end =
525        memory_layout::globals_region_size(global_init_values.len(), num_passive_segments);
526
527    // Calculate the size needed for data segments
528    let wasm_to_rw_offset = wasm_memory_base as u32 - 0x30000;
529
530    let data_end = data_segments
531        .iter()
532        .filter_map(|seg| {
533            seg.offset
534                .map(|off| wasm_to_rw_offset + off + seg.data.len() as u32)
535        })
536        .max()
537        .unwrap_or(0) as usize;
538
539    let total_size = globals_end.max(data_end);
540
541    if total_size == 0 {
542        return Vec::new();
543    }
544
545    let mut rw_data = vec![0u8; total_size];
546
547    // Initialize user globals
548    for (i, &value) in global_init_values.iter().enumerate() {
549        let offset = i * 4;
550        if offset + 4 <= rw_data.len() {
551            rw_data[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
552        }
553    }
554
555    // Initialize compiler-managed memory size global (right after user globals)
556    let mem_size_offset = global_init_values.len() * 4;
557    if mem_size_offset + 4 <= rw_data.len() {
558        rw_data[mem_size_offset..mem_size_offset + 4]
559            .copy_from_slice(&initial_memory_pages.to_le_bytes());
560    }
561
562    // Initialize passive data segment effective lengths (right after memory size global).
563    // These are used by memory.init for bounds checking and zeroed by data.drop.
564    for (&seg_idx, &addr) in data_segment_length_addrs {
565        if let Some(&length) = data_segment_lengths.get(&seg_idx) {
566            // addr is absolute PVM address; convert to rw_data offset
567            let rw_offset = (addr - memory_layout::GLOBAL_MEMORY_BASE) as usize;
568            if rw_offset + 4 <= rw_data.len() {
569                rw_data[rw_offset..rw_offset + 4].copy_from_slice(&length.to_le_bytes());
570            }
571        }
572    }
573
574    // Copy data segments to their WASM memory locations
575    for seg in data_segments {
576        if let Some(offset) = seg.offset {
577            let rw_offset = (wasm_to_rw_offset + offset) as usize;
578            if rw_offset + seg.data.len() <= rw_data.len() {
579                rw_data[rw_offset..rw_offset + seg.data.len()].copy_from_slice(&seg.data);
580            }
581        }
582    }
583
584    // Trim trailing zeros to reduce SPI size. Heap pages are zero-initialized,
585    // so omitted high-address zero bytes are semantically equivalent.
586    if let Some(last_non_zero) = rw_data.iter().rposition(|&b| b != 0) {
587        rw_data.truncate(last_non_zero + 1);
588    } else {
589        rw_data.clear();
590    }
591
592    rw_data
593}
594
595/// Extract the pre-assigned jump-table index from a return-address load instruction.
596///
597/// Call return addresses are pre-assigned as `(idx + 1) * 2` at emission time.
598/// This helper recovers `idx` so that `resolve_call_fixups` can write the byte
599/// offset into the correct jump-table slot instead of appending in list order
600/// (which would desync when a function mixes direct and indirect calls).
601///
602/// Direct calls use `LoadImmJump`, while indirect calls use either `LoadImm` (legacy
603/// two-instruction sequence) or `LoadImmJumpInd` (combined return-addr load + jump).
604fn return_addr_jump_table_idx(
605    instructions: &[Instruction],
606    return_addr_instr: usize,
607) -> Result<usize> {
608    let value = match instructions.get(return_addr_instr) {
609        Some(
610            Instruction::LoadImmJump { value, .. }
611            | Instruction::LoadImm { value, .. }
612            | Instruction::LoadImmJumpInd { value, .. },
613        ) => Some(*value),
614        _ => None,
615    };
616    match value {
617        Some(v) if v > 0 && v % 2 == 0 => Ok((v as usize / 2) - 1),
618        _ => Err(Error::Internal(format!(
619            "expected LoadImmJump/LoadImm/LoadImmJumpInd((idx+1)*2) at return_addr_instr {return_addr_instr}, got {:?}",
620            instructions.get(return_addr_instr)
621        ))),
622    }
623}
624
625fn resolve_call_fixups(
626    instructions: &mut [Instruction],
627    call_fixups: &[(usize, CallFixup)],
628    indirect_call_fixups: &[(usize, IndirectCallFixup)],
629    function_offsets: &[usize],
630) -> Result<(Vec<u32>, usize)> {
631    // Count total call-return entries by finding the maximum pre-assigned index.
632    // Entries are written at their pre-assigned slot so mixed direct/indirect
633    // call ordering within a function is preserved correctly.
634    let mut num_call_returns: usize = 0;
635
636    for (instr_base, fixup) in call_fixups {
637        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
638        num_call_returns = num_call_returns.max(idx + 1);
639    }
640    for (instr_base, fixup) in indirect_call_fixups {
641        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
642        num_call_returns = num_call_returns.max(idx + 1);
643    }
644
645    let mut jump_table: Vec<u32> = vec![0u32; num_call_returns];
646
647    // Call return addresses (LoadImmJump/LoadImm/LoadImmJumpInd values) are pre-assigned at emission time,
648    // so we only need to compute byte offsets for the jump table and patch Jump targets.
649    // Write each entry at its pre-assigned index to keep values in sync.
650    for (instr_base, fixup) in call_fixups {
651        let target_offset = function_offsets
652            .get(fixup.target_func as usize)
653            .ok_or_else(|| {
654                Error::Unsupported(format!("call to unknown function {}", fixup.target_func))
655            })?;
656
657        let jump_idx = instr_base + fixup.jump_instr;
658
659        // Return address = byte offset after the LoadImmJump instruction.
660        let return_addr_offset: usize = instructions[..=jump_idx]
661            .iter()
662            .map(|i| i.encode().len())
663            .sum();
664
665        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
666        jump_table[slot] = return_addr_offset as u32;
667
668        // Verify pre-assigned jump table address matches actual index.
669        let expected_addr = ((slot + 1) * 2) as i32;
670        debug_assert!(
671            matches!(&instructions[jump_idx], Instruction::LoadImmJump { value, .. } if *value == expected_addr),
672            "pre-assigned jump table address mismatch: expected {expected_addr}, got {:?}",
673            &instructions[jump_idx]
674        );
675
676        // Patch the offset field of LoadImmJump.
677        let jump_start_offset: usize = instructions[..jump_idx]
678            .iter()
679            .map(|i| i.encode().len())
680            .sum();
681        let relative_offset = (*target_offset as i32) - (jump_start_offset as i32);
682
683        if let Instruction::LoadImmJump { offset, .. } = &mut instructions[jump_idx] {
684            *offset = relative_offset;
685        }
686    }
687
688    for (instr_base, fixup) in indirect_call_fixups {
689        let jump_ind_idx = instr_base + fixup.jump_ind_instr;
690
691        let return_addr_offset: usize = instructions[..=jump_ind_idx]
692            .iter()
693            .map(|i| i.encode().len())
694            .sum();
695
696        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
697        jump_table[slot] = return_addr_offset as u32;
698    }
699
700    let func_entry_base = jump_table.len();
701    for &offset in function_offsets {
702        jump_table.push(offset as u32);
703    }
704
705    Ok((jump_table, func_entry_base))
706}
707
708#[cfg(test)]
709mod tests {
710    use std::collections::HashMap;
711
712    use super::build_rw_data;
713    use super::memory_layout;
714    use super::wasm_module::DataSegment;
715
716    #[test]
717    fn build_rw_data_trims_all_zero_tail_to_empty() {
718        let rw = build_rw_data(&[], &[], 0, 0x30000, &HashMap::new(), &HashMap::new());
719        assert!(rw.is_empty());
720    }
721
722    #[test]
723    fn build_rw_data_preserves_internal_zeros_and_trims_trailing_zeros() {
724        let data_segments = vec![DataSegment {
725            offset: Some(0),
726            data: vec![1, 0, 2, 0, 0],
727        }];
728
729        let rw = build_rw_data(
730            &data_segments,
731            &[],
732            0,
733            0x30000,
734            &HashMap::new(),
735            &HashMap::new(),
736        );
737
738        assert_eq!(rw, vec![1, 0, 2]);
739    }
740
741    #[test]
742    fn build_rw_data_keeps_non_zero_passive_length_bytes() {
743        let mut addrs = HashMap::new();
744        addrs.insert(0u32, memory_layout::GLOBAL_MEMORY_BASE + 4);
745        let mut lengths = HashMap::new();
746        lengths.insert(0u32, 7u32);
747
748        let rw = build_rw_data(&[], &[], 0, 0x30000, &addrs, &lengths);
749
750        assert_eq!(rw, vec![0, 0, 0, 0, 7]);
751    }
752
753    // ── calculate_heap_pages tests ──
754
755    #[test]
756    fn heap_pages_with_empty_rw_data_equals_total_pages_plus_one() {
757        // wasm_memory_base = 0x33000 (typical), initial_pages = 0 (clamped to 16)
758        // end = 0x33000 + 16*64*1024 = 0x33000 + 0x100000 = 0x133000
759        // total_bytes = 0x133000 - 0x30000 = 0x103000 = 1060864
760        // total_pages = ceil(1060864 / 4096) = 259
761        // rw_pages = 0, heap_pages = 259 + 1 = 260
762        let pages = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
763        assert_eq!(pages, 260);
764    }
765
766    #[test]
767    fn heap_pages_reduced_by_rw_data_pages() {
768        // Same scenario but with 8192 bytes of rw_data (2 pages)
769        let pages_no_rw = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
770        let pages_with_rw = super::calculate_heap_pages(8192, 0x33000, 0, 10).unwrap();
771        assert_eq!(pages_no_rw - pages_with_rw, 2);
772    }
773
774    #[test]
775    fn heap_pages_saturates_at_one_for_large_rw_data() {
776        // rw_data that covers more than total_pages still gets +1 headroom
777        let pages = super::calculate_heap_pages(2 * 1024 * 1024, 0x33000, 0, 10).unwrap();
778        assert_eq!(pages, 1);
779    }
780
781    #[test]
782    fn heap_pages_respects_initial_pages() {
783        // initial_pages = 32 (larger than MIN_INITIAL_WASM_PAGES=16)
784        // end = 0x33000 + 32*64*1024 = 0x33000 + 0x200000 = 0x233000
785        // total_bytes = 0x233000 - 0x30000 = 0x203000
786        // total_pages = ceil(0x203000 / 4096) = 515
787        // heap_pages = 515 + 1 = 516
788        let pages = super::calculate_heap_pages(0, 0x33000, 32, 10).unwrap();
789        assert_eq!(pages, 516);
790    }
791}