use super::ir::*;
use crate::{StarResult, StarStore, StarTriple};
use scirs2_core::jit::{
CompilationHints, CompiledKernel, ComputeIntensity, DataType, JitBackend, JitCompiler,
JitConfig, JitError, KernelLanguage, KernelSource, MemoryPattern, OptimizationLevel,
ParallelizationHints, TargetArchitecture,
};
use std::collections::HashMap;
use std::time::Instant;
use tracing::{debug, info, instrument};
pub struct SparqlJitCompiler {
jit_compiler: JitCompiler,
kernel_cache: HashMap<String, CompiledKernel>,
}
impl SparqlJitCompiler {
pub fn new() -> Result<Self, JitError> {
let config = JitConfig {
backend: JitBackend::Llvm,
target_arch: Self::detect_architecture(),
optimization_level: OptimizationLevel::O3,
enable_caching: true,
enable_profiling: true,
max_cache_size: 256 * 1024 * 1024, compilation_timeout: std::time::Duration::from_secs(30),
adaptive_optimization: true,
custom_flags: Vec::new(),
};
let jit_compiler = JitCompiler::new(config)?;
Ok(Self {
jit_compiler,
kernel_cache: HashMap::new(),
})
}
fn detect_architecture() -> TargetArchitecture {
#[cfg(target_arch = "x86_64")]
{
TargetArchitecture::X86_64
}
#[cfg(target_arch = "aarch64")]
{
TargetArchitecture::Arm64
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
TargetArchitecture::Generic
}
}
#[instrument(skip(self))]
pub fn parse_to_ir(&self, query: &str) -> Result<IrQueryPlan, JitError> {
debug!("Parsing SPARQL-star query to IR");
let ir_op = if query.contains("<<") && query.contains(">>") {
self.parse_quoted_pattern(query)?
} else if query.contains("?s") && query.contains("?p") && query.contains("?o") {
IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
}
} else {
IrOp::SeqScan
};
let mut plan = IrQueryPlan::new(ir_op);
plan.estimate_cost();
plan.analyze_parallelism();
Ok(plan)
}
fn parse_quoted_pattern(&self, _query: &str) -> Result<IrOp, JitError> {
let inner = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
};
Ok(IrOp::QuotedTriplePattern {
inner: Box::new(inner),
position: QuotePosition::Subject,
})
}
#[instrument(skip(self, plan))]
pub fn compile_ir(&mut self, plan: &IrQueryPlan) -> Result<String, JitError> {
let start = Instant::now();
let kernel_source = self.generate_kernel_source(plan)?;
info!(
"Generated kernel source ({} bytes) in {:?}",
kernel_source.source.len(),
start.elapsed()
);
let kernel_id = self.jit_compiler.compile_kernel(kernel_source)?;
info!("JIT compilation complete in {:?}", start.elapsed());
Ok(kernel_id)
}
fn generate_kernel_source(&self, plan: &IrQueryPlan) -> Result<KernelSource, JitError> {
let mut source = String::new();
source.push_str("// SPARQL-star JIT-compiled query\n");
source.push_str("// Generated from IR\n\n");
source.push_str("fn execute_query(store: &Store) -> Vec<Triple> {\n");
source.push_str(" let mut results = Vec::new();\n\n");
self.codegen_ir_op(&plan.root, &mut source, 1)?;
source.push_str("\n results\n");
source.push_str("}\n");
debug!("Generated kernel source:\n{}", source);
let kernel_id = format!("sparql_query_{:x}", Self::hash_query(&plan.root));
let hints = CompilationHints {
workload_size: Some(plan.memory_hints.estimated_results),
memory_pattern: if plan.memory_hints.sequential_access {
Some(MemoryPattern::Sequential)
} else {
Some(MemoryPattern::Random)
},
compute_intensity: Some(ComputeIntensity::Balanced),
parallelization: Some(ParallelizationHints {
work_group_size: None,
vector_width: Some(4), unroll_factor: Some(4),
auto_vectorize: true,
}),
target_hints: HashMap::new(),
};
Ok(KernelSource {
id: kernel_id,
source,
language: KernelLanguage::LlvmIr, entry_point: "execute_query".to_string(),
input_types: vec![DataType::Ptr(Box::new(DataType::U8))], output_types: vec![DataType::Ptr(Box::new(DataType::U8))], hints,
})
}
#[allow(clippy::only_used_in_recursion)]
fn codegen_ir_op(&self, op: &IrOp, output: &mut String, indent: usize) -> Result<(), JitError> {
let indent_str = " ".repeat(indent);
match op {
IrOp::TriplePattern {
subject,
predicate,
object,
} => {
output.push_str(&format!("{}// Triple pattern scan\n", indent_str));
output.push_str(&format!(
"{}for triple in store.scan_spo({:?}, {:?}, {:?}) {{\n",
indent_str, subject, predicate, object
));
output.push_str(&format!("{} results.push(triple);\n", indent_str));
output.push_str(&format!("{}}}\n", indent_str));
}
IrOp::QuotedTriplePattern { inner, position } => {
output.push_str(&format!("{}// Quoted triple pattern\n", indent_str));
output.push_str(&format!(
"{}for triple in store.scan_quoted({:?}) {{\n",
indent_str, position
));
self.codegen_ir_op(inner, output, indent + 1)?;
output.push_str(&format!("{}}}\n", indent_str));
}
IrOp::Join { left, right, .. } => {
output.push_str(&format!("{}// Join operation\n", indent_str));
output.push_str(&format!(
"{}let mut left_results = Vec::new();\n",
indent_str
));
self.codegen_ir_op(left, output, indent)?;
output.push_str(&format!(
"{}let mut right_results = Vec::new();\n",
indent_str
));
self.codegen_ir_op(right, output, indent)?;
output.push_str(&format!(
"{}results.extend(join(left_results, right_results));\n",
indent_str
));
}
IrOp::Filter { condition } => {
output.push_str(&format!("{}// Filter: {:?}\n", indent_str, condition));
output.push_str(&format!("{}results.retain(|t| filter(t));\n", indent_str));
}
IrOp::IndexScan { index_type, keys } => {
output.push_str(&format!(
"{}// Index scan: {:?} with keys {:?}\n",
indent_str, index_type, keys
));
output.push_str(&format!(
"{}for triple in store.index_scan({:?}, &keys) {{\n",
indent_str, index_type
));
output.push_str(&format!("{} results.push(triple);\n", indent_str));
output.push_str(&format!("{}}}\n", indent_str));
}
IrOp::SeqScan => {
output.push_str(&format!("{}// Sequential scan (fallback)\n", indent_str));
output.push_str(&format!(
"{}results.extend(store.all_triples());\n",
indent_str
));
}
_ => {
output.push_str(&format!("{}// Unsupported IR op: {:?}\n", indent_str, op));
}
}
Ok(())
}
#[instrument(skip(self, store))]
pub fn execute_compiled(
&self,
kernel_id: &str,
store: &StarStore,
) -> StarResult<Vec<StarTriple>> {
debug!("Executing compiled kernel: {}", kernel_id);
Ok(store.all_triples())
}
pub fn get_cached_kernel(&self, kernel_id: &str) -> Option<&CompiledKernel> {
self.kernel_cache.get(kernel_id)
}
pub fn clear_cache(&mut self) {
self.kernel_cache.clear();
self.jit_compiler.clear_cache();
}
pub fn stats(&self) -> CompilationStats {
CompilationStats {
total_compilations: self.kernel_cache.len(),
cache_size: self.kernel_cache.len(),
total_compilation_time_ms: 0, }
}
fn hash_query(op: &IrOp) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
format!("{:?}", op).hash(&mut hasher);
hasher.finish()
}
}
impl Default for SparqlJitCompiler {
fn default() -> Self {
Self::new().expect("Failed to create JIT compiler")
}
}
#[derive(Debug, Clone)]
pub struct CompilationStats {
pub total_compilations: usize,
pub cache_size: usize,
pub total_compilation_time_ms: u64,
}
pub struct QueryParser;
impl QueryParser {
pub fn parse(query: &str) -> Result<IrQueryPlan, JitError> {
let compiler = SparqlJitCompiler::new()?;
compiler.parse_to_ir(query)
}
pub fn complexity(query: &str) -> f64 {
let mut cost: f64 = 0.0;
if query.contains("<<") && query.contains(">>") {
cost += 15.0; }
if query.contains("JOIN") || query.contains("OPTIONAL") {
cost += 50.0; }
if query.contains("FILTER") {
cost += 5.0;
}
if query.contains("UNION") {
cost += 20.0;
}
cost.max(10.0) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compiler_creation() {
let result = SparqlJitCompiler::new();
assert!(result.is_ok());
}
#[test]
fn test_parse_simple_pattern() {
let compiler = SparqlJitCompiler::new().unwrap();
let query = "SELECT * WHERE { ?s ?p ?o }";
let result = compiler.parse_to_ir(query);
assert!(result.is_ok());
let plan = result.unwrap();
assert!(plan.estimated_cost > 0.0);
}
#[test]
fn test_parse_quoted_pattern() {
let compiler = SparqlJitCompiler::new().unwrap();
let query = "SELECT * WHERE { << ?s ?p ?o >> ?meta ?value }";
let result = compiler.parse_to_ir(query);
assert!(result.is_ok());
let plan = result.unwrap();
assert!(plan.estimated_cost > 10.0); }
#[test]
fn test_compile_ir() {
let mut compiler = SparqlJitCompiler::new().unwrap();
let pattern = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
};
let plan = IrQueryPlan::new(pattern);
let result = compiler.compile_ir(&plan);
assert!(result.is_ok());
}
#[test]
fn test_query_parser_complexity() {
assert_eq!(QueryParser::complexity("SELECT * WHERE { ?s ?p ?o }"), 10.0);
assert_eq!(
QueryParser::complexity("SELECT * WHERE { << ?s ?p ?o >> ?m ?v }"),
15.0
);
let complexity =
QueryParser::complexity("SELECT * WHERE { ?s ?p ?o . ?s ?p2 ?o2 } FILTER(?s)");
assert!(
complexity >= 10.0,
"Expected complexity >= 10.0, got {}",
complexity
);
}
#[test]
fn test_codegen_triple_pattern() {
let compiler = SparqlJitCompiler::new().unwrap();
let pattern = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
};
let mut output = String::new();
let result = compiler.codegen_ir_op(&pattern, &mut output, 1);
assert!(result.is_ok());
assert!(output.contains("scan_spo"));
assert!(output.contains("results.push"));
}
#[test]
fn test_codegen_join() {
let compiler = SparqlJitCompiler::new().unwrap();
let left = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
};
let right = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Iri("http://ex.org/name".to_string()),
object: IrTerm::Variable("name".to_string()),
};
let join = IrOp::Join {
left: Box::new(left),
right: Box::new(right),
join_type: JoinType::Inner,
};
let mut output = String::new();
let result = compiler.codegen_ir_op(&join, &mut output, 1);
assert!(result.is_ok());
assert!(output.contains("Join operation"));
assert!(output.contains("left_results"));
assert!(output.contains("right_results"));
}
#[test]
fn test_kernel_source_generation() {
let compiler = SparqlJitCompiler::new().unwrap();
let pattern = IrOp::TriplePattern {
subject: IrTerm::Variable("s".to_string()),
predicate: IrTerm::Variable("p".to_string()),
object: IrTerm::Variable("o".to_string()),
};
let plan = IrQueryPlan::new(pattern);
let result = compiler.generate_kernel_source(&plan);
assert!(result.is_ok());
let source = result.unwrap();
assert!(source.source.contains("execute_query"));
assert!(source.source.contains("results"));
}
}