1#![allow(
3 clippy::cast_possible_truncation,
4 clippy::cast_possible_wrap,
5 clippy::cast_sign_loss
6)]
7
8pub mod adapter_merge;
9pub mod dead_function_elimination;
10pub use crate::memory_layout;
11pub mod wasm_module;
12
13use std::collections::HashMap;
14
15use crate::pvm::Instruction;
16use crate::{Error, Result, SpiProgram};
17
18pub use wasm_module::WasmModule;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ImportAction {
23 Trap,
25 Nop,
27}
28
29#[derive(Debug, Clone)]
32#[allow(clippy::struct_excessive_bools)]
33pub struct OptimizationFlags {
34 pub llvm_passes: bool,
37 pub peephole: bool,
39 pub register_cache: bool,
41 pub icmp_branch_fusion: bool,
43 pub shrink_wrap_callee_saves: bool,
45 pub dead_store_elimination: bool,
47 pub constant_propagation: bool,
49 pub inlining: bool,
51 pub cross_block_cache: bool,
53 pub register_allocation: bool,
55 pub dead_function_elimination: bool,
57 pub fallthrough_jumps: bool,
59}
60
61impl Default for OptimizationFlags {
62 fn default() -> Self {
63 Self {
64 llvm_passes: true,
65 peephole: true,
66 register_cache: true,
67 icmp_branch_fusion: true,
68 shrink_wrap_callee_saves: true,
69 dead_store_elimination: true,
70 constant_propagation: true,
71 inlining: true,
72 cross_block_cache: true,
73 register_allocation: true,
74 dead_function_elimination: true,
75 fallthrough_jumps: true,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Default)]
82pub struct CompileOptions {
83 pub import_map: Option<HashMap<String, ImportAction>>,
87 pub adapter: Option<String>,
90 pub metadata: Vec<u8>,
93 pub optimizations: OptimizationFlags,
95}
96
97pub use crate::abi::{ARGS_LEN_REG, ARGS_PTR_REG, RETURN_ADDR_REG, STACK_PTR_REG};
99
100#[derive(Debug, Clone)]
103pub struct CallFixup {
104 pub return_addr_instr: usize,
105 pub jump_instr: usize,
106 pub target_func: u32,
107}
108
109#[derive(Debug, Clone)]
110pub struct IndirectCallFixup {
111 pub return_addr_instr: usize,
112 pub jump_ind_instr: usize,
114}
115
116const RO_DATA_SIZE: usize = 64 * 1024;
118
119fn is_known_intrinsic(name: &str) -> bool {
121 if name == "pvm_ptr" || name == "host_call_r8" {
122 return true;
123 }
124 if let Some(suffix) = name.strip_prefix("host_call_") {
125 let digits = suffix.strip_suffix('b').unwrap_or(suffix);
127 if let Ok(n) = digits.parse::<u8>() {
128 return n <= crate::abi::MAX_HOST_CALL_DATA_ARGS;
129 }
130 }
131 false
132}
133
134pub fn compile(wasm: &[u8]) -> Result<SpiProgram> {
135 compile_with_options(wasm, &CompileOptions::default())
136}
137
138pub fn compile_with_options(wasm: &[u8], options: &CompileOptions) -> Result<SpiProgram> {
139 const DEFAULT_MAPPINGS: &[&str] = &["abort"];
141
142 let merged_wasm;
144 let wasm = if let Some(adapter_wat) = &options.adapter {
145 merged_wasm = adapter_merge::merge_adapter(wasm, adapter_wat)?;
146 &merged_wasm
147 } else {
148 wasm
149 };
150
151 let module = WasmModule::parse(wasm)?;
152
153 for name in &module.imported_func_names {
155 if is_known_intrinsic(name) {
156 continue;
157 }
158 if let Some(import_map) = &options.import_map {
159 if import_map.contains_key(name) {
160 continue;
161 }
162 } else if DEFAULT_MAPPINGS.contains(&name.as_str()) {
163 continue;
164 }
165 return Err(Error::UnresolvedImport(format!(
166 "import '{name}' has no mapping. Provide a mapping via --imports or add it to the import map."
167 )));
168 }
169
170 compile_via_llvm(&module, options)
171}
172
173pub fn compile_via_llvm(module: &WasmModule, options: &CompileOptions) -> Result<SpiProgram> {
174 use crate::llvm_backend::{self, LoweringContext};
175 use crate::llvm_frontend;
176 use inkwell::context::Context;
177
178 let reachable_locals = if options.optimizations.dead_function_elimination {
180 Some(dead_function_elimination::reachable_functions(module)?)
181 } else {
182 None
183 };
184
185 let context = Context::create();
187 let llvm_module = llvm_frontend::translate_wasm_to_llvm(
188 &context,
189 module,
190 options.optimizations.llvm_passes,
191 options.optimizations.inlining,
192 reachable_locals.as_ref(),
193 )?;
194
195 let mut data_segment_offsets = std::collections::HashMap::new();
197 let mut data_segment_lengths = std::collections::HashMap::new();
198 let mut current_ro_offset = if module.function_table.is_empty() {
199 1 } else {
201 module.function_table.len() * 8 };
203
204 let mut data_segment_length_addrs = std::collections::HashMap::new();
205 let mut passive_ordinal = 0usize;
206
207 for (idx, seg) in module.data_segments.iter().enumerate() {
208 if seg.offset.is_none() {
209 if current_ro_offset + seg.data.len() > RO_DATA_SIZE {
211 return Err(Error::Internal(format!(
212 "passive data segment {} (size {}) would overflow RO_DATA region ({} bytes used of {})",
213 idx,
214 seg.data.len(),
215 current_ro_offset,
216 RO_DATA_SIZE
217 )));
218 }
219 data_segment_offsets.insert(idx as u32, current_ro_offset as u32);
220 data_segment_lengths.insert(idx as u32, seg.data.len() as u32);
221 data_segment_length_addrs.insert(
222 idx as u32,
223 memory_layout::data_segment_length_offset(module.globals.len(), passive_ordinal),
224 );
225 current_ro_offset += seg.data.len();
226 passive_ordinal += 1;
227 }
228 }
229
230 let ctx = LoweringContext {
232 wasm_memory_base: module.wasm_memory_base,
233 num_globals: module.globals.len(),
234 function_signatures: module.function_signatures.clone(),
235 type_signatures: module.type_signatures.clone(),
236 function_table: module.function_table.clone(),
237 num_imported_funcs: module.num_imported_funcs as usize,
238 imported_func_names: module.imported_func_names.clone(),
239 initial_memory_pages: module.memory_limits.initial_pages,
240 max_memory_pages: module.max_memory_pages,
241 stack_size: memory_layout::DEFAULT_STACK_SIZE,
242 data_segment_offsets,
243 data_segment_lengths,
244 data_segment_length_addrs,
245 wasm_import_map: options.import_map.clone(),
246 optimizations: options.optimizations.clone(),
247 };
248
249 let mut all_instructions: Vec<Instruction> = Vec::new();
251 let mut all_call_fixups: Vec<(usize, CallFixup)> = Vec::new();
252 let mut all_indirect_call_fixups: Vec<(usize, IndirectCallFixup)> = Vec::new();
253 let mut function_offsets: Vec<usize> = vec![0; module.functions.len()];
254 let mut next_call_return_idx: usize = 0;
255
256 all_instructions.push(Instruction::Jump { offset: 0 });
259 if module.has_secondary_entry {
260 all_instructions.push(Instruction::Jump { offset: 0 });
261 } else {
262 all_instructions.push(Instruction::Trap);
263 }
264
265 let mut emission_order: Vec<usize> = Vec::with_capacity(module.functions.len());
268 emission_order.push(module.main_func_local_idx);
269 if let Some(secondary_idx) = module.secondary_entry_local_idx
270 && secondary_idx != module.main_func_local_idx
271 {
272 emission_order.push(secondary_idx);
273 }
274 for idx in 0..module.functions.len() {
275 if idx != module.main_func_local_idx && module.secondary_entry_local_idx != Some(idx) {
276 emission_order.push(idx);
277 }
278 }
279
280 for &local_func_idx in &emission_order {
281 if reachable_locals
284 .as_ref()
285 .is_some_and(|r| !r.contains(&local_func_idx))
286 {
287 let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
288 function_offsets[local_func_idx] = func_start_offset;
289 all_instructions.push(Instruction::Trap);
290 continue;
291 }
292
293 let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
294 let fn_name = format!("wasm_func_{global_func_idx}");
295 let llvm_func = llvm_module
296 .get_function(&fn_name)
297 .ok_or_else(|| Error::Internal(format!("missing LLVM function: {fn_name}")))?;
298
299 let is_main = local_func_idx == module.main_func_local_idx;
300 let is_secondary = module.secondary_entry_local_idx == Some(local_func_idx);
301 let is_entry = is_main || is_secondary;
302
303 let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
304 function_offsets[local_func_idx] = func_start_offset;
305
306 if let Some(start_local_idx) = module.start_func_local_idx.filter(|_| is_entry) {
308 all_instructions.push(Instruction::AddImm64 {
310 dst: STACK_PTR_REG,
311 src: STACK_PTR_REG,
312 value: -16,
313 });
314 all_instructions.push(Instruction::StoreIndU64 {
315 base: STACK_PTR_REG,
316 src: ARGS_PTR_REG,
317 offset: 0,
318 });
319 all_instructions.push(Instruction::StoreIndU64 {
320 base: STACK_PTR_REG,
321 src: ARGS_LEN_REG,
322 offset: 8,
323 });
324
325 let call_return_addr = ((next_call_return_idx + 1) * 2) as i32;
327 next_call_return_idx += 1;
328 let current_instr_idx = all_instructions.len();
329 all_instructions.push(Instruction::LoadImmJump {
330 reg: RETURN_ADDR_REG,
331 value: call_return_addr,
332 offset: 0, });
334
335 all_call_fixups.push((
336 current_instr_idx,
337 CallFixup {
338 target_func: start_local_idx as u32,
339 return_addr_instr: 0,
340 jump_instr: 0, },
342 ));
343
344 all_instructions.push(Instruction::LoadIndU64 {
346 dst: ARGS_PTR_REG,
347 base: STACK_PTR_REG,
348 offset: 0,
349 });
350 all_instructions.push(Instruction::LoadIndU64 {
351 dst: ARGS_LEN_REG,
352 base: STACK_PTR_REG,
353 offset: 8,
354 });
355 all_instructions.push(Instruction::AddImm64 {
356 dst: STACK_PTR_REG,
357 src: STACK_PTR_REG,
358 value: 16,
359 });
360 }
361
362 let translation = llvm_backend::lower_function(
363 llvm_func,
364 &ctx,
365 is_entry,
366 global_func_idx,
367 next_call_return_idx,
368 )?;
369 next_call_return_idx += translation.num_call_returns;
370
371 let instr_base = all_instructions.len();
372 for fixup in translation.call_fixups {
373 all_call_fixups.push((
374 instr_base,
375 CallFixup {
376 return_addr_instr: fixup.return_addr_instr,
377 jump_instr: fixup.jump_instr,
378 target_func: fixup.target_func,
379 },
380 ));
381 }
382 for fixup in translation.indirect_call_fixups {
383 all_indirect_call_fixups.push((
384 instr_base,
385 IndirectCallFixup {
386 return_addr_instr: fixup.return_addr_instr,
387 jump_ind_instr: fixup.jump_ind_instr,
388 },
389 ));
390 }
391
392 all_instructions.extend(translation.instructions);
393 }
394
395 let (jump_table, func_entry_jump_table_base) = resolve_call_fixups(
397 &mut all_instructions,
398 &all_call_fixups,
399 &all_indirect_call_fixups,
400 &function_offsets,
401 )?;
402
403 let main_offset = function_offsets[module.main_func_local_idx] as i32;
405 if let Instruction::Jump { offset } = &mut all_instructions[0] {
406 *offset = main_offset;
407 }
408
409 if let Some(secondary_idx) = module.secondary_entry_local_idx {
410 let secondary_offset = function_offsets[secondary_idx] as i32 - 5;
411 if let Instruction::Jump { offset } = &mut all_instructions[1] {
412 *offset = secondary_offset;
413 }
414 }
415
416 let mut ro_data = vec![0u8];
418 if !module.function_table.is_empty() {
419 ro_data.clear();
420 for &func_idx in &module.function_table {
421 if func_idx == u32::MAX || (func_idx as usize) < module.num_imported_funcs as usize {
422 ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
423 ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
424 } else {
425 let local_func_idx = func_idx as usize - module.num_imported_funcs as usize;
426 let jump_ref = 2 * (func_entry_jump_table_base + local_func_idx + 1) as u32;
427 ro_data.extend_from_slice(&jump_ref.to_le_bytes());
428 let type_idx = *module
429 .function_type_indices
430 .get(local_func_idx)
431 .unwrap_or(&u32::MAX);
432 ro_data.extend_from_slice(&type_idx.to_le_bytes());
433 }
434 }
435 }
436
437 for seg in &module.data_segments {
441 if seg.offset.is_none() {
442 ro_data.extend_from_slice(&seg.data);
443 }
444 }
445
446 let blob = crate::pvm::ProgramBlob::new(all_instructions).with_jump_table(jump_table);
447 let rw_data_section = build_rw_data(
448 &module.data_segments,
449 &module.global_init_values,
450 module.memory_limits.initial_pages,
451 module.wasm_memory_base,
452 &ctx.data_segment_length_addrs,
453 &ctx.data_segment_lengths,
454 );
455
456 let heap_pages = calculate_heap_pages(
457 rw_data_section.len(),
458 module.wasm_memory_base,
459 module.memory_limits.initial_pages,
460 module.functions.len(),
461 )?;
462
463 Ok(SpiProgram::new(blob)
464 .with_heap_pages(heap_pages)
465 .with_ro_data(ro_data)
466 .with_rw_data(rw_data_section)
467 .with_metadata(options.metadata.clone()))
468}
469
470fn calculate_heap_pages(
485 rw_data_len: usize,
486 wasm_memory_base: i32,
487 initial_pages: u32,
488 num_functions: usize,
489) -> Result<u16> {
490 use wasm_module::MIN_INITIAL_WASM_PAGES;
491
492 let initial_pages = initial_pages.max(MIN_INITIAL_WASM_PAGES);
493 let wasm_memory_initial_end = wasm_memory_base as usize + (initial_pages as usize) * 64 * 1024;
494
495 let spilled_locals_end = memory_layout::SPILLED_LOCALS_BASE as usize
496 + num_functions * memory_layout::SPILLED_LOCALS_PER_FUNC as usize;
497
498 let end = spilled_locals_end.max(wasm_memory_initial_end);
499 let total_bytes = end - memory_layout::GLOBAL_MEMORY_BASE as usize;
500 let rw_pages = rw_data_len.div_ceil(4096);
501 let total_pages = total_bytes.div_ceil(4096);
502 let heap_pages = total_pages.saturating_sub(rw_pages) + 1;
503
504 u16::try_from(heap_pages).map_err(|_| {
505 Error::Internal(format!(
506 "heap size {heap_pages} pages exceeds u16::MAX ({}) — module too large",
507 u16::MAX
508 ))
509 })
510}
511
512pub(crate) fn build_rw_data(
514 data_segments: &[wasm_module::DataSegment],
515 global_init_values: &[i32],
516 initial_memory_pages: u32,
517 wasm_memory_base: i32,
518 data_segment_length_addrs: &std::collections::HashMap<u32, i32>,
519 data_segment_lengths: &std::collections::HashMap<u32, u32>,
520) -> Vec<u8> {
521 let num_passive_segments = data_segment_length_addrs.len();
524 let globals_end =
525 memory_layout::globals_region_size(global_init_values.len(), num_passive_segments);
526
527 let wasm_to_rw_offset = wasm_memory_base as u32 - 0x30000;
529
530 let data_end = data_segments
531 .iter()
532 .filter_map(|seg| {
533 seg.offset
534 .map(|off| wasm_to_rw_offset + off + seg.data.len() as u32)
535 })
536 .max()
537 .unwrap_or(0) as usize;
538
539 let total_size = globals_end.max(data_end);
540
541 if total_size == 0 {
542 return Vec::new();
543 }
544
545 let mut rw_data = vec![0u8; total_size];
546
547 for (i, &value) in global_init_values.iter().enumerate() {
549 let offset = i * 4;
550 if offset + 4 <= rw_data.len() {
551 rw_data[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
552 }
553 }
554
555 let mem_size_offset = global_init_values.len() * 4;
557 if mem_size_offset + 4 <= rw_data.len() {
558 rw_data[mem_size_offset..mem_size_offset + 4]
559 .copy_from_slice(&initial_memory_pages.to_le_bytes());
560 }
561
562 for (&seg_idx, &addr) in data_segment_length_addrs {
565 if let Some(&length) = data_segment_lengths.get(&seg_idx) {
566 let rw_offset = (addr - memory_layout::GLOBAL_MEMORY_BASE) as usize;
568 if rw_offset + 4 <= rw_data.len() {
569 rw_data[rw_offset..rw_offset + 4].copy_from_slice(&length.to_le_bytes());
570 }
571 }
572 }
573
574 for seg in data_segments {
576 if let Some(offset) = seg.offset {
577 let rw_offset = (wasm_to_rw_offset + offset) as usize;
578 if rw_offset + seg.data.len() <= rw_data.len() {
579 rw_data[rw_offset..rw_offset + seg.data.len()].copy_from_slice(&seg.data);
580 }
581 }
582 }
583
584 if let Some(last_non_zero) = rw_data.iter().rposition(|&b| b != 0) {
587 rw_data.truncate(last_non_zero + 1);
588 } else {
589 rw_data.clear();
590 }
591
592 rw_data
593}
594
595fn return_addr_jump_table_idx(
605 instructions: &[Instruction],
606 return_addr_instr: usize,
607) -> Result<usize> {
608 let value = match instructions.get(return_addr_instr) {
609 Some(
610 Instruction::LoadImmJump { value, .. }
611 | Instruction::LoadImm { value, .. }
612 | Instruction::LoadImmJumpInd { value, .. },
613 ) => Some(*value),
614 _ => None,
615 };
616 match value {
617 Some(v) if v > 0 && v % 2 == 0 => Ok((v as usize / 2) - 1),
618 _ => Err(Error::Internal(format!(
619 "expected LoadImmJump/LoadImm/LoadImmJumpInd((idx+1)*2) at return_addr_instr {return_addr_instr}, got {:?}",
620 instructions.get(return_addr_instr)
621 ))),
622 }
623}
624
625fn resolve_call_fixups(
626 instructions: &mut [Instruction],
627 call_fixups: &[(usize, CallFixup)],
628 indirect_call_fixups: &[(usize, IndirectCallFixup)],
629 function_offsets: &[usize],
630) -> Result<(Vec<u32>, usize)> {
631 let mut num_call_returns: usize = 0;
635
636 for (instr_base, fixup) in call_fixups {
637 let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
638 num_call_returns = num_call_returns.max(idx + 1);
639 }
640 for (instr_base, fixup) in indirect_call_fixups {
641 let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
642 num_call_returns = num_call_returns.max(idx + 1);
643 }
644
645 let mut jump_table: Vec<u32> = vec![0u32; num_call_returns];
646
647 for (instr_base, fixup) in call_fixups {
651 let target_offset = function_offsets
652 .get(fixup.target_func as usize)
653 .ok_or_else(|| {
654 Error::Unsupported(format!("call to unknown function {}", fixup.target_func))
655 })?;
656
657 let jump_idx = instr_base + fixup.jump_instr;
658
659 let return_addr_offset: usize = instructions[..=jump_idx]
661 .iter()
662 .map(|i| i.encode().len())
663 .sum();
664
665 let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
666 jump_table[slot] = return_addr_offset as u32;
667
668 let expected_addr = ((slot + 1) * 2) as i32;
670 debug_assert!(
671 matches!(&instructions[jump_idx], Instruction::LoadImmJump { value, .. } if *value == expected_addr),
672 "pre-assigned jump table address mismatch: expected {expected_addr}, got {:?}",
673 &instructions[jump_idx]
674 );
675
676 let jump_start_offset: usize = instructions[..jump_idx]
678 .iter()
679 .map(|i| i.encode().len())
680 .sum();
681 let relative_offset = (*target_offset as i32) - (jump_start_offset as i32);
682
683 if let Instruction::LoadImmJump { offset, .. } = &mut instructions[jump_idx] {
684 *offset = relative_offset;
685 }
686 }
687
688 for (instr_base, fixup) in indirect_call_fixups {
689 let jump_ind_idx = instr_base + fixup.jump_ind_instr;
690
691 let return_addr_offset: usize = instructions[..=jump_ind_idx]
692 .iter()
693 .map(|i| i.encode().len())
694 .sum();
695
696 let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
697 jump_table[slot] = return_addr_offset as u32;
698 }
699
700 let func_entry_base = jump_table.len();
701 for &offset in function_offsets {
702 jump_table.push(offset as u32);
703 }
704
705 Ok((jump_table, func_entry_base))
706}
707
708#[cfg(test)]
709mod tests {
710 use std::collections::HashMap;
711
712 use super::build_rw_data;
713 use super::memory_layout;
714 use super::wasm_module::DataSegment;
715
716 #[test]
717 fn build_rw_data_trims_all_zero_tail_to_empty() {
718 let rw = build_rw_data(&[], &[], 0, 0x30000, &HashMap::new(), &HashMap::new());
719 assert!(rw.is_empty());
720 }
721
722 #[test]
723 fn build_rw_data_preserves_internal_zeros_and_trims_trailing_zeros() {
724 let data_segments = vec![DataSegment {
725 offset: Some(0),
726 data: vec![1, 0, 2, 0, 0],
727 }];
728
729 let rw = build_rw_data(
730 &data_segments,
731 &[],
732 0,
733 0x30000,
734 &HashMap::new(),
735 &HashMap::new(),
736 );
737
738 assert_eq!(rw, vec![1, 0, 2]);
739 }
740
741 #[test]
742 fn build_rw_data_keeps_non_zero_passive_length_bytes() {
743 let mut addrs = HashMap::new();
744 addrs.insert(0u32, memory_layout::GLOBAL_MEMORY_BASE + 4);
745 let mut lengths = HashMap::new();
746 lengths.insert(0u32, 7u32);
747
748 let rw = build_rw_data(&[], &[], 0, 0x30000, &addrs, &lengths);
749
750 assert_eq!(rw, vec![0, 0, 0, 0, 7]);
751 }
752
753 #[test]
756 fn heap_pages_with_empty_rw_data_equals_total_pages_plus_one() {
757 let pages = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
763 assert_eq!(pages, 260);
764 }
765
766 #[test]
767 fn heap_pages_reduced_by_rw_data_pages() {
768 let pages_no_rw = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
770 let pages_with_rw = super::calculate_heap_pages(8192, 0x33000, 0, 10).unwrap();
771 assert_eq!(pages_no_rw - pages_with_rw, 2);
772 }
773
774 #[test]
775 fn heap_pages_saturates_at_one_for_large_rw_data() {
776 let pages = super::calculate_heap_pages(2 * 1024 * 1024, 0x33000, 0, 10).unwrap();
778 assert_eq!(pages, 1);
779 }
780
781 #[test]
782 fn heap_pages_respects_initial_pages() {
783 let pages = super::calculate_heap_pages(0, 0x33000, 32, 10).unwrap();
789 assert_eq!(pages, 516);
790 }
791}