Skip to main content

oxirs_arq/jit/
join_compiler.rs

1//! Cranelift-based JIT compiler for hash-join key comparison (JIT phase c).
2//!
3//! Produces a native function with signature:
4//! ```text
5//! fn(left: *const f64, right: *const f64, n_keys: usize) -> i8
6//! ```
7//! Returns `1` if all key columns match, `0` otherwise.
8//!
9//! # Comparison modes
10//!
11//! Each key column chooses one of two modes:
12//! - **Epsilon**: `|left[i] - right[i]| < 1e-9` (ordered float comparison, NaN → not equal).
13//! - **Exact**: bitcast both operands to `i64` and check `icmp Equal`. Under this mode,
14//!   two values with identical bit patterns — including two copies of the *same* NaN — compare
15//!   as equal, because IEEE 754 NaN bit patterns are just ordinary 64-bit integers to `icmp`.
16//!
17//! # Safety
18//!
19//! The compiled function is called via an `unsafe extern "C" fn` pointer.
20//! Callers must guarantee:
21//! - `left` and `right` are valid, aligned, non-null pointers to at least
22//!   `max(spec.left_idx) + 1` and `max(spec.right_idx) + 1` `f64` values, respectively.
23//! - Both pointers remain valid for the duration of the call.
24
25use std::sync::Arc;
26
27use cranelift_codegen::ir::{
28    condcodes::{FloatCC, IntCC},
29    types, AbiParam, InstBuilder, MemFlags, Signature,
30};
31use cranelift_codegen::isa::CallConv;
32use cranelift_codegen::settings::{self, Configurable};
33use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
34use cranelift_jit::{JITBuilder, JITModule};
35use cranelift_module::{Linkage, Module};
36
37use super::filter_compiler::JITModuleOwner;
38
39// ---------------------------------------------------------------------------
40// Public types
41// ---------------------------------------------------------------------------
42
43/// Configuration for a single join key column.
44#[derive(Debug, Clone)]
45pub struct JoinKeySpec {
46    /// Index into the left-row `f64` slice.
47    pub left_idx: usize,
48    /// Index into the right-row `f64` slice.
49    pub right_idx: usize,
50    /// If `true`, use epsilon comparison (`|a - b| < 1e-9`).
51    /// If `false`, use exact bit equality (bitcast to `i64` then `icmp Equal`).
52    pub numeric_epsilon: bool,
53}
54
55/// The C-ABI function pointer produced by [`JoinCompiler::compile`].
56///
57/// # Safety
58///
59/// `left` and `right` must each point to enough `f64` values to cover all
60/// `left_idx` / `right_idx` in the compiled spec. `n_keys` must equal the
61/// number of [`JoinKeySpec`] entries the function was compiled with.
62type JoinKeyFn = unsafe extern "C" fn(*const f64, *const f64, usize) -> i8;
63
64/// A compiled join-key comparator produced by [`JoinCompiler::compile`].
65///
66/// The wrapped function returns `1` if all key columns match and `0` otherwise.
67///
68/// # Ownership / safety invariant
69///
70/// `fn_ptr` is only valid while `_module_owner` is alive.
71/// Cloning an `Arc<CompiledJoinKey>` extends the lifetime safely.
72pub struct CompiledJoinKey {
73    /// JIT-compiled function pointer.
74    fn_ptr: JoinKeyFn,
75    /// Key specs kept for `key_count()`.
76    specs: Vec<JoinKeySpec>,
77    /// Keeps the `JITModule` code pages alive.
78    _module_owner: Arc<JITModuleOwner>,
79}
80
81// SAFETY: JITModule code pages are read-only after finalisation;
82// the module is protected by `Arc<JITModuleOwner>`.
83unsafe impl Send for CompiledJoinKey {}
84unsafe impl Sync for CompiledJoinKey {}
85
86impl std::fmt::Debug for CompiledJoinKey {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("CompiledJoinKey")
89            .field("key_count", &self.specs.len())
90            .finish_non_exhaustive()
91    }
92}
93
94impl CompiledJoinKey {
95    /// Compare two rows using the compiled function.
96    ///
97    /// Returns `true` if all key columns match according to each column's comparison mode.
98    ///
99    /// # Safety
100    ///
101    /// `left` must have at least `max(spec.left_idx) + 1` elements and
102    /// `right` must have at least `max(spec.right_idx) + 1` elements.
103    pub fn compare(&self, left: &[f64], right: &[f64]) -> bool {
104        let n = self.specs.len();
105        // SAFETY:
106        // - `fn_ptr` is valid because `_module_owner` is still alive (we hold it).
107        // - `left.as_ptr()` and `right.as_ptr()` are valid for `n` reads
108        //   when the caller obeys the documented precondition.
109        // - The compiled function does not mutate either slice.
110        unsafe { (self.fn_ptr)(left.as_ptr(), right.as_ptr(), n) == 1 }
111    }
112
113    /// The number of key columns this comparator was compiled for.
114    pub fn key_count(&self) -> usize {
115        self.specs.len()
116    }
117}
118
119// ---------------------------------------------------------------------------
120// Error type
121// ---------------------------------------------------------------------------
122
123/// Errors that can occur during JIT compilation of a join-key comparator.
124#[derive(Debug, thiserror::Error)]
125pub enum JoinCompilerError {
126    /// No key specs were provided.
127    #[error("join compiler requires at least one key spec")]
128    NoKeys,
129    /// Cranelift reported a codegen or linkage error.
130    #[error("JIT codegen error: {0}")]
131    CodegenError(String),
132    /// ISA builder failed to initialise.
133    #[error("JIT ISA init error: {0}")]
134    IsaInitError(String),
135}
136
137// ---------------------------------------------------------------------------
138// Compiler (ZST — new JITModule per compile call)
139// ---------------------------------------------------------------------------
140
141/// Compiles join-key comparators to native machine code via Cranelift.
142///
143/// Each call to [`compile`](JoinCompiler::compile) creates a fresh `JITModule`
144/// so that the resulting [`CompiledJoinKey`] is independently owned and can be
145/// dropped without affecting other compiled functions.
146pub struct JoinCompiler;
147
148impl Default for JoinCompiler {
149    fn default() -> Self {
150        JoinCompiler
151    }
152}
153
154impl JoinCompiler {
155    /// Create a new `JoinCompiler`.
156    pub fn new() -> Self {
157        JoinCompiler
158    }
159
160    /// Compile a multi-key join comparator.
161    ///
162    /// The resulting function returns `1` if **all** key columns match, `0` if any diverges.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`JoinCompilerError::NoKeys`] if `specs` is empty.
167    /// Returns [`JoinCompilerError::CodegenError`] on Cranelift failures.
168    pub fn compile(&self, specs: &[JoinKeySpec]) -> Result<CompiledJoinKey, JoinCompilerError> {
169        if specs.is_empty() {
170            return Err(JoinCompilerError::NoKeys);
171        }
172        let module = build_jit_module()?;
173        let (fn_ptr, module) = compile_join_fn(module, specs)?;
174        let owner = Arc::new(JITModuleOwner::new(module));
175        Ok(CompiledJoinKey {
176            fn_ptr,
177            specs: specs.to_vec(),
178            _module_owner: owner,
179        })
180    }
181}
182
183// ---------------------------------------------------------------------------
184// Cranelift module setup
185// ---------------------------------------------------------------------------
186
187fn build_jit_module() -> Result<JITModule, JoinCompilerError> {
188    let mut flag_builder = settings::builder();
189    flag_builder
190        .set("use_colocated_libcalls", "false")
191        .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
192    flag_builder
193        .set("is_pic", "false")
194        .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
195    flag_builder
196        .set("opt_level", "speed")
197        .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
198
199    let flags = settings::Flags::new(flag_builder);
200    let isa = cranelift_native::builder()
201        .map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?
202        .finish(flags)
203        .map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?;
204
205    let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
206    Ok(JITModule::new(builder))
207}
208
209// ---------------------------------------------------------------------------
210// Cranelift IR code generation
211// ---------------------------------------------------------------------------
212
213/// Compile a join-key comparator function into `module` and return the function
214/// pointer together with the (now finalized) module.
215fn compile_join_fn(
216    mut module: JITModule,
217    specs: &[JoinKeySpec],
218) -> Result<(JoinKeyFn, JITModule), JoinCompilerError> {
219    let ptr_type = module.isa().pointer_type();
220
221    // Signature: fn(*const f64, *const f64, usize) -> i8
222    let mut sig = Signature::new(CallConv::SystemV);
223    sig.params.push(AbiParam::new(ptr_type)); // left ptr
224    sig.params.push(AbiParam::new(ptr_type)); // right ptr
225    sig.params.push(AbiParam::new(ptr_type)); // n_keys (unused at runtime, kept for ABI)
226    sig.returns.push(AbiParam::new(types::I8));
227
228    let func_id = module
229        .declare_function("join_key_fn", Linkage::Local, &sig)
230        .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
231
232    {
233        let mut ctx = module.make_context();
234        ctx.func.signature = sig.clone();
235
236        let mut fn_builder_ctx = FunctionBuilderContext::new();
237        let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
238
239        let entry_block = builder.create_block();
240        builder.append_block_params_for_function_params(entry_block);
241        builder.switch_to_block(entry_block);
242        builder.seal_block(entry_block);
243
244        let left_ptr = builder.block_params(entry_block)[0];
245        let right_ptr = builder.block_params(entry_block)[1];
246        // Third param (n_keys) is only for ABI match; suppress "unused" at the IR level.
247        let _n_keys = builder.block_params(entry_block)[2];
248
249        // Emit one comparison per key and AND them together.
250        // `accumulator` starts as the i8 value `1` (all-match).
251        let one_i8 = builder.ins().iconst(types::I8, 1);
252        let mut accumulator = one_i8;
253
254        for spec in specs {
255            let key_ok = emit_key_comparison(&mut builder, spec, left_ptr, right_ptr)?;
256            // Logical AND of i8 values (both are 0 or 1)
257            accumulator = builder.ins().band(accumulator, key_ok);
258        }
259
260        builder.ins().return_(&[accumulator]);
261        builder.finalize();
262
263        module
264            .define_function(func_id, &mut ctx)
265            .map_err(|e| JoinCompilerError::CodegenError(format!("{e:?}")))?;
266    }
267
268    module
269        .finalize_definitions()
270        .map_err(|e| JoinCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
271
272    // SAFETY: The function was just defined and finalized above; the pointer is valid.
273    let raw_ptr = module.get_finalized_function(func_id);
274    // SAFETY: We built the function with exactly this signature.
275    let fn_ptr: JoinKeyFn = unsafe { std::mem::transmute(raw_ptr) };
276
277    Ok((fn_ptr, module))
278}
279
280/// Emit a single key-column comparison and return an `i8` value (1 = match, 0 = no match).
281fn emit_key_comparison(
282    builder: &mut FunctionBuilder<'_>,
283    spec: &JoinKeySpec,
284    left_ptr: cranelift_codegen::ir::Value,
285    right_ptr: cranelift_codegen::ir::Value,
286) -> Result<cranelift_codegen::ir::Value, JoinCompilerError> {
287    let left_offset = byte_offset(spec.left_idx)?;
288    let right_offset = byte_offset(spec.right_idx)?;
289
290    // Load left[left_idx] and right[right_idx] as f64.
291    // SAFETY comment for generated IR: caller guarantees indices are in-bounds.
292    let lv = builder
293        .ins()
294        .load(types::F64, MemFlags::trusted(), left_ptr, left_offset);
295    let rv = builder
296        .ins()
297        .load(types::F64, MemFlags::trusted(), right_ptr, right_offset);
298
299    if spec.numeric_epsilon {
300        // |lv - rv| < 1e-9
301        // Note: if either operand is NaN, fsub returns NaN; fabs(NaN) is NaN;
302        // fcmp LessThan(NaN, eps) = false → correctly returns 0 (not equal).
303        let diff = builder.ins().fsub(lv, rv);
304        let abs_diff = builder.ins().fabs(diff);
305        let eps = builder.ins().f64const(1e-9);
306        // fcmp returns I8 (0 or 1) in modern Cranelift — no bint needed.
307        let cmp = builder.ins().fcmp(FloatCC::LessThan, abs_diff, eps);
308        Ok(cmp)
309    } else {
310        // Exact bit equality: bitcast both to i64 then icmp Equal.
311        // Under this mode two NaN values with identical bit patterns compare as equal,
312        // because icmp treats them as plain integers — documented behaviour.
313        let li = builder.ins().bitcast(types::I64, MemFlags::new(), lv);
314        let ri = builder.ins().bitcast(types::I64, MemFlags::new(), rv);
315        // icmp returns I8 (0 or 1) in modern Cranelift.
316        let cmp = builder.ins().icmp(IntCC::Equal, li, ri);
317        Ok(cmp)
318    }
319}
320
321/// Convert a column index to a `i32` byte offset for Cranelift `load`.
322///
323/// Returns an error if the offset overflows `i32` (requires >268 million columns).
324fn byte_offset(idx: usize) -> Result<i32, JoinCompilerError> {
325    let byte = idx.checked_mul(std::mem::size_of::<f64>()).ok_or_else(|| {
326        JoinCompilerError::CodegenError(format!("column index {idx} overflows byte offset"))
327    })?;
328    i32::try_from(byte).map_err(|_| {
329        JoinCompilerError::CodegenError(format!(
330            "column index {idx} byte offset {} exceeds i32::MAX",
331            byte
332        ))
333    })
334}
335
336// ---------------------------------------------------------------------------
337// Unit tests (inline)
338// ---------------------------------------------------------------------------
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn compiler() -> JoinCompiler {
345        JoinCompiler::new()
346    }
347
348    #[test]
349    fn test_no_keys_error() {
350        let result = compiler().compile(&[]);
351        assert!(matches!(result, Err(JoinCompilerError::NoKeys)));
352    }
353
354    #[test]
355    fn test_byte_offset_zero() {
356        assert_eq!(byte_offset(0).expect("ok"), 0);
357    }
358
359    #[test]
360    fn test_byte_offset_one() {
361        assert_eq!(byte_offset(1).expect("ok"), 8);
362    }
363}