use std::cmp::Ordering;
use std::sync::Arc;
use cranelift_codegen::ir::{
condcodes::FloatCC, types, AbiParam, InstBuilder, MemFlags, Signature,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use super::filter_compiler::JITModuleOwner;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OrderKeySpec {
pub col_idx: usize,
pub ascending: bool,
}
type OrderFn = unsafe extern "C" fn(*const f64, *const f64, usize) -> i8;
pub struct CompiledOrder {
fn_ptr: OrderFn,
specs: Vec<OrderKeySpec>,
_module_owner: Arc<JITModuleOwner>,
}
unsafe impl Send for CompiledOrder {}
unsafe impl Sync for CompiledOrder {}
impl std::fmt::Debug for CompiledOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledOrder")
.field("col_count", &self.specs.len())
.finish_non_exhaustive()
}
}
impl CompiledOrder {
pub fn compare(&self, left: &[f64], right: &[f64]) -> Ordering {
let n = self.specs.len();
let r = unsafe { (self.fn_ptr)(left.as_ptr(), right.as_ptr(), n) };
match r {
i if i < 0 => Ordering::Less,
0 => Ordering::Equal,
_ => Ordering::Greater,
}
}
pub fn col_count(&self) -> usize {
self.specs.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum OrderCompilerError {
#[error("order compiler requires at least one key spec")]
NoKeys,
#[error("JIT codegen error: {0}")]
CodegenError(String),
#[error("JIT ISA init error: {0}")]
IsaInitError(String),
}
pub struct OrderCompiler;
impl Default for OrderCompiler {
fn default() -> Self {
OrderCompiler
}
}
impl OrderCompiler {
pub fn new() -> Self {
OrderCompiler
}
pub fn compile(&self, specs: &[OrderKeySpec]) -> Result<CompiledOrder, OrderCompilerError> {
if specs.is_empty() {
return Err(OrderCompilerError::NoKeys);
}
let module = build_jit_module()?;
let (fn_ptr, module) = compile_order_fn(module, specs)?;
let owner = Arc::new(JITModuleOwner::new(module));
Ok(CompiledOrder {
fn_ptr,
specs: specs.to_vec(),
_module_owner: owner,
})
}
}
fn build_jit_module() -> Result<JITModule, OrderCompilerError> {
let mut flag_builder = settings::builder();
flag_builder
.set("use_colocated_libcalls", "false")
.map_err(|e| OrderCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("is_pic", "false")
.map_err(|e| OrderCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("opt_level", "speed")
.map_err(|e| OrderCompilerError::CodegenError(e.to_string()))?;
let flags = settings::Flags::new(flag_builder);
let isa = cranelift_native::builder()
.map_err(|e| OrderCompilerError::IsaInitError(e.to_string()))?
.finish(flags)
.map_err(|e| OrderCompilerError::IsaInitError(e.to_string()))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
Ok(JITModule::new(builder))
}
fn compile_order_fn(
mut module: JITModule,
specs: &[OrderKeySpec],
) -> Result<(OrderFn, JITModule), OrderCompilerError> {
let ptr_type = module.isa().pointer_type();
let mut sig = Signature::new(CallConv::SystemV);
sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.returns.push(AbiParam::new(types::I8));
let func_id = module
.declare_function("order_fn", Linkage::Local, &sig)
.map_err(|e| OrderCompilerError::CodegenError(e.to_string()))?;
{
let mut ctx = module.make_context();
ctx.func.signature = sig.clone();
let mut fn_builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let left_ptr = builder.block_params(entry_block)[0];
let right_ptr = builder.block_params(entry_block)[1];
let _n_cols = builder.block_params(entry_block)[2];
let zero_i8 = builder.ins().iconst(types::I8, 0);
let mut accumulated = zero_i8;
for spec in specs {
let col_result = emit_col_comparison(&mut builder, spec, left_ptr, right_ptr)?;
let is_decided = builder.ins().icmp_imm(
cranelift_codegen::ir::condcodes::IntCC::NotEqual,
accumulated,
0,
);
accumulated = builder.ins().select(is_decided, accumulated, col_result);
}
builder.ins().return_(&[accumulated]);
builder.finalize();
module
.define_function(func_id, &mut ctx)
.map_err(|e| OrderCompilerError::CodegenError(format!("{e:?}")))?;
}
module
.finalize_definitions()
.map_err(|e| OrderCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
let raw_ptr = module.get_finalized_function(func_id);
let fn_ptr: OrderFn = unsafe { std::mem::transmute(raw_ptr) };
Ok((fn_ptr, module))
}
fn emit_col_comparison(
builder: &mut FunctionBuilder<'_>,
spec: &OrderKeySpec,
left_ptr: cranelift_codegen::ir::Value,
right_ptr: cranelift_codegen::ir::Value,
) -> Result<cranelift_codegen::ir::Value, OrderCompilerError> {
let offset = col_byte_offset(spec.col_idx)?;
let lv = builder
.ins()
.load(types::F64, MemFlags::trusted(), left_ptr, offset);
let rv = builder
.ins()
.load(types::F64, MemFlags::trusted(), right_ptr, offset);
let is_lt = builder.ins().fcmp(FloatCC::LessThan, lv, rv);
let is_eq = builder.ins().fcmp(FloatCC::Equal, lv, rv);
let (neg_one, pos_one) = if spec.ascending {
(
builder.ins().iconst(types::I8, -1i64),
builder.ins().iconst(types::I8, 1i64),
)
} else {
(
builder.ins().iconst(types::I8, 1i64),
builder.ins().iconst(types::I8, -1i64),
)
};
let zero_i8 = builder.ins().iconst(types::I8, 0);
let inner = builder.ins().select(is_lt, neg_one, pos_one);
let result = builder.ins().select(is_eq, zero_i8, inner);
Ok(result)
}
fn col_byte_offset(idx: usize) -> Result<i32, OrderCompilerError> {
let byte = idx.checked_mul(std::mem::size_of::<f64>()).ok_or_else(|| {
OrderCompilerError::CodegenError(format!("column index {idx} overflows byte offset"))
})?;
i32::try_from(byte).map_err(|_| {
OrderCompilerError::CodegenError(format!(
"column index {idx} byte offset {} exceeds i32::MAX",
byte
))
})
}
#[cfg(test)]
mod tests {
use super::*;
fn compiler() -> OrderCompiler {
OrderCompiler::new()
}
#[test]
fn test_no_keys_error() {
let result = compiler().compile(&[]);
assert!(matches!(result, Err(OrderCompilerError::NoKeys)));
}
#[test]
fn test_col_byte_offset_zero() {
assert_eq!(col_byte_offset(0).expect("ok"), 0);
}
#[test]
fn test_col_byte_offset_two() {
assert_eq!(col_byte_offset(2).expect("ok"), 16);
}
}