use std::sync::Arc;
use cranelift_codegen::ir::{
condcodes::IntCC, types, AbiParam, InstBuilder, MemFlags, Signature, Value,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use super::filter_compiler::JITModuleOwner;
#[derive(Debug, Clone, Copy)]
pub struct ProjectSpec {
pub src_idx: usize,
}
type ProjectFn = unsafe extern "C" fn(*const f64, usize, *mut f64, usize) -> i8;
pub struct CompiledProject {
fn_ptr: ProjectFn,
specs: Vec<ProjectSpec>,
_module_owner: Arc<JITModuleOwner>,
}
unsafe impl Send for CompiledProject {}
unsafe impl Sync for CompiledProject {}
impl std::fmt::Debug for CompiledProject {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledProject")
.field("output_width", &self.specs.len())
.finish_non_exhaustive()
}
}
impl CompiledProject {
pub fn output_width(&self) -> usize {
self.specs.len()
}
pub fn extract(&self, src: &[f64], dst: &mut Vec<f64>) -> bool {
let out_len = self.specs.len();
dst.resize(out_len, 0.0);
let result = unsafe { (self.fn_ptr)(src.as_ptr(), src.len(), dst.as_mut_ptr(), dst.len()) };
result == 1
}
}
#[derive(Debug, thiserror::Error)]
pub enum ProjectCompilerError {
#[error("JIT codegen error: {0}")]
CodegenError(String),
#[error("JIT ISA init error: {0}")]
IsaInitError(String),
#[error("JIT linkage error: {0}")]
LinkageError(String),
}
pub struct ProjectCompiler;
impl Default for ProjectCompiler {
fn default() -> Self {
ProjectCompiler
}
}
impl ProjectCompiler {
pub fn new() -> Self {
ProjectCompiler
}
pub fn compile(
&mut self,
specs: &[ProjectSpec],
) -> Result<CompiledProject, ProjectCompilerError> {
let module = build_jit_module()?;
let (fn_ptr, module) = compile_project_fn(module, specs)?;
let owner = Arc::new(JITModuleOwner::new(module));
Ok(CompiledProject {
fn_ptr,
specs: specs.to_vec(),
_module_owner: owner,
})
}
}
fn build_jit_module() -> Result<JITModule, ProjectCompilerError> {
let mut flag_builder = settings::builder();
flag_builder
.set("use_colocated_libcalls", "false")
.map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("is_pic", "false")
.map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("opt_level", "speed")
.map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
let flags = settings::Flags::new(flag_builder);
let isa = cranelift_native::builder()
.map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?
.finish(flags)
.map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
Ok(JITModule::new(builder))
}
fn compile_project_fn(
mut module: JITModule,
specs: &[ProjectSpec],
) -> Result<(ProjectFn, JITModule), ProjectCompilerError> {
let ptr_type = module.isa().pointer_type();
let mut sig = Signature::new(CallConv::SystemV);
sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.returns.push(AbiParam::new(types::I8));
let func_id = module
.declare_function("project_fn", Linkage::Local, &sig)
.map_err(|e| ProjectCompilerError::LinkageError(e.to_string()))?;
{
let mut ctx = module.make_context();
ctx.func.signature = sig.clone();
let mut fn_builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
if specs.is_empty() {
emit_trivial_success(&mut builder, ptr_type);
} else {
emit_project_body(&mut builder, specs, ptr_type)?;
}
builder.finalize();
module
.define_function(func_id, &mut ctx)
.map_err(|e| ProjectCompilerError::CodegenError(format!("{e:?}")))?;
}
module
.finalize_definitions()
.map_err(|e| ProjectCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
let raw_ptr = module.get_finalized_function(func_id);
let fn_ptr: ProjectFn = unsafe { std::mem::transmute(raw_ptr) };
Ok((fn_ptr, module))
}
fn emit_trivial_success(builder: &mut FunctionBuilder<'_>, ptr_type: types::Type) {
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let _src_ptr: Value = builder.block_params(entry_block)[0];
let _src_len: Value = builder.block_params(entry_block)[1];
let _dst_ptr: Value = builder.block_params(entry_block)[2];
let _dst_len: Value = builder.block_params(entry_block)[3];
let _ = ptr_type;
let one = builder.ins().iconst(types::I8, 1);
builder.ins().return_(&[one]);
}
fn emit_project_body(
builder: &mut FunctionBuilder<'_>,
specs: &[ProjectSpec],
ptr_type: types::Type,
) -> Result<(), ProjectCompilerError> {
let max_src_idx = specs.iter().map(|s| s.src_idx).max().unwrap_or(0);
let entry_block = builder.create_block();
let bounds_fail_block = builder.create_block();
let body_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let src_ptr = builder.block_params(entry_block)[0];
let src_len = builder.block_params(entry_block)[1];
let dst_ptr = builder.block_params(entry_block)[2];
let dst_len = builder.block_params(entry_block)[3];
let max_src_val = builder.ins().iconst(ptr_type, max_src_idx as i64);
let src_too_short = builder
.ins()
.icmp(IntCC::UnsignedLessThanOrEqual, src_len, max_src_val);
let dst_need = builder.ins().iconst(ptr_type, specs.len() as i64);
let dst_too_short = builder
.ins()
.icmp(IntCC::UnsignedLessThan, dst_len, dst_need);
let any_fail = builder.ins().bor(src_too_short, dst_too_short);
builder
.ins()
.brif(any_fail, bounds_fail_block, &[], body_block, &[]);
builder.switch_to_block(bounds_fail_block);
builder.seal_block(bounds_fail_block);
let zero_i8 = builder.ins().iconst(types::I8, 0);
builder.ins().return_(&[zero_i8]);
builder.switch_to_block(body_block);
builder.seal_block(body_block);
for (dst_i, spec) in specs.iter().enumerate() {
let src_offset = builder
.ins()
.iconst(ptr_type, (spec.src_idx * std::mem::size_of::<f64>()) as i64);
let src_addr = builder.ins().iadd(src_ptr, src_offset);
let val = builder.ins().load(types::F64, MemFlags::new(), src_addr, 0);
let dst_offset = builder
.ins()
.iconst(ptr_type, (dst_i * std::mem::size_of::<f64>()) as i64);
let dst_addr = builder.ins().iadd(dst_ptr, dst_offset);
builder.ins().store(MemFlags::new(), val, dst_addr, 0);
}
let one_i8 = builder.ins().iconst(types::I8, 1);
builder.ins().return_(&[one_i8]);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn compiler() -> ProjectCompiler {
ProjectCompiler::new()
}
#[test]
fn test_empty_specs_compiles_and_succeeds() {
let cp = compiler().compile(&[]).expect("compile ok");
assert_eq!(cp.output_width(), 0);
let src = [1.0f64, 2.0, 3.0];
let mut dst = Vec::new();
assert!(cp.extract(&src, &mut dst));
assert!(dst.is_empty());
}
#[test]
fn test_single_col_extract() {
let specs = [ProjectSpec { src_idx: 1 }];
let cp = compiler().compile(&specs).expect("compile ok");
assert_eq!(cp.output_width(), 1);
let src = [10.0f64, 20.0, 30.0];
let mut dst = Vec::new();
assert!(cp.extract(&src, &mut dst));
assert_eq!(dst, vec![20.0]);
}
#[test]
fn test_reorder_extract() {
let specs = [
ProjectSpec { src_idx: 2 },
ProjectSpec { src_idx: 0 },
ProjectSpec { src_idx: 1 },
];
let cp = compiler().compile(&specs).expect("compile ok");
let src = [1.0f64, 2.0, 3.0];
let mut dst = Vec::new();
assert!(cp.extract(&src, &mut dst));
assert_eq!(dst, vec![3.0, 1.0, 2.0]);
}
#[test]
fn test_src_bounds_fail() {
let specs = [ProjectSpec { src_idx: 5 }];
let cp = compiler().compile(&specs).expect("compile ok");
let src = [1.0f64, 2.0]; let mut dst = Vec::new();
assert!(!cp.extract(&src, &mut dst));
}
}