oxirs_arq/jit/join_compiler.rs
1//! Cranelift-based JIT compiler for hash-join key comparison (JIT phase c).
2//!
3//! Produces a native function with signature:
4//! ```text
5//! fn(left: *const f64, right: *const f64, n_keys: usize) -> i8
6//! ```
7//! Returns `1` if all key columns match, `0` otherwise.
8//!
9//! # Comparison modes
10//!
11//! Each key column chooses one of two modes:
12//! - **Epsilon**: `|left[i] - right[i]| < 1e-9` (ordered float comparison, NaN → not equal).
13//! - **Exact**: bitcast both operands to `i64` and check `icmp Equal`. Under this mode,
14//! two values with identical bit patterns — including two copies of the *same* NaN — compare
15//! as equal, because IEEE 754 NaN bit patterns are just ordinary 64-bit integers to `icmp`.
16//!
17//! # Safety
18//!
19//! The compiled function is called via an `unsafe extern "C" fn` pointer.
20//! Callers must guarantee:
21//! - `left` and `right` are valid, aligned, non-null pointers to at least
22//! `max(spec.left_idx) + 1` and `max(spec.right_idx) + 1` `f64` values, respectively.
23//! - Both pointers remain valid for the duration of the call.
24
25use std::sync::Arc;
26
27use cranelift_codegen::ir::{
28 condcodes::{FloatCC, IntCC},
29 types, AbiParam, InstBuilder, MemFlags, Signature,
30};
31use cranelift_codegen::isa::CallConv;
32use cranelift_codegen::settings::{self, Configurable};
33use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
34use cranelift_jit::{JITBuilder, JITModule};
35use cranelift_module::{Linkage, Module};
36
37use super::filter_compiler::JITModuleOwner;
38
39// ---------------------------------------------------------------------------
40// Public types
41// ---------------------------------------------------------------------------
42
43/// Configuration for a single join key column.
44#[derive(Debug, Clone)]
45pub struct JoinKeySpec {
46 /// Index into the left-row `f64` slice.
47 pub left_idx: usize,
48 /// Index into the right-row `f64` slice.
49 pub right_idx: usize,
50 /// If `true`, use epsilon comparison (`|a - b| < 1e-9`).
51 /// If `false`, use exact bit equality (bitcast to `i64` then `icmp Equal`).
52 pub numeric_epsilon: bool,
53}
54
55/// The C-ABI function pointer produced by [`JoinCompiler::compile`].
56///
57/// # Safety
58///
59/// `left` and `right` must each point to enough `f64` values to cover all
60/// `left_idx` / `right_idx` in the compiled spec. `n_keys` must equal the
61/// number of [`JoinKeySpec`] entries the function was compiled with.
62type JoinKeyFn = unsafe extern "C" fn(*const f64, *const f64, usize) -> i8;
63
64/// A compiled join-key comparator produced by [`JoinCompiler::compile`].
65///
66/// The wrapped function returns `1` if all key columns match and `0` otherwise.
67///
68/// # Ownership / safety invariant
69///
70/// `fn_ptr` is only valid while `_module_owner` is alive.
71/// Cloning an `Arc<CompiledJoinKey>` extends the lifetime safely.
72pub struct CompiledJoinKey {
73 /// JIT-compiled function pointer.
74 fn_ptr: JoinKeyFn,
75 /// Key specs kept for `key_count()`.
76 specs: Vec<JoinKeySpec>,
77 /// Keeps the `JITModule` code pages alive.
78 _module_owner: Arc<JITModuleOwner>,
79}
80
81// SAFETY: JITModule code pages are read-only after finalisation;
82// the module is protected by `Arc<JITModuleOwner>`.
83unsafe impl Send for CompiledJoinKey {}
84unsafe impl Sync for CompiledJoinKey {}
85
86impl std::fmt::Debug for CompiledJoinKey {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("CompiledJoinKey")
89 .field("key_count", &self.specs.len())
90 .finish_non_exhaustive()
91 }
92}
93
94impl CompiledJoinKey {
95 /// Compare two rows using the compiled function.
96 ///
97 /// Returns `true` if all key columns match according to each column's comparison mode.
98 ///
99 /// # Safety
100 ///
101 /// `left` must have at least `max(spec.left_idx) + 1` elements and
102 /// `right` must have at least `max(spec.right_idx) + 1` elements.
103 pub fn compare(&self, left: &[f64], right: &[f64]) -> bool {
104 let n = self.specs.len();
105 // SAFETY:
106 // - `fn_ptr` is valid because `_module_owner` is still alive (we hold it).
107 // - `left.as_ptr()` and `right.as_ptr()` are valid for `n` reads
108 // when the caller obeys the documented precondition.
109 // - The compiled function does not mutate either slice.
110 unsafe { (self.fn_ptr)(left.as_ptr(), right.as_ptr(), n) == 1 }
111 }
112
113 /// The number of key columns this comparator was compiled for.
114 pub fn key_count(&self) -> usize {
115 self.specs.len()
116 }
117}
118
119// ---------------------------------------------------------------------------
120// Error type
121// ---------------------------------------------------------------------------
122
123/// Errors that can occur during JIT compilation of a join-key comparator.
124#[derive(Debug, thiserror::Error)]
125pub enum JoinCompilerError {
126 /// No key specs were provided.
127 #[error("join compiler requires at least one key spec")]
128 NoKeys,
129 /// Cranelift reported a codegen or linkage error.
130 #[error("JIT codegen error: {0}")]
131 CodegenError(String),
132 /// ISA builder failed to initialise.
133 #[error("JIT ISA init error: {0}")]
134 IsaInitError(String),
135}
136
137// ---------------------------------------------------------------------------
138// Compiler (ZST — new JITModule per compile call)
139// ---------------------------------------------------------------------------
140
141/// Compiles join-key comparators to native machine code via Cranelift.
142///
143/// Each call to [`compile`](JoinCompiler::compile) creates a fresh `JITModule`
144/// so that the resulting [`CompiledJoinKey`] is independently owned and can be
145/// dropped without affecting other compiled functions.
146pub struct JoinCompiler;
147
148impl Default for JoinCompiler {
149 fn default() -> Self {
150 JoinCompiler
151 }
152}
153
154impl JoinCompiler {
155 /// Create a new `JoinCompiler`.
156 pub fn new() -> Self {
157 JoinCompiler
158 }
159
160 /// Compile a multi-key join comparator.
161 ///
162 /// The resulting function returns `1` if **all** key columns match, `0` if any diverges.
163 ///
164 /// # Errors
165 ///
166 /// Returns [`JoinCompilerError::NoKeys`] if `specs` is empty.
167 /// Returns [`JoinCompilerError::CodegenError`] on Cranelift failures.
168 pub fn compile(&self, specs: &[JoinKeySpec]) -> Result<CompiledJoinKey, JoinCompilerError> {
169 if specs.is_empty() {
170 return Err(JoinCompilerError::NoKeys);
171 }
172 let module = build_jit_module()?;
173 let (fn_ptr, module) = compile_join_fn(module, specs)?;
174 let owner = Arc::new(JITModuleOwner::new(module));
175 Ok(CompiledJoinKey {
176 fn_ptr,
177 specs: specs.to_vec(),
178 _module_owner: owner,
179 })
180 }
181}
182
183// ---------------------------------------------------------------------------
184// Cranelift module setup
185// ---------------------------------------------------------------------------
186
187fn build_jit_module() -> Result<JITModule, JoinCompilerError> {
188 let mut flag_builder = settings::builder();
189 flag_builder
190 .set("use_colocated_libcalls", "false")
191 .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
192 flag_builder
193 .set("is_pic", "false")
194 .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
195 flag_builder
196 .set("opt_level", "speed")
197 .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
198
199 let flags = settings::Flags::new(flag_builder);
200 let isa = cranelift_native::builder()
201 .map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?
202 .finish(flags)
203 .map_err(|e| JoinCompilerError::IsaInitError(e.to_string()))?;
204
205 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
206 Ok(JITModule::new(builder))
207}
208
209// ---------------------------------------------------------------------------
210// Cranelift IR code generation
211// ---------------------------------------------------------------------------
212
213/// Compile a join-key comparator function into `module` and return the function
214/// pointer together with the (now finalized) module.
215fn compile_join_fn(
216 mut module: JITModule,
217 specs: &[JoinKeySpec],
218) -> Result<(JoinKeyFn, JITModule), JoinCompilerError> {
219 let ptr_type = module.isa().pointer_type();
220
221 // Signature: fn(*const f64, *const f64, usize) -> i8
222 let mut sig = Signature::new(CallConv::SystemV);
223 sig.params.push(AbiParam::new(ptr_type)); // left ptr
224 sig.params.push(AbiParam::new(ptr_type)); // right ptr
225 sig.params.push(AbiParam::new(ptr_type)); // n_keys (unused at runtime, kept for ABI)
226 sig.returns.push(AbiParam::new(types::I8));
227
228 let func_id = module
229 .declare_function("join_key_fn", Linkage::Local, &sig)
230 .map_err(|e| JoinCompilerError::CodegenError(e.to_string()))?;
231
232 {
233 let mut ctx = module.make_context();
234 ctx.func.signature = sig.clone();
235
236 let mut fn_builder_ctx = FunctionBuilderContext::new();
237 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
238
239 let entry_block = builder.create_block();
240 builder.append_block_params_for_function_params(entry_block);
241 builder.switch_to_block(entry_block);
242 builder.seal_block(entry_block);
243
244 let left_ptr = builder.block_params(entry_block)[0];
245 let right_ptr = builder.block_params(entry_block)[1];
246 // Third param (n_keys) is only for ABI match; suppress "unused" at the IR level.
247 let _n_keys = builder.block_params(entry_block)[2];
248
249 // Emit one comparison per key and AND them together.
250 // `accumulator` starts as the i8 value `1` (all-match).
251 let one_i8 = builder.ins().iconst(types::I8, 1);
252 let mut accumulator = one_i8;
253
254 for spec in specs {
255 let key_ok = emit_key_comparison(&mut builder, spec, left_ptr, right_ptr)?;
256 // Logical AND of i8 values (both are 0 or 1)
257 accumulator = builder.ins().band(accumulator, key_ok);
258 }
259
260 builder.ins().return_(&[accumulator]);
261 builder.finalize();
262
263 module
264 .define_function(func_id, &mut ctx)
265 .map_err(|e| JoinCompilerError::CodegenError(format!("{e:?}")))?;
266 }
267
268 module
269 .finalize_definitions()
270 .map_err(|e| JoinCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
271
272 // SAFETY: The function was just defined and finalized above; the pointer is valid.
273 let raw_ptr = module.get_finalized_function(func_id);
274 // SAFETY: We built the function with exactly this signature.
275 let fn_ptr: JoinKeyFn = unsafe { std::mem::transmute(raw_ptr) };
276
277 Ok((fn_ptr, module))
278}
279
280/// Emit a single key-column comparison and return an `i8` value (1 = match, 0 = no match).
281fn emit_key_comparison(
282 builder: &mut FunctionBuilder<'_>,
283 spec: &JoinKeySpec,
284 left_ptr: cranelift_codegen::ir::Value,
285 right_ptr: cranelift_codegen::ir::Value,
286) -> Result<cranelift_codegen::ir::Value, JoinCompilerError> {
287 let left_offset = byte_offset(spec.left_idx)?;
288 let right_offset = byte_offset(spec.right_idx)?;
289
290 // Load left[left_idx] and right[right_idx] as f64.
291 // SAFETY comment for generated IR: caller guarantees indices are in-bounds.
292 let lv = builder
293 .ins()
294 .load(types::F64, MemFlags::trusted(), left_ptr, left_offset);
295 let rv = builder
296 .ins()
297 .load(types::F64, MemFlags::trusted(), right_ptr, right_offset);
298
299 if spec.numeric_epsilon {
300 // |lv - rv| < 1e-9
301 // Note: if either operand is NaN, fsub returns NaN; fabs(NaN) is NaN;
302 // fcmp LessThan(NaN, eps) = false → correctly returns 0 (not equal).
303 let diff = builder.ins().fsub(lv, rv);
304 let abs_diff = builder.ins().fabs(diff);
305 let eps = builder.ins().f64const(1e-9);
306 // fcmp returns I8 (0 or 1) in modern Cranelift — no bint needed.
307 let cmp = builder.ins().fcmp(FloatCC::LessThan, abs_diff, eps);
308 Ok(cmp)
309 } else {
310 // Exact bit equality: bitcast both to i64 then icmp Equal.
311 // Under this mode two NaN values with identical bit patterns compare as equal,
312 // because icmp treats them as plain integers — documented behaviour.
313 let li = builder.ins().bitcast(types::I64, MemFlags::new(), lv);
314 let ri = builder.ins().bitcast(types::I64, MemFlags::new(), rv);
315 // icmp returns I8 (0 or 1) in modern Cranelift.
316 let cmp = builder.ins().icmp(IntCC::Equal, li, ri);
317 Ok(cmp)
318 }
319}
320
321/// Convert a column index to a `i32` byte offset for Cranelift `load`.
322///
323/// Returns an error if the offset overflows `i32` (requires >268 million columns).
324fn byte_offset(idx: usize) -> Result<i32, JoinCompilerError> {
325 let byte = idx.checked_mul(std::mem::size_of::<f64>()).ok_or_else(|| {
326 JoinCompilerError::CodegenError(format!("column index {idx} overflows byte offset"))
327 })?;
328 i32::try_from(byte).map_err(|_| {
329 JoinCompilerError::CodegenError(format!(
330 "column index {idx} byte offset {} exceeds i32::MAX",
331 byte
332 ))
333 })
334}
335
336// ---------------------------------------------------------------------------
337// Unit tests (inline)
338// ---------------------------------------------------------------------------
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn compiler() -> JoinCompiler {
345 JoinCompiler::new()
346 }
347
348 #[test]
349 fn test_no_keys_error() {
350 let result = compiler().compile(&[]);
351 assert!(matches!(result, Err(JoinCompilerError::NoKeys)));
352 }
353
354 #[test]
355 fn test_byte_offset_zero() {
356 assert_eq!(byte_offset(0).expect("ok"), 0);
357 }
358
359 #[test]
360 fn test_byte_offset_one() {
361 assert_eq!(byte_offset(1).expect("ok"), 8);
362 }
363}