Skip to main content

oxirs_arq/jit/
project_compiler.rs

1//! Cranelift-based JIT compiler for SPARQL PROJECT column extraction (JIT phase d).
2//!
3//! Produces a native function with signature:
4//! ```text
5//! fn(src_ptr: *const f64, src_len: usize, dst_ptr: *mut f64, dst_len: usize) -> i8
6//! ```
7//! Returns `1` on success, `0` if either the source or destination slice is too short
8//! to satisfy the requested column mapping.
9//!
10//! # Operator semantics
11//!
12//! The PROJECT operator selects and reorders columns from a source row into a
13//! destination row.  Each [`ProjectSpec`] entry contributes one `f64` value
14//! to the output by naming a source-row column index.  Output order follows
15//! the order of the `specs` slice.
16//!
17//! # Safety
18//!
19//! The compiled function is called via an `unsafe extern "C" fn` pointer.
20//! Callers must guarantee:
21//! - `src_ptr` is a valid, aligned, non-null pointer to at least `src_len` `f64` values.
22//! - `dst_ptr` is a valid, aligned, non-null pointer to at least `dst_len` `f64` values.
23//! - Both pointers remain valid for the duration of the call.
24
25use std::sync::Arc;
26
27use cranelift_codegen::ir::{
28    condcodes::IntCC, types, AbiParam, InstBuilder, MemFlags, Signature, Value,
29};
30use cranelift_codegen::isa::CallConv;
31use cranelift_codegen::settings::{self, Configurable};
32use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
33use cranelift_jit::{JITBuilder, JITModule};
34use cranelift_module::{Linkage, Module};
35
36use super::filter_compiler::JITModuleOwner;
37
38// ---------------------------------------------------------------------------
39// Public types
40// ---------------------------------------------------------------------------
41
42/// Configuration for one output column in a PROJECT operation.
43#[derive(Debug, Clone, Copy)]
44pub struct ProjectSpec {
45    /// Index into the source row `f64` slice that this output slot reads from.
46    pub src_idx: usize,
47}
48
49/// The C-ABI function pointer produced by [`ProjectCompiler::compile`].
50///
51/// # Safety
52///
53/// - `src_ptr` must point to at least `src_len` `f64` values.
54/// - `dst_ptr` must point to at least `dst_len` writable `f64` values.
55/// - `dst_len` must be `≥ specs.len()`.
56type ProjectFn = unsafe extern "C" fn(*const f64, usize, *mut f64, usize) -> i8;
57
58/// A compiled PROJECT-operator column extractor produced by [`ProjectCompiler::compile`].
59///
60/// # Ownership / safety invariant
61///
62/// `fn_ptr` is only valid while `_module_owner` is alive.
63pub struct CompiledProject {
64    /// JIT-compiled function pointer.
65    fn_ptr: ProjectFn,
66    /// Kept for [`output_width`](CompiledProject::output_width).
67    specs: Vec<ProjectSpec>,
68    /// Keeps the `JITModule` code pages alive.
69    _module_owner: Arc<JITModuleOwner>,
70}
71
72// SAFETY: JITModule code pages are read-only after finalisation;
73// the module is protected by `Arc<JITModuleOwner>`.
74unsafe impl Send for CompiledProject {}
75unsafe impl Sync for CompiledProject {}
76
77impl std::fmt::Debug for CompiledProject {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("CompiledProject")
80            .field("output_width", &self.specs.len())
81            .finish_non_exhaustive()
82    }
83}
84
85impl CompiledProject {
86    /// Number of output columns this projector produces.
87    pub fn output_width(&self) -> usize {
88        self.specs.len()
89    }
90
91    /// Extract projected columns from `src` into `dst`.
92    ///
93    /// `dst` is resized to [`output_width`](Self::output_width) before the call.
94    /// Returns `false` if `src` is too short to satisfy any required column index.
95    ///
96    /// # Safety invariant
97    ///
98    /// The compiled function is invoked via a raw function pointer that is valid
99    /// because `_module_owner` keeps the `JITModule` pages alive for the lifetime
100    /// of `self`.
101    pub fn extract(&self, src: &[f64], dst: &mut Vec<f64>) -> bool {
102        let out_len = self.specs.len();
103        dst.resize(out_len, 0.0);
104        // SAFETY:
105        // - `fn_ptr` is valid because `_module_owner` is still alive (we hold it in self).
106        // - `src.as_ptr()` is a valid, aligned pointer to `src.len()` f64 values.
107        // - `dst.as_mut_ptr()` is a valid, aligned, writable pointer to `dst.len()` f64 values.
108        // - Both slices remain valid for the duration of this call.
109        let result = unsafe { (self.fn_ptr)(src.as_ptr(), src.len(), dst.as_mut_ptr(), dst.len()) };
110        result == 1
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Error type
116// ---------------------------------------------------------------------------
117
118/// Errors that can occur during JIT compilation of a PROJECT extractor.
119#[derive(Debug, thiserror::Error)]
120pub enum ProjectCompilerError {
121    /// Cranelift reported a codegen or linkage error.
122    #[error("JIT codegen error: {0}")]
123    CodegenError(String),
124    /// ISA builder failed to initialise.
125    #[error("JIT ISA init error: {0}")]
126    IsaInitError(String),
127    /// Function declaration or linkage failed.
128    #[error("JIT linkage error: {0}")]
129    LinkageError(String),
130}
131
132// ---------------------------------------------------------------------------
133// Compiler (ZST — new JITModule per compile call)
134// ---------------------------------------------------------------------------
135
136/// Compiles PROJECT column-extraction operations to native machine code via Cranelift.
137///
138/// Each call to [`compile`](ProjectCompiler::compile) creates a fresh `JITModule`
139/// so that the resulting [`CompiledProject`] is independently owned and can be dropped
140/// without affecting other compiled functions.
141pub struct ProjectCompiler;
142
143impl Default for ProjectCompiler {
144    fn default() -> Self {
145        ProjectCompiler
146    }
147}
148
149impl ProjectCompiler {
150    /// Create a new `ProjectCompiler`.
151    pub fn new() -> Self {
152        ProjectCompiler
153    }
154
155    /// Compile a column-extraction function for the given `specs`.
156    ///
157    /// If `specs` is empty the compiled function is a no-op that always returns success.
158    pub fn compile(
159        &mut self,
160        specs: &[ProjectSpec],
161    ) -> Result<CompiledProject, ProjectCompilerError> {
162        let module = build_jit_module()?;
163        let (fn_ptr, module) = compile_project_fn(module, specs)?;
164        let owner = Arc::new(JITModuleOwner::new(module));
165        Ok(CompiledProject {
166            fn_ptr,
167            specs: specs.to_vec(),
168            _module_owner: owner,
169        })
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Cranelift module setup
175// ---------------------------------------------------------------------------
176
177fn build_jit_module() -> Result<JITModule, ProjectCompilerError> {
178    let mut flag_builder = settings::builder();
179    flag_builder
180        .set("use_colocated_libcalls", "false")
181        .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
182    flag_builder
183        .set("is_pic", "false")
184        .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
185    flag_builder
186        .set("opt_level", "speed")
187        .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
188
189    let flags = settings::Flags::new(flag_builder);
190    let isa = cranelift_native::builder()
191        .map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?
192        .finish(flags)
193        .map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?;
194
195    let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
196    Ok(JITModule::new(builder))
197}
198
199// ---------------------------------------------------------------------------
200// Cranelift IR code generation
201// ---------------------------------------------------------------------------
202
203/// Compile a project function into `module` and return the function pointer
204/// together with the (now finalized) module.
205fn compile_project_fn(
206    mut module: JITModule,
207    specs: &[ProjectSpec],
208) -> Result<(ProjectFn, JITModule), ProjectCompilerError> {
209    // Use the host pointer width for all params (usize / pointer types)
210    let ptr_type = module.isa().pointer_type();
211
212    // Signature: fn(*const f64, usize, *mut f64, usize) -> i8
213    let mut sig = Signature::new(CallConv::SystemV);
214    sig.params.push(AbiParam::new(ptr_type)); // src_ptr: *const f64
215    sig.params.push(AbiParam::new(ptr_type)); // src_len: usize
216    sig.params.push(AbiParam::new(ptr_type)); // dst_ptr: *mut f64
217    sig.params.push(AbiParam::new(ptr_type)); // dst_len: usize
218    sig.returns.push(AbiParam::new(types::I8)); // return: i8 (1 = ok, 0 = bounds error)
219
220    let func_id = module
221        .declare_function("project_fn", Linkage::Local, &sig)
222        .map_err(|e| ProjectCompilerError::LinkageError(e.to_string()))?;
223
224    {
225        let mut ctx = module.make_context();
226        ctx.func.signature = sig.clone();
227
228        let mut fn_builder_ctx = FunctionBuilderContext::new();
229        let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
230
231        if specs.is_empty() {
232            // Fast path: no columns to project — emit a single-block return-1 function.
233            emit_trivial_success(&mut builder, ptr_type);
234        } else {
235            emit_project_body(&mut builder, specs, ptr_type)?;
236        }
237
238        builder.finalize();
239
240        module
241            .define_function(func_id, &mut ctx)
242            .map_err(|e| ProjectCompilerError::CodegenError(format!("{e:?}")))?;
243    }
244
245    module
246        .finalize_definitions()
247        .map_err(|e| ProjectCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
248
249    // SAFETY: The function was just defined and finalized above; the pointer is valid.
250    let raw_ptr = module.get_finalized_function(func_id);
251    // SAFETY: We built the function with exactly this signature.
252    let fn_ptr: ProjectFn = unsafe { std::mem::transmute(raw_ptr) };
253
254    Ok((fn_ptr, module))
255}
256
257/// Emit a trivial single-block function that immediately returns 1 (success).
258///
259/// Used when there are no specs to process.
260fn emit_trivial_success(builder: &mut FunctionBuilder<'_>, ptr_type: types::Type) {
261    let entry_block = builder.create_block();
262    builder.append_block_params_for_function_params(entry_block);
263    builder.switch_to_block(entry_block);
264    builder.seal_block(entry_block);
265
266    // Consume all parameters to satisfy the ABI (4 params for project_fn).
267    // This suppresses any "unused parameter" IR warnings from Cranelift.
268    let _src_ptr: Value = builder.block_params(entry_block)[0];
269    let _src_len: Value = builder.block_params(entry_block)[1];
270    let _dst_ptr: Value = builder.block_params(entry_block)[2];
271    let _dst_len: Value = builder.block_params(entry_block)[3];
272
273    // Suppress unused type warning
274    let _ = ptr_type;
275
276    let one = builder.ins().iconst(types::I8, 1);
277    builder.ins().return_(&[one]);
278}
279
280/// Emit the multi-block project body with bounds checks.
281///
282/// Control flow:
283/// ```text
284///   entry → [bounds_fail if src_len <= max_src_idx || dst_len < specs.len()]
285///   entry → body → return 1
286///   bounds_fail → return 0
287/// ```
288fn emit_project_body(
289    builder: &mut FunctionBuilder<'_>,
290    specs: &[ProjectSpec],
291    ptr_type: types::Type,
292) -> Result<(), ProjectCompilerError> {
293    // Compute the maximum source index needed across all specs.
294    // SAFETY: specs is non-empty at this call site.
295    let max_src_idx = specs.iter().map(|s| s.src_idx).max().unwrap_or(0);
296
297    // Create blocks: entry, bounds_fail, body.
298    let entry_block = builder.create_block();
299    let bounds_fail_block = builder.create_block();
300    let body_block = builder.create_block();
301
302    // ---- Entry block ----
303    builder.append_block_params_for_function_params(entry_block);
304    builder.switch_to_block(entry_block);
305    builder.seal_block(entry_block);
306
307    let src_ptr = builder.block_params(entry_block)[0];
308    let src_len = builder.block_params(entry_block)[1];
309    let dst_ptr = builder.block_params(entry_block)[2];
310    let dst_len = builder.block_params(entry_block)[3];
311
312    // Bounds check 1: src_len <= max_src_idx  (unsigned ≤, i.e., src cannot hold the largest src_idx)
313    // We need src_len > max_src_idx, i.e., src_len >= max_src_idx + 1.
314    // Failure condition: src_len <= max_src_idx (UnsignedLessThanOrEqual).
315    let max_src_val = builder.ins().iconst(ptr_type, max_src_idx as i64);
316    let src_too_short = builder
317        .ins()
318        .icmp(IntCC::UnsignedLessThanOrEqual, src_len, max_src_val);
319
320    // Bounds check 2: dst_len < specs.len()
321    // Failure condition: dst_len < specs.len() (UnsignedLessThan).
322    let dst_need = builder.ins().iconst(ptr_type, specs.len() as i64);
323    let dst_too_short = builder
324        .ins()
325        .icmp(IntCC::UnsignedLessThan, dst_len, dst_need);
326
327    // Combine: either condition → fail
328    let any_fail = builder.ins().bor(src_too_short, dst_too_short);
329
330    // brif: if any_fail → bounds_fail_block, else → body_block
331    builder
332        .ins()
333        .brif(any_fail, bounds_fail_block, &[], body_block, &[]);
334
335    // ---- Bounds-fail block ----
336    builder.switch_to_block(bounds_fail_block);
337    builder.seal_block(bounds_fail_block);
338    let zero_i8 = builder.ins().iconst(types::I8, 0);
339    builder.ins().return_(&[zero_i8]);
340
341    // ---- Body block ----
342    builder.switch_to_block(body_block);
343    builder.seal_block(body_block);
344
345    // For each spec: load src[src_idx], store to dst[dst_i].
346    for (dst_i, spec) in specs.iter().enumerate() {
347        // Source load: src_ptr + src_idx * 8
348        let src_offset = builder
349            .ins()
350            .iconst(ptr_type, (spec.src_idx * std::mem::size_of::<f64>()) as i64);
351        let src_addr = builder.ins().iadd(src_ptr, src_offset);
352        // SAFETY comment for generated IR: bounds were verified above; load is safe.
353        let val = builder.ins().load(types::F64, MemFlags::new(), src_addr, 0);
354
355        // Destination store: dst_ptr + dst_i * 8
356        let dst_offset = builder
357            .ins()
358            .iconst(ptr_type, (dst_i * std::mem::size_of::<f64>()) as i64);
359        let dst_addr = builder.ins().iadd(dst_ptr, dst_offset);
360        // SAFETY comment for generated IR: bounds were verified above; store is safe.
361        builder.ins().store(MemFlags::new(), val, dst_addr, 0);
362    }
363
364    let one_i8 = builder.ins().iconst(types::I8, 1);
365    builder.ins().return_(&[one_i8]);
366
367    Ok(())
368}
369
370// ---------------------------------------------------------------------------
371// Unit tests (inline)
372// ---------------------------------------------------------------------------
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    fn compiler() -> ProjectCompiler {
379        ProjectCompiler::new()
380    }
381
382    #[test]
383    fn test_empty_specs_compiles_and_succeeds() {
384        let cp = compiler().compile(&[]).expect("compile ok");
385        assert_eq!(cp.output_width(), 0);
386        let src = [1.0f64, 2.0, 3.0];
387        let mut dst = Vec::new();
388        assert!(cp.extract(&src, &mut dst));
389        assert!(dst.is_empty());
390    }
391
392    #[test]
393    fn test_single_col_extract() {
394        let specs = [ProjectSpec { src_idx: 1 }];
395        let cp = compiler().compile(&specs).expect("compile ok");
396        assert_eq!(cp.output_width(), 1);
397        let src = [10.0f64, 20.0, 30.0];
398        let mut dst = Vec::new();
399        assert!(cp.extract(&src, &mut dst));
400        assert_eq!(dst, vec![20.0]);
401    }
402
403    #[test]
404    fn test_reorder_extract() {
405        let specs = [
406            ProjectSpec { src_idx: 2 },
407            ProjectSpec { src_idx: 0 },
408            ProjectSpec { src_idx: 1 },
409        ];
410        let cp = compiler().compile(&specs).expect("compile ok");
411        let src = [1.0f64, 2.0, 3.0];
412        let mut dst = Vec::new();
413        assert!(cp.extract(&src, &mut dst));
414        assert_eq!(dst, vec![3.0, 1.0, 2.0]);
415    }
416
417    #[test]
418    fn test_src_bounds_fail() {
419        let specs = [ProjectSpec { src_idx: 5 }];
420        let cp = compiler().compile(&specs).expect("compile ok");
421        let src = [1.0f64, 2.0]; // only 2 elements; idx 5 is out of bounds
422        let mut dst = Vec::new();
423        assert!(!cp.extract(&src, &mut dst));
424    }
425}