use std::collections::HashMap;
use std::sync::Arc;
use cranelift_codegen::ir::{
condcodes::FloatCC, types, AbiParam, InstBuilder, MemFlags, Signature, Value,
};
use cranelift_codegen::isa::CallConv;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use crate::algebra::{BinaryOperator, Expression, Literal, UnaryOperator};
pub type VarIndexMap = HashMap<String, usize>;
#[derive(Debug, Clone)]
pub enum FilterExpr {
Literal(f64),
Variable(String),
BinOp {
op: BinOp,
left: Box<FilterExpr>,
right: Box<FilterExpr>,
},
UnaryNot(Box<FilterExpr>),
Builtin {
func: BuiltinFunc,
arg: Box<FilterExpr>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
Lt,
Gt,
Le,
Ge,
Eq,
Ne,
And,
Or,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BuiltinFunc {
Abs,
Ceil,
Floor,
Round,
}
const XSD_NUMERIC_TYPES: &[&str] = &[
"http://www.w3.org/2001/XMLSchema#integer",
"http://www.w3.org/2001/XMLSchema#decimal",
"http://www.w3.org/2001/XMLSchema#double",
"http://www.w3.org/2001/XMLSchema#float",
"http://www.w3.org/2001/XMLSchema#int",
"http://www.w3.org/2001/XMLSchema#long",
"http://www.w3.org/2001/XMLSchema#short",
"http://www.w3.org/2001/XMLSchema#byte",
"http://www.w3.org/2001/XMLSchema#unsignedInt",
"http://www.w3.org/2001/XMLSchema#unsignedLong",
"http://www.w3.org/2001/XMLSchema#nonNegativeInteger",
"http://www.w3.org/2001/XMLSchema#positiveInteger",
];
fn literal_to_f64(lit: &Literal) -> Option<f64> {
if lit.language.is_some() {
return None;
}
match &lit.datatype {
Some(dt) => {
let iri = dt.as_str();
if !XSD_NUMERIC_TYPES.contains(&iri) {
return None;
}
lit.value.parse::<f64>().ok()
}
None => {
if lit.value.parse::<f64>().is_ok() {
lit.value.parse::<f64>().ok()
} else {
None
}
}
}
}
pub fn try_lower(expr: &Expression) -> Option<(FilterExpr, VarIndexMap)> {
let mut var_map: VarIndexMap = HashMap::new();
let filter_expr = lower_expr(expr, &mut var_map)?;
Some((filter_expr, var_map))
}
fn lower_expr(expr: &Expression, var_map: &mut VarIndexMap) -> Option<FilterExpr> {
match expr {
Expression::Literal(lit) => {
let v = literal_to_f64(lit)?;
Some(FilterExpr::Literal(v))
}
Expression::Variable(var) => {
let name = var.name().to_string();
let next_idx = var_map.len();
let idx = *var_map.entry(name.clone()).or_insert(next_idx);
debug_assert_eq!(var_map[&name], idx);
Some(FilterExpr::Variable(name))
}
Expression::Binary { op, left, right } => {
let jit_op = match op {
BinaryOperator::Add => BinOp::Add,
BinaryOperator::Subtract => BinOp::Sub,
BinaryOperator::Multiply => BinOp::Mul,
BinaryOperator::Divide => BinOp::Div,
BinaryOperator::Less => BinOp::Lt,
BinaryOperator::Greater => BinOp::Gt,
BinaryOperator::LessEqual => BinOp::Le,
BinaryOperator::GreaterEqual => BinOp::Ge,
BinaryOperator::Equal => BinOp::Eq,
BinaryOperator::NotEqual => BinOp::Ne,
BinaryOperator::And => BinOp::And,
BinaryOperator::Or => BinOp::Or,
_ => return None,
};
let l = lower_expr(left, var_map)?;
let r = lower_expr(right, var_map)?;
Some(FilterExpr::BinOp {
op: jit_op,
left: Box::new(l),
right: Box::new(r),
})
}
Expression::Unary { op, operand } => match op {
UnaryOperator::Not => {
let inner = lower_expr(operand, var_map)?;
Some(FilterExpr::UnaryNot(Box::new(inner)))
}
UnaryOperator::Plus => lower_expr(operand, var_map),
UnaryOperator::Minus => {
let inner = lower_expr(operand, var_map)?;
Some(FilterExpr::BinOp {
op: BinOp::Sub,
left: Box::new(FilterExpr::Literal(0.0)),
right: Box::new(inner),
})
}
_ => None,
},
Expression::Function { name, args } => {
let func = match name.to_uppercase().as_str() {
"ABS" => BuiltinFunc::Abs,
"CEIL" => BuiltinFunc::Ceil,
"FLOOR" => BuiltinFunc::Floor,
"ROUND" => BuiltinFunc::Round,
_ => return None,
};
if args.len() != 1 {
return None;
}
let arg = lower_expr(&args[0], var_map)?;
Some(FilterExpr::Builtin {
func,
arg: Box::new(arg),
})
}
_ => None,
}
}
type FilterFn = unsafe extern "C" fn(*const f64, usize) -> i8;
pub struct CompiledFilter {
fn_ptr: FilterFn,
pub var_map: VarIndexMap,
_module_owner: Arc<JITModuleOwner>,
}
unsafe impl Send for CompiledFilter {}
unsafe impl Sync for CompiledFilter {}
impl std::fmt::Debug for CompiledFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledFilter")
.field("var_map", &self.var_map)
.finish_non_exhaustive()
}
}
impl CompiledFilter {
pub fn evaluate(&self, binding: &HashMap<String, f64>) -> Option<bool> {
let n = self.var_map.len();
let mut values = vec![0.0f64; n];
for (name, &idx) in &self.var_map {
match binding.get(name) {
Some(&v) => values[idx] = v,
None => return None,
}
}
let result = unsafe { (self.fn_ptr)(values.as_ptr(), values.len()) };
match result {
1 => Some(true),
0 => Some(false),
_ => None,
}
}
}
pub(crate) struct JITModuleOwner {
#[allow(dead_code)]
module: std::sync::Mutex<JITModule>,
}
impl JITModuleOwner {
pub(crate) fn new(module: JITModule) -> Self {
JITModuleOwner {
module: std::sync::Mutex::new(module),
}
}
}
unsafe impl Send for JITModuleOwner {}
unsafe impl Sync for JITModuleOwner {}
#[derive(Debug, thiserror::Error)]
pub enum FilterCompilerError {
#[error("expression not in JIT-supported subset: {0}")]
UnsupportedExpression(String),
#[error("JIT codegen error: {0}")]
CodegenError(String),
#[error("JIT linkage error: {0}")]
LinkageError(String),
#[error("JIT ISA init error: {0}")]
IsaInitError(String),
}
pub struct FilterCompiler;
impl Default for FilterCompiler {
fn default() -> Self {
FilterCompiler
}
}
impl FilterCompiler {
pub fn new() -> Self {
FilterCompiler
}
pub fn compile(
&self,
expr: &FilterExpr,
var_map: VarIndexMap,
) -> Result<Option<CompiledFilter>, FilterCompilerError> {
let module = build_jit_module()?;
let (fn_ptr, module) = compile_filter_fn(module, expr, &var_map)?;
let owner = Arc::new(JITModuleOwner::new(module));
Ok(Some(CompiledFilter {
fn_ptr,
var_map,
_module_owner: owner,
}))
}
}
fn build_jit_module() -> Result<JITModule, FilterCompilerError> {
let mut flag_builder = settings::builder();
flag_builder
.set("use_colocated_libcalls", "false")
.map_err(|e| FilterCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("is_pic", "false")
.map_err(|e| FilterCompilerError::CodegenError(e.to_string()))?;
flag_builder
.set("opt_level", "speed")
.map_err(|e| FilterCompilerError::CodegenError(e.to_string()))?;
let flags = settings::Flags::new(flag_builder);
let isa = cranelift_native::builder()
.map_err(|e| FilterCompilerError::IsaInitError(e.to_string()))?
.finish(flags)
.map_err(|e| FilterCompilerError::IsaInitError(e.to_string()))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
Ok(JITModule::new(builder))
}
fn compile_filter_fn(
mut module: JITModule,
expr: &FilterExpr,
var_map: &VarIndexMap,
) -> Result<(FilterFn, JITModule), FilterCompilerError> {
let ptr_type = module.isa().pointer_type();
let mut sig = Signature::new(CallConv::SystemV);
sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.returns.push(AbiParam::new(types::I8));
let func_id = module
.declare_function("filter_fn", Linkage::Local, &sig)
.map_err(|e| FilterCompilerError::LinkageError(e.to_string()))?;
{
let mut ctx = module.make_context();
ctx.func.signature = sig.clone();
let mut fn_builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let ptr_val = builder.block_params(entry_block)[0];
let _n_val = builder.block_params(entry_block)[1];
let result_val = emit_expr(&mut builder, expr, var_map, ptr_val, ptr_type)?;
let result_i8 = coerce_to_i8(&mut builder, result_val)?;
builder.ins().return_(&[result_i8]);
builder.finalize();
module
.define_function(func_id, &mut ctx)
.map_err(|e| FilterCompilerError::CodegenError(format!("{e:?}")))?;
}
module
.finalize_definitions()
.map_err(|e| FilterCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
let raw_ptr = module.get_finalized_function(func_id);
let fn_ptr: FilterFn = unsafe { std::mem::transmute(raw_ptr) };
Ok((fn_ptr, module))
}
fn emit_expr(
builder: &mut FunctionBuilder<'_>,
expr: &FilterExpr,
var_map: &VarIndexMap,
ptr_val: Value,
ptr_type: types::Type,
) -> Result<Value, FilterCompilerError> {
match expr {
FilterExpr::Literal(v) => {
let val = builder.ins().f64const(*v);
Ok(val)
}
FilterExpr::Variable(name) => {
let idx = var_map.get(name).copied().ok_or_else(|| {
FilterCompilerError::UnsupportedExpression(format!(
"variable '{}' not found in var_map",
name
))
})?;
let byte_offset = (idx * std::mem::size_of::<f64>()) as i32;
let val = builder
.ins()
.load(types::F64, MemFlags::trusted(), ptr_val, byte_offset);
Ok(val)
}
FilterExpr::BinOp { op, left, right } => {
emit_binop(builder, *op, left, right, var_map, ptr_val, ptr_type)
}
FilterExpr::UnaryNot(inner) => {
let inner_val = emit_expr(builder, inner, var_map, ptr_val, ptr_type)?;
let bool_val = coerce_to_i8(builder, inner_val)?;
let notted = builder.ins().bxor_imm(bool_val, 1);
Ok(notted)
}
FilterExpr::Builtin { func, arg } => {
let arg_val = emit_expr(builder, arg, var_map, ptr_val, ptr_type)?;
let f_val = coerce_to_f64(builder, arg_val)?;
let result = match func {
BuiltinFunc::Abs => builder.ins().fabs(f_val),
BuiltinFunc::Ceil => builder.ins().ceil(f_val),
BuiltinFunc::Floor => builder.ins().floor(f_val),
BuiltinFunc::Round => emit_sparql_round(builder, f_val),
};
Ok(result)
}
}
}
fn emit_sparql_round(builder: &mut FunctionBuilder<'_>, x: Value) -> Value {
let half = builder.ins().f64const(0.5);
let shifted = builder.ins().fadd(x, half);
builder.ins().floor(shifted)
}
fn emit_binop(
builder: &mut FunctionBuilder<'_>,
op: BinOp,
left: &FilterExpr,
right: &FilterExpr,
var_map: &VarIndexMap,
ptr_val: Value,
ptr_type: types::Type,
) -> Result<Value, FilterCompilerError> {
match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
let lv = emit_expr(builder, left, var_map, ptr_val, ptr_type)?;
let rv = emit_expr(builder, right, var_map, ptr_val, ptr_type)?;
let lf = coerce_to_f64(builder, lv)?;
let rf = coerce_to_f64(builder, rv)?;
let result = match op {
BinOp::Add => builder.ins().fadd(lf, rf),
BinOp::Sub => builder.ins().fsub(lf, rf),
BinOp::Mul => builder.ins().fmul(lf, rf),
BinOp::Div => builder.ins().fdiv(lf, rf),
_ => unreachable!(),
};
Ok(result)
}
BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge | BinOp::Eq | BinOp::Ne => {
let lv = emit_expr(builder, left, var_map, ptr_val, ptr_type)?;
let rv = emit_expr(builder, right, var_map, ptr_val, ptr_type)?;
let lf = coerce_to_f64(builder, lv)?;
let rf = coerce_to_f64(builder, rv)?;
let float_cc = match op {
BinOp::Lt => FloatCC::LessThan,
BinOp::Gt => FloatCC::GreaterThan,
BinOp::Le => FloatCC::LessThanOrEqual,
BinOp::Ge => FloatCC::GreaterThanOrEqual,
BinOp::Eq => FloatCC::Equal,
BinOp::Ne => FloatCC::NotEqual,
_ => unreachable!(),
};
let cmp = builder.ins().fcmp(float_cc, lf, rf);
Ok(cmp)
}
BinOp::And | BinOp::Or => {
let lv = emit_expr(builder, left, var_map, ptr_val, ptr_type)?;
let rv = emit_expr(builder, right, var_map, ptr_val, ptr_type)?;
let lb = coerce_to_i8(builder, lv)?;
let rb = coerce_to_i8(builder, rv)?;
let result = match op {
BinOp::And => builder.ins().band(lb, rb),
BinOp::Or => builder.ins().bor(lb, rb),
_ => unreachable!(),
};
Ok(result)
}
}
}
fn coerce_to_i8(
builder: &mut FunctionBuilder<'_>,
val: Value,
) -> Result<Value, FilterCompilerError> {
let ty = builder.func.dfg.value_type(val);
if ty == types::I8 {
return Ok(val);
}
if ty == types::F64 {
let zero = builder.ins().f64const(0.0);
let cmp = builder.ins().fcmp(FloatCC::OrderedNotEqual, val, zero);
return Ok(cmp);
}
Err(FilterCompilerError::CodegenError(format!(
"coerce_to_i8: unexpected type {:?}",
ty
)))
}
fn coerce_to_f64(
builder: &mut FunctionBuilder<'_>,
val: Value,
) -> Result<Value, FilterCompilerError> {
let ty = builder.func.dfg.value_type(val);
if ty == types::F64 {
return Ok(val);
}
if ty == types::I8 {
let as_i32 = builder.ins().uextend(types::I32, val);
let as_f64 = builder.ins().fcvt_from_uint(types::F64, as_i32);
return Ok(as_f64);
}
Err(FilterCompilerError::CodegenError(format!(
"coerce_to_f64: unexpected type {:?}",
ty
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{BinaryOperator, Expression, Literal, UnaryOperator};
use oxirs_core::model::{NamedNode, Variable as CoreVariable};
fn xsd_integer() -> NamedNode {
NamedNode::new("http://www.w3.org/2001/XMLSchema#integer").expect("valid XSD URI")
}
fn int_lit(v: i64) -> Expression {
Expression::Literal(Literal::typed(v.to_string(), xsd_integer()))
}
fn var(name: &str) -> Expression {
Expression::Variable(CoreVariable::new(name).expect("valid variable"))
}
#[test]
fn lower_integer_literal() {
let expr = int_lit(42);
let (fe, vm) = try_lower(&expr).expect("should lower");
assert!(vm.is_empty());
assert!(matches!(fe, FilterExpr::Literal(v) if (v - 42.0).abs() < 1e-9));
}
#[test]
fn lower_variable() {
let expr = var("x");
let (fe, vm) = try_lower(&expr).expect("should lower");
assert_eq!(vm.len(), 1);
assert!(vm.contains_key("x"));
assert!(matches!(fe, FilterExpr::Variable(n) if n == "x"));
}
#[test]
fn lower_comparison() {
let expr = Expression::Binary {
op: BinaryOperator::Greater,
left: Box::new(var("x")),
right: Box::new(int_lit(5)),
};
let (fe, _vm) = try_lower(&expr).expect("should lower");
assert!(matches!(fe, FilterExpr::BinOp { op: BinOp::Gt, .. }));
}
#[test]
fn lower_lang_tagged_literal_fails() {
let expr = Expression::Literal(Literal::with_language(
"hello".to_string(),
"en".to_string(),
));
assert!(
try_lower(&expr).is_none(),
"lang-tagged literal should not lower"
);
}
#[test]
fn lower_iri_fails() {
let expr = Expression::Iri(xsd_integer());
assert!(try_lower(&expr).is_none(), "IRI should not lower");
}
#[test]
fn lower_unary_not() {
let inner = Expression::Binary {
op: BinaryOperator::Greater,
left: Box::new(var("x")),
right: Box::new(int_lit(0)),
};
let expr = Expression::Unary {
op: UnaryOperator::Not,
operand: Box::new(inner),
};
let (fe, _vm) = try_lower(&expr).expect("should lower");
assert!(matches!(fe, FilterExpr::UnaryNot(_)));
}
#[test]
fn lower_abs_builtin() {
let expr = Expression::Function {
name: "ABS".to_string(),
args: vec![var("x")],
};
let (fe, _) = try_lower(&expr).expect("should lower");
assert!(matches!(
fe,
FilterExpr::Builtin {
func: BuiltinFunc::Abs,
..
}
));
}
}