use std::sync::Arc;
use cranelift_codegen::ir::{
condcodes::{FloatCC, IntCC},
types, AbiParam, InstBuilder, MemFlags, Signature,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use super::filter_compiler::JITModuleOwner;
#[derive(Debug, Clone)]
pub struct JoinKeySpec {
pub left_idx: usize,
pub right_idx: usize,
pub numeric_epsilon: bool,
}
type JoinKeyFn = unsafe extern "C" fn(*const f64, *const f64, usize) -> i8;
pub struct CompiledJoinKey {
fn_ptr: JoinKeyFn,
specs: Vec<JoinKeySpec>,
_module_owner: Arc<JITModuleOwner>,
}
unsafe impl Send for CompiledJoinKey {}
unsafe impl Sync for CompiledJoinKey {}
impl std::fmt::Debug for CompiledJoinKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledJoinKey")
.field("key_count", &self.specs.len())
.finish_non_exhaustive()
}
}
impl CompiledJoinKey {
pub fn compare(&self, left: &[f64], right: &[f64]) -> bool {
let n = self.specs.len();
unsafe { (self.fn_ptr)(left.as_ptr(), right.as_ptr(), n) == 1 }
}
pub fn key_count(&self) -> usize {
self.specs.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum JoinCompilerError {
#[error("join compiler requires at least one key spec")]
NoKeys,
#[error("JIT codegen error: {0}")]
CodegenError(String),
#[error("JIT ISA init error: {0}")]
IsaInitError(String),
}
pub struct JoinCompiler;
impl Default for JoinCompiler {
fn default() -> Self {
JoinCompiler
}
}
impl JoinCompiler {
pub fn new() -> Self {
JoinCompiler
}
pub fn compile(&self, specs: &[JoinKeySpec]) -> Result<CompiledJoinKey, JoinCompilerError> {
if specs.is_empty() {
return Err(JoinCompilerError::NoKeys);
}
let module = build_jit_module()?;
let (fn_ptr, module) = compile_join_fn(module, specs)?;
let owner = Arc::new(JITModuleOwner::new(module));
Ok(CompiledJoinKey {
fn_ptr,
specs: specs.to_vec(),
_module_owner: owner,
})
}
}
fn build_jit_module() -> Result<JITModule, JoinCompilerError> {
let mut flag_builder = settings::builder();
flag_builder
.set("use_colocated_libcalls", "false")
.map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("is_pic", "false")
.map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("opt_level", "speed")
.map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
let flags = settings::Flags::new(flag_builder);
let isa = cranelift_native::builder()
.map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?
.finish(flags)
.map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
Ok(JITModule::new(builder))
}
fn compile_join_fn(
mut module: JITModule,
specs: &[JoinKeySpec],
) -> Result<(JoinKeyFn, JITModule), JoinCompilerError> {
let ptr_type = module.isa().pointer_type();
let mut sig = Signature::new(CallConv::SystemV);
sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.returns.push(AbiParam::new(types::I8));
let func_id = module
.declare_function("join_key_fn", Linkage::Local, &sig)
.map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
{
let mut ctx = module.make_context();
ctx.func.signature = sig.clone();
let mut fn_builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let left_ptr = builder.block_params(entry_block)[0];
let right_ptr = builder.block_params(entry_block)[1];
let _n_keys = builder.block_params(entry_block)[2];
let one_i8 = builder.ins().iconst(types::I8, 1);
let mut accumulator = one_i8;
for spec in specs {
let key_ok = emit_key_comparison(&mut builder, spec, left_ptr, right_ptr)?;
accumulator = builder.ins().band(accumulator, key_ok);
}
builder.ins().return_(&[accumulator]);
builder.finalize();
module
.define_function(func_id, &mut ctx)
.map_err(|e| JoinCompilerError::CodegenError(format!("{e:?}")))?;
}
module
.finalize_definitions()
.map_err(|e| JoinCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
let raw_ptr = module.get_finalized_function(func_id);
let fn_ptr: JoinKeyFn = unsafe { std::mem::transmute(raw_ptr) };
Ok((fn_ptr, module))
}
fn emit_key_comparison(
builder: &mut FunctionBuilder<'_>,
spec: &JoinKeySpec,
left_ptr: cranelift_codegen::ir::Value,
right_ptr: cranelift_codegen::ir::Value,
) -> Result<cranelift_codegen::ir::Value, JoinCompilerError> {
let left_offset = byte_offset(spec.left_idx)?;
let right_offset = byte_offset(spec.right_idx)?;
let lv = builder
.ins()
.load(types::F64, MemFlags::trusted(), left_ptr, left_offset);
let rv = builder
.ins()
.load(types::F64, MemFlags::trusted(), right_ptr, right_offset);
if spec.numeric_epsilon {
let diff = builder.ins().fsub(lv, rv);
let abs_diff = builder.ins().fabs(diff);
let eps = builder.ins().f64const(1e-9);
let cmp = builder.ins().fcmp(FloatCC::LessThan, abs_diff, eps);
Ok(cmp)
} else {
let li = builder.ins().bitcast(types::I64, MemFlags::new(), lv);
let ri = builder.ins().bitcast(types::I64, MemFlags::new(), rv);
let cmp = builder.ins().icmp(IntCC::Equal, li, ri);
Ok(cmp)
}
}
fn byte_offset(idx: usize) -> Result<i32, JoinCompilerError> {
let byte = idx.checked_mul(std::mem::size_of::<f64>()).ok_or_else(|| {
JoinCompilerError::CodegenError(format!("column index {idx} overflows byte offset"))
})?;
i32::try_from(byte).map_err(|_| {
JoinCompilerError::CodegenError(format!(
"column index {idx} byte offset {} exceeds i32::MAX",
byte
))
})
}
#[cfg(test)]
mod tests {
use super::*;
fn compiler() -> JoinCompiler {
JoinCompiler::new()
}
#[test]
fn test_no_keys_error() {
let result = compiler().compile(&[]);
assert!(matches!(result, Err(JoinCompilerError::NoKeys)));
}
#[test]
fn test_byte_offset_zero() {
assert_eq!(byte_offset(0).expect("ok"), 0);
}
#[test]
fn test_byte_offset_one() {
assert_eq!(byte_offset(1).expect("ok"), 8);
}
}