use std::process::Command;
use inkwell::attributes::AttributeLoc;
use inkwell::context::Context;
use inkwell::execution_engine::ExecutionEngine;
use inkwell::memory_buffer::MemoryBuffer;
use inkwell::module::Module as LlvmModule;
use inkwell::targets::{
CodeModel, InitializationConfig, RelocMode, Target, TargetMachine, TargetTriple,
};
use inkwell::OptimizationLevel;
use crate::codegen::{emit_module_funcs_closed_world, ConstPool, ENTRY_SYMBOL};
use crate::error::LlvmError;
type EntryArity1 = unsafe extern "C" fn(i64) -> i64;
pub struct CocompiledModule {
pub ir_after_opt: String,
pub ir_before_link: String,
engine: ExecutionEngine<'static>,
}
impl CocompiledModule {
pub fn run_i64(&self, arg: i64) -> Result<i64, LlvmError> {
let f: inkwell::execution_engine::JitFunction<'_, EntryArity1> = unsafe {
self.engine
.get_function(ENTRY_SYMBOL)
.map_err(|e| LlvmError::Codegen(format!("cocompile: entry lookup: {e}")))?
};
Ok(unsafe { f.call(arg) })
}
}
pub fn cocompile_legacy_i64(
ir: &relon_ir::ir::Module,
host_shim_src: &str,
) -> Result<CocompiledModule, LlvmError> {
let entry_idx = ir
.entry_func_index
.ok_or_else(|| LlvmError::Codegen("cocompile: IR module has no entry function".into()))?;
let entry = &ir.funcs[entry_idx];
let ctx_box: Box<Context> = Box::new(Context::create());
let ctx: &'static Context = unsafe { &*(Box::into_raw(ctx_box) as *const Context) };
let module = ctx.create_module("relon_llvm_cocompile");
let const_pool = ConstPool::from_module(ir)?;
let helpers: Vec<&relon_ir::ir::Func> = ir
.funcs
.iter()
.enumerate()
.filter(|(i, _)| *i != entry_idx)
.map(|(_, f)| f)
.collect();
let helper_ir_indices: Vec<u32> = ir
.funcs
.iter()
.enumerate()
.filter(|(i, _)| *i != entry_idx)
.map(|(i, _)| i as u32)
.collect();
emit_module_funcs_closed_world(
ctx,
&module,
entry,
0,
&const_pool,
&helpers,
Some(&helper_ir_indices),
&[],
&[],
&ir.imports,
)?;
let ir_before_link = module.print_to_string().to_string();
link_and_inline_host_shim(&module, host_shim_src, &ir.imports)?;
run_default_o3_pipeline(&module)?;
let ir_after_opt = module.print_to_string().to_string();
let engine = module
.create_jit_execution_engine(OptimizationLevel::Aggressive)
.map_err(|e| LlvmError::Codegen(format!("cocompile: create JIT engine: {e}")))?;
Ok(CocompiledModule {
ir_after_opt,
ir_before_link,
engine,
})
}
pub(crate) fn link_and_inline_host_shim(
module: &LlvmModule<'_>,
host_shim_src: &str,
imports: &[relon_ir::ir::NativeImport],
) -> Result<(), LlvmError> {
link_and_inline_host_shim_for_target(module, host_shim_src, imports, HostShimTarget::Native)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum HostShimTarget {
Native,
Wasm32,
}
pub(crate) fn link_and_inline_host_shim_wasm_pure_only(
module: &LlvmModule<'_>,
host_shim_src: &str,
imports: &[relon_ir::ir::NativeImport],
effectful: &[bool],
) -> Result<(), LlvmError> {
let ctx = module.get_context();
let host_ll = compile_host_shim_to_textual_ir(host_shim_src, HostShimTarget::Wasm32)?;
let buffer = MemoryBuffer::create_from_file(&host_ll)
.map_err(|e| LlvmError::Codegen(format!("cocompile: read host wasm .ll: {e}")))?;
let host_module = ctx
.create_module_from_ir(buffer)
.map_err(|e| LlvmError::Codegen(format!("cocompile: parse host wasm textual IR: {e}")))?;
module
.link_in_module(host_module)
.map_err(|e| LlvmError::Codegen(format!("cocompile: wasm link_in_module: {e}")))?;
let always_inline = ctx.create_enum_attribute(
inkwell::attributes::Attribute::get_named_enum_kind_id("alwaysinline"),
0,
);
for (idx, import) in imports.iter().enumerate() {
if effectful.get(idx).copied().unwrap_or(false) {
continue;
}
if let Some(host_fn) = module.get_function(&import.name) {
if host_fn.get_first_basic_block().is_some() {
host_fn.add_attribute(AttributeLoc::Function, always_inline);
}
}
}
Ok(())
}
fn link_and_inline_host_shim_for_target(
module: &LlvmModule<'_>,
host_shim_src: &str,
imports: &[relon_ir::ir::NativeImport],
target: HostShimTarget,
) -> Result<(), LlvmError> {
let ctx = module.get_context();
let host_ll = compile_host_shim_to_textual_ir(host_shim_src, target)?;
let buffer = MemoryBuffer::create_from_file(&host_ll)
.map_err(|e| LlvmError::Codegen(format!("cocompile: read host .ll: {e}")))?;
let host_module = ctx
.create_module_from_ir(buffer)
.map_err(|e| LlvmError::Codegen(format!("cocompile: parse host textual IR: {e}")))?;
module
.link_in_module(host_module)
.map_err(|e| LlvmError::Codegen(format!("cocompile: link_in_module: {e}")))?;
let always_inline = ctx.create_enum_attribute(
inkwell::attributes::Attribute::get_named_enum_kind_id("alwaysinline"),
0,
);
for import in imports {
if let Some(host_fn) = module.get_function(&import.name) {
if host_fn.get_first_basic_block().is_some() {
host_fn.add_attribute(AttributeLoc::Function, always_inline);
}
}
}
Ok(())
}
fn compile_host_shim_to_textual_ir(
host_shim_src: &str,
target: HostShimTarget,
) -> Result<std::path::PathBuf, LlvmError> {
static SEQ: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let seq = SEQ.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("relon_cocompile_{}_{seq}", std::process::id()));
std::fs::create_dir_all(&dir)
.map_err(|e| LlvmError::Codegen(format!("cocompile: mkdir tmp: {e}")))?;
let rs_path = dir.join("host_shim.rs");
let ll_path = dir.join("host_shim.ll");
std::fs::write(&rs_path, host_shim_src)
.map_err(|e| LlvmError::Codegen(format!("cocompile: write shim: {e}")))?;
let mut args: Vec<&str> = vec![
"--emit=llvm-ir",
"--crate-type=cdylib",
"-O",
"-Ccodegen-units=1",
];
if matches!(target, HostShimTarget::Wasm32) {
args.push("--target");
args.push("wasm32-unknown-unknown");
}
args.push(rs_path.to_str().unwrap());
args.push("-o");
args.push(ll_path.to_str().unwrap());
let rustc = Command::new("rustc")
.args(&args)
.output()
.map_err(|e| LlvmError::Codegen(format!("cocompile: spawn rustc: {e}")))?;
if !rustc.status.success() {
return Err(LlvmError::Codegen(format!(
"cocompile: rustc --emit=llvm-ir failed: {}",
String::from_utf8_lossy(&rustc.stderr)
)));
}
Ok(ll_path)
}
fn run_default_o3_pipeline(module: &LlvmModule<'_>) -> Result<(), LlvmError> {
Target::initialize_native(&InitializationConfig::default())
.map_err(|e| LlvmError::Codegen(format!("cocompile: initialize_native: {e}")))?;
let triple_str = TargetMachine::get_default_triple();
let target = Target::from_triple(&triple_str)
.map_err(|e| LlvmError::Codegen(format!("cocompile: target from_triple: {e}")))?;
let cpu = TargetMachine::get_host_cpu_name();
let features = TargetMachine::get_host_cpu_features();
let triple = TargetTriple::create(
triple_str
.as_str()
.to_str()
.map_err(|e| LlvmError::Codegen(format!("cocompile: triple utf8: {e}")))?,
);
let machine = target
.create_target_machine(
&triple,
cpu.to_str().unwrap_or(""),
features.to_str().unwrap_or(""),
OptimizationLevel::Aggressive,
RelocMode::Default,
CodeModel::JITDefault,
)
.ok_or_else(|| LlvmError::Codegen("cocompile: create_target_machine null".into()))?;
let opts = inkwell::passes::PassBuilderOptions::create();
module
.run_passes("default<O3>", &machine, opts)
.map_err(|e| LlvmError::Codegen(format!("cocompile: run_passes O3: {e}")))?;
Ok(())
}