use std::collections::HashMap;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{default_libcall_names, Linkage, Module};
use lower_ir_utils::{define_jit_fn, jit_export};
fn jit_builder() -> JITBuilder {
let mut flag_builder = settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
let isa = cranelift_native::builder()
.unwrap()
.finish(settings::Flags::new(flag_builder))
.unwrap();
JITBuilder::with_isa(isa, default_libcall_names())
}
#[jit_export]
fn double_i64(x: i64) -> i64 {
x.wrapping_mul(2)
}
#[test]
fn calls_extern_taking_i64() {
let mut jb = jit_builder();
double_i64_jit::register(&mut jb);
let mut module = JITModule::new(jb);
let ext_id = double_i64_jit::declare(&mut module);
let wrap_id = define_jit_fn!(
&mut module,
"wrap",
Linkage::Export,
fn(i64) -> i64,
|bcx, module, params| double_i64_jit::call(bcx, module, ext_id, params[0]),
)
.unwrap();
module.finalize_definitions().unwrap();
let f: extern "C" fn(i64) -> i64 =
unsafe { std::mem::transmute(module.get_finalized_function(wrap_id)) };
assert_eq!(f(21), 42);
assert_eq!(f(-7), -14);
}
#[jit_export]
fn lookup(map_ptr: *const HashMap<String, i64>, key: &str) -> i64 {
let map = unsafe { &*map_ptr };
*map.get(key).unwrap_or(&-1)
}
#[test]
fn calls_extern_with_map_pointer_and_static_str() {
let mut jb = jit_builder();
lookup_jit::register(&mut jb);
let mut module = JITModule::new(jb);
let ext_id = lookup_jit::declare(&mut module);
let wrap_id = define_jit_fn!(
&mut module,
"wrap",
Linkage::Export,
fn(*const HashMap<String, i64>) -> i64,
|bcx, module, params| {
lookup_jit::call(bcx, module, ext_id, params[0], "answer")
},
)
.unwrap();
module.finalize_definitions().unwrap();
let f: extern "C" fn(*const HashMap<String, i64>) -> i64 =
unsafe { std::mem::transmute(module.get_finalized_function(wrap_id)) };
let mut map = HashMap::new();
map.insert("answer".to_string(), 42i64);
map.insert("other".to_string(), 7);
assert_eq!(f(&map), 42);
map.remove("answer");
assert_eq!(f(&map), -1);
}
#[jit_export]
fn fma_like(n: i32, x: f64) -> f64 {
(n as f64) * x + 1.0
}
#[test]
fn calls_extern_with_mixed_int_float() {
let mut jb = jit_builder();
fma_like_jit::register(&mut jb);
let mut module = JITModule::new(jb);
let ext_id = fma_like_jit::declare(&mut module);
let wrap_id = define_jit_fn!(
&mut module,
"wrap",
Linkage::Export,
fn(i32, f64) -> f64,
|bcx, module, params| fma_like_jit::call(bcx, module, ext_id, params[0], params[1]),
)
.unwrap();
module.finalize_definitions().unwrap();
let f: extern "C" fn(i32, f64) -> f64 =
unsafe { std::mem::transmute(module.get_finalized_function(wrap_id)) };
assert_eq!(f(3, 0.5), 3.0 * 0.5 + 1.0);
}
#[repr(C)]
struct Config {
base: i64,
}
#[jit_export]
fn add_to_base(cfg: *const Config, x: i64) -> i64 {
let cfg = unsafe { &*cfg };
cfg.base + x
}
static CFG: Config = Config { base: 100 };
#[test]
fn embeds_raw_pointer_constant() {
let mut jb = jit_builder();
add_to_base_jit::register(&mut jb);
let mut module = JITModule::new(jb);
let ext_id = add_to_base_jit::declare(&mut module);
let wrap_id = define_jit_fn!(
&mut module,
"wrap",
Linkage::Export,
fn(i64) -> i64,
|bcx, module, params| {
let cfg_ptr: *const Config = &CFG;
add_to_base_jit::call(bcx, module, ext_id, cfg_ptr, params[0])
},
)
.unwrap();
module.finalize_definitions().unwrap();
let f: extern "C" fn(i64) -> i64 =
unsafe { std::mem::transmute(module.get_finalized_function(wrap_id)) };
assert_eq!(f(5), 105);
assert_eq!(f(-50), 50);
}
#[jit_export]
fn answer() -> i64 {
42
}
#[test]
fn calls_extern_with_no_args() {
let mut jb = jit_builder();
answer_jit::register(&mut jb);
let mut module = JITModule::new(jb);
let ext_id = answer_jit::declare(&mut module);
let wrap_id = define_jit_fn!(
&mut module,
"wrap",
Linkage::Export,
fn() -> i64,
|bcx, module, _params| answer_jit::call(bcx, module, ext_id),
)
.unwrap();
module.finalize_definitions().unwrap();
let f: extern "C" fn() -> i64 =
unsafe { std::mem::transmute(module.get_finalized_function(wrap_id)) };
assert_eq!(f(), 42);
}