use crate::algebra::{Solution, Term, Variable};
use anyhow::Result;
use std::alloc::{alloc, dealloc, Layout};
use std::collections::HashMap;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
pub struct LockFreeWorkStealingQueue<T> {
buffer: AtomicPtr<T>,
capacity: usize,
head: AtomicUsize,
tail: AtomicUsize,
mask: usize,
}
impl<T> LockFreeWorkStealingQueue<T> {
pub fn new(capacity: usize) -> Self {
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let layout = Layout::array::<T>(capacity).expect("Invalid layout");
let buffer = unsafe { alloc(layout) as *mut T };
Self {
buffer: AtomicPtr::new(buffer),
capacity,
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
mask,
}
}
pub fn push(&self, item: T) -> Result<()> {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail - head >= self.capacity {
return Err(anyhow::anyhow!("Work queue is full"));
}
unsafe {
let buffer = self.buffer.load(Ordering::Relaxed);
let index = tail & self.mask;
std::ptr::write(buffer.add(index), item);
}
self.tail.store(tail + 1, Ordering::Release);
Ok(())
}
pub fn pop(&self) -> Option<T> {
let tail = self.tail.load(Ordering::Relaxed);
if tail == 0 {
return None;
}
let new_tail = tail - 1;
self.tail.store(new_tail, Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if new_tail > head {
unsafe {
let buffer = self.buffer.load(Ordering::Relaxed);
let index = new_tail & self.mask;
Some(std::ptr::read(buffer.add(index)))
}
} else if new_tail == head {
if self
.head
.compare_exchange_weak(head, head + 1, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
unsafe {
let buffer = self.buffer.load(Ordering::Relaxed);
let index = head & self.mask;
Some(std::ptr::read(buffer.add(index)))
}
} else {
self.tail.store(tail, Ordering::Relaxed);
None
}
} else {
self.tail.store(tail, Ordering::Relaxed);
None
}
}
pub fn steal(&self) -> Option<T> {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
if head >= tail {
return None;
}
unsafe {
let buffer = self.buffer.load(Ordering::Relaxed);
let index = head & self.mask;
let item = std::ptr::read(buffer.add(index));
if self
.head
.compare_exchange_weak(head, head + 1, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
Some(item)
} else {
std::mem::forget(item); None
}
}
}
pub fn is_empty(&self) -> bool {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
head >= tail
}
pub fn len(&self) -> usize {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Relaxed);
tail.saturating_sub(head)
}
}
impl<T> Drop for LockFreeWorkStealingQueue<T> {
fn drop(&mut self) {
while self.pop().is_some() {}
let buffer = self.buffer.load(Ordering::Relaxed);
if !buffer.is_null() {
unsafe {
let layout = Layout::array::<T>(self.capacity).expect("Invalid layout");
dealloc(buffer as *mut u8, layout);
}
}
}
}
pub struct MemoryPool<T> {
available: LockFreeWorkStealingQueue<Box<T>>,
factory: fn() -> T,
max_size: usize,
current_size: AtomicUsize,
}
impl<T> MemoryPool<T> {
pub fn new(initial_size: usize, max_size: usize, factory: fn() -> T) -> Self {
let pool = Self {
available: LockFreeWorkStealingQueue::new(max_size),
factory,
max_size,
current_size: AtomicUsize::new(0),
};
for _ in 0..initial_size {
let obj = Box::new(factory());
let _ = pool.available.push(obj);
pool.current_size.store(initial_size, Ordering::Relaxed);
}
pool
}
pub fn acquire(&self) -> PooledObject<'_, T> {
match self.available.steal() {
Some(obj) => PooledObject {
object: Some(obj),
pool: self,
},
_ => {
let obj = Box::new((self.factory)());
PooledObject {
object: Some(obj),
pool: self,
}
}
}
}
fn return_object(&self, obj: Box<T>) {
let current = self.current_size.load(Ordering::Relaxed);
if current < self.max_size && self.available.push(obj).is_ok() {
self.current_size.fetch_add(1, Ordering::Relaxed);
}
}
}
pub struct PooledObject<'a, T> {
object: Option<Box<T>>,
pool: &'a MemoryPool<T>,
}
impl<'a, T> PooledObject<'a, T> {
pub fn get_mut(&mut self) -> &mut T {
self.object
.as_mut()
.expect("pooled object should be present")
}
pub fn get(&self) -> &T {
self.object
.as_ref()
.expect("pooled object should be present")
}
}
impl<'a, T> Drop for PooledObject<'a, T> {
fn drop(&mut self) {
if let Some(obj) = self.object.take() {
self.pool.return_object(obj);
}
}
}
pub struct CacheFriendlyHashJoin {
num_partitions: usize,
#[allow(dead_code)]
radix_bits: u32,
hash_table_pool: MemoryPool<HashMap<u64, Vec<Solution>>>,
}
impl CacheFriendlyHashJoin {
pub fn new(num_partitions: usize) -> Self {
let num_partitions = num_partitions.next_power_of_two();
let radix_bits = num_partitions.trailing_zeros();
Self {
num_partitions,
radix_bits,
hash_table_pool: MemoryPool::new(num_partitions, num_partitions * 2, || {
HashMap::with_capacity(1024)
}),
}
}
pub fn join_parallel(
&self,
left_solutions: Vec<Solution>,
right_solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
let left_partitions = self.partition_solutions(left_solutions, join_variables)?;
let right_partitions = self.partition_solutions(right_solutions, join_variables)?;
let results: Vec<_> = (0..self.num_partitions)
.map(|i| self.join_partition(&left_partitions[i], &right_partitions[i], join_variables))
.collect::<Result<Vec<_>>>()?;
Ok(results.into_iter().flatten().collect())
}
fn partition_solutions(
&self,
solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Vec<Solution>>> {
let mut partitions = vec![Vec::new(); self.num_partitions];
for solution in solutions {
let hash = self.compute_join_key_hash(&solution, join_variables);
let partition_id = (hash as usize) & (self.num_partitions - 1);
partitions[partition_id].push(solution);
}
Ok(partitions)
}
fn join_partition(
&self,
left_partition: &[Solution],
right_partition: &[Solution],
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
if left_partition.is_empty() || right_partition.is_empty() {
return Ok(Vec::new());
}
let (build_side, probe_side, build_left) = if left_partition.len() <= right_partition.len()
{
(left_partition, right_partition, true)
} else {
(right_partition, left_partition, false)
};
let mut hash_table = self.hash_table_pool.acquire();
hash_table.get_mut().clear();
for solution in build_side {
let key = self.compute_join_key_hash(solution, join_variables);
hash_table
.get_mut()
.entry(key)
.or_default()
.push(solution.clone());
}
let mut results = Vec::new();
for probe_solution in probe_side {
let key = self.compute_join_key_hash(probe_solution, join_variables);
if let Some(build_solutions) = hash_table.get().get(&key) {
for build_solution in build_solutions {
if self.solutions_join_compatible(
build_solution,
probe_solution,
join_variables,
) {
let joined = if build_left {
self.merge_solutions(build_solution, probe_solution)?
} else {
self.merge_solutions(probe_solution, build_solution)?
};
results.push(joined);
}
}
}
}
Ok(results)
}
fn compute_join_key_hash(&self, solution: &Solution, join_variables: &[Variable]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for binding in solution {
for var in join_variables {
if let Some(term) = binding.get(var) {
term.hash(&mut hasher);
}
}
}
hasher.finish()
}
fn solutions_join_compatible(
&self,
left: &Solution,
right: &Solution,
join_variables: &[Variable],
) -> bool {
for left_binding in left {
for right_binding in right {
for var in join_variables {
if let (Some(left_term), Some(right_term)) =
(left_binding.get(var), right_binding.get(var))
{
if left_term != right_term {
return false;
}
}
}
}
}
true
}
fn merge_solutions(&self, left: &Solution, right: &Solution) -> Result<Solution> {
let mut result = Vec::new();
for left_binding in left {
for right_binding in right {
let mut merged_binding = left_binding.clone();
for (var, term) in right_binding {
if !merged_binding.contains_key(var) {
merged_binding.insert(var.clone(), term.clone());
}
}
result.push(merged_binding);
}
}
Ok(result)
}
}
pub struct SIMDOptimizedOps;
impl SIMDOptimizedOps {
#[cfg(target_feature = "sse2")]
pub fn bulk_string_compare(strings: &[String], pattern: &str) -> Vec<bool> {
use rayon::prelude::*;
strings
.par_chunks(256) .flat_map(|chunk| {
chunk
.iter()
.map(|s| s.contains(pattern))
.collect::<Vec<_>>()
})
.collect()
}
#[cfg(not(target_feature = "sse2"))]
pub fn bulk_string_compare(strings: &[String], pattern: &str) -> Vec<bool> {
use rayon::prelude::*;
strings.par_iter().map(|s| s.contains(pattern)).collect()
}
pub fn bulk_hash_compute(terms: &[Term]) -> Vec<u64> {
use rayon::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
terms
.par_chunks(1024) .flat_map(|chunk| {
chunk
.iter()
.map(|term| {
let mut hasher = DefaultHasher::new();
term.hash(&mut hasher);
hasher.finish()
})
.collect::<Vec<_>>()
})
.collect()
}
pub fn parallel_count_aggregate(
solutions: &[Solution],
group_var: &Variable,
) -> HashMap<Term, usize> {
use rayon::prelude::*;
solutions
.par_iter()
.flat_map(|solution| {
solution
.par_iter()
.filter_map(|binding| binding.get(group_var).map(|term| (term.clone(), 1)))
})
.fold(HashMap::new, |mut acc, (term, count)| {
*acc.entry(term).or_insert(0) += count;
acc
})
.reduce(HashMap::new, |mut acc1, acc2| {
for (term, count) in acc2 {
*acc1.entry(term).or_insert(0) += count;
}
acc1
})
}
pub fn bulk_equality_check(terms1: &[Term], terms2: &[Term]) -> Vec<bool> {
use rayon::prelude::*;
terms1
.par_iter()
.zip(terms2.par_iter())
.map(|(t1, t2)| t1 == t2)
.collect()
}
pub fn bulk_numeric_sum(literals: &[crate::algebra::Literal]) -> Result<f64> {
use rayon::prelude::*;
literals
.par_iter()
.map(|lit| lit.value.parse::<f64>())
.try_fold(|| 0.0, |acc, val| val.map(|v| acc + v))
.try_reduce(|| 0.0, |a, b| Ok(a + b))
.map_err(|e| anyhow::anyhow!("Failed to parse numeric value: {}", e))
}
pub fn bulk_filter_solutions(
solutions: &[Solution],
predicate: fn(&Solution) -> bool,
) -> Vec<Solution> {
use rayon::prelude::*;
solutions
.par_iter()
.filter(|solution| predicate(solution))
.cloned()
.collect()
}
pub fn bulk_project_solutions(solutions: &[Solution], variables: &[Variable]) -> Vec<Solution> {
use rayon::prelude::*;
solutions
.par_iter()
.map(|solution| {
solution
.iter()
.map(|binding| {
let mut projected_binding = HashMap::new();
for var in variables {
if let Some(term) = binding.get(var) {
projected_binding.insert(var.clone(), term.clone());
}
}
projected_binding
})
.collect()
})
.collect()
}
pub fn bulk_deduplicate_solutions(solutions: Vec<Solution>) -> Vec<Solution> {
use rayon::prelude::*;
use std::collections::HashSet;
use std::sync::Mutex;
let seen = Mutex::new(HashSet::new());
solutions
.into_par_iter()
.filter(|solution| {
let solution_hash = Self::compute_solution_hash(solution);
let mut seen_set = seen.lock().expect("lock should not be poisoned");
seen_set.insert(solution_hash)
})
.collect()
}
fn compute_solution_hash(solution: &Solution) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for binding in solution {
let mut sorted_items: Vec<_> = binding.iter().collect();
sorted_items.sort_by(|a, b| a.0.cmp(b.0));
sorted_items.hash(&mut hasher);
}
hasher.finish()
}
}
pub struct SortMergeJoin {
#[allow(dead_code)]
memory_threshold: usize,
#[allow(dead_code)]
temp_dir: Option<std::path::PathBuf>,
}
impl SortMergeJoin {
pub fn new(memory_threshold: usize) -> Self {
Self {
memory_threshold,
temp_dir: None,
}
}
pub fn with_temp_dir(memory_threshold: usize, temp_dir: std::path::PathBuf) -> Self {
Self {
memory_threshold,
temp_dir: Some(temp_dir),
}
}
pub fn join(
&self,
left_solutions: Vec<Solution>,
right_solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
let sorted_left = self.sort_solutions(left_solutions, join_variables)?;
let sorted_right = self.sort_solutions(right_solutions, join_variables)?;
self.merge_sorted_solutions(sorted_left, sorted_right, join_variables)
}
fn sort_solutions(
&self,
mut solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
solutions.sort_by(|a, b| self.compare_solutions_by_join_key(a, b, join_variables));
Ok(solutions)
}
fn compare_solutions_by_join_key(
&self,
left: &Solution,
right: &Solution,
join_variables: &[Variable],
) -> std::cmp::Ordering {
use std::cmp::Ordering;
let left_binding = left.first();
let right_binding = right.first();
match (left_binding, right_binding) {
(Some(l_binding), Some(r_binding)) => {
for var in join_variables {
let left_term = l_binding.get(var);
let right_term = r_binding.get(var);
let cmp = match (left_term, right_term) {
(Some(l), Some(r)) => self.compare_terms(l, r),
(Some(_), None) => Ordering::Greater,
(None, Some(_)) => Ordering::Less,
(None, None) => Ordering::Equal,
};
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
}
(Some(_), None) => Ordering::Greater,
(None, Some(_)) => Ordering::Less,
(None, None) => Ordering::Equal,
}
}
fn compare_terms(&self, left: &Term, right: &Term) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (left, right) {
(Term::Literal(l), Term::Literal(r)) => {
if let (Ok(l_num), Ok(r_num)) = (l.value.parse::<f64>(), r.value.parse::<f64>()) {
l_num.partial_cmp(&r_num).unwrap_or(Ordering::Equal)
} else {
l.value.cmp(&r.value)
}
}
(Term::Iri(l), Term::Iri(r)) => l.as_str().cmp(r.as_str()),
(Term::BlankNode(l), Term::BlankNode(r)) => l.as_str().cmp(r.as_str()),
(Term::QuotedTriple(l), Term::QuotedTriple(r)) => {
format!("{l}").cmp(&format!("{r}"))
}
(Term::PropertyPath(l), Term::PropertyPath(r)) => {
format!("{l}").cmp(&format!("{r}"))
}
(
Term::Literal(_),
Term::Iri(_)
| Term::BlankNode(_)
| Term::QuotedTriple(_)
| Term::PropertyPath(_)
| Term::Variable(_),
) => Ordering::Less,
(Term::Iri(_), Term::Literal(_)) => Ordering::Greater,
(
Term::Iri(_),
Term::BlankNode(_)
| Term::QuotedTriple(_)
| Term::PropertyPath(_)
| Term::Variable(_),
) => Ordering::Less,
(Term::BlankNode(_), Term::Literal(_) | Term::Iri(_)) => Ordering::Greater,
(
Term::BlankNode(_),
Term::QuotedTriple(_) | Term::PropertyPath(_) | Term::Variable(_),
) => Ordering::Less,
(Term::QuotedTriple(_), Term::Literal(_) | Term::Iri(_) | Term::BlankNode(_)) => {
Ordering::Greater
}
(Term::QuotedTriple(_), Term::PropertyPath(_) | Term::Variable(_)) => Ordering::Less,
(
Term::PropertyPath(_),
Term::Literal(_) | Term::Iri(_) | Term::BlankNode(_) | Term::QuotedTriple(_),
) => Ordering::Greater,
(Term::PropertyPath(_), Term::Variable(_)) => Ordering::Less,
(Term::Variable(_), _) => Ordering::Greater, }
}
fn merge_sorted_solutions(
&self,
left_solutions: Vec<Solution>,
right_solutions: Vec<Solution>,
join_variables: &[Variable],
) -> Result<Vec<Solution>> {
let mut result = Vec::new();
let mut left_idx = 0;
let mut right_idx = 0;
while left_idx < left_solutions.len() && right_idx < right_solutions.len() {
let left_solution = &left_solutions[left_idx];
let right_solution = &right_solutions[right_idx];
let cmp =
self.compare_solutions_by_join_key(left_solution, right_solution, join_variables);
match cmp {
std::cmp::Ordering::Equal => {
let mut left_end = left_idx + 1;
while left_end < left_solutions.len()
&& self.compare_solutions_by_join_key(
left_solution,
&left_solutions[left_end],
join_variables,
) == std::cmp::Ordering::Equal
{
left_end += 1;
}
let mut right_end = right_idx + 1;
while right_end < right_solutions.len()
&& self.compare_solutions_by_join_key(
right_solution,
&right_solutions[right_end],
join_variables,
) == std::cmp::Ordering::Equal
{
right_end += 1;
}
for left_solution in left_solutions
.iter()
.skip(left_idx)
.take(left_end - left_idx)
{
for right_solution in right_solutions
.iter()
.skip(right_idx)
.take(right_end - right_idx)
{
if let Ok(Some(merged_solution)) = self.merge_solutions_if_compatible(
left_solution,
right_solution,
join_variables,
) {
result.push(merged_solution);
}
}
}
left_idx = left_end;
right_idx = right_end;
}
std::cmp::Ordering::Less => {
left_idx += 1;
}
std::cmp::Ordering::Greater => {
right_idx += 1;
}
}
}
Ok(result)
}
fn merge_solutions_if_compatible(
&self,
left: &Solution,
right: &Solution,
join_variables: &[Variable],
) -> Result<Option<Solution>> {
let mut result = Vec::new();
for left_binding in left {
for right_binding in right {
let mut compatible = true;
for var in join_variables {
if let (Some(left_term), Some(right_term)) =
(left_binding.get(var), right_binding.get(var))
{
if left_term != right_term {
compatible = false;
break;
}
}
}
if compatible {
let mut merged_binding = left_binding.clone();
for (var, term) in right_binding {
if !merged_binding.contains_key(var) {
merged_binding.insert(var.clone(), term.clone());
}
}
result.push(merged_binding);
}
}
}
if result.is_empty() {
Ok(None)
} else {
Ok(Some(result))
}
}
}
pub struct CacheFriendlyStorage {
columns: HashMap<Variable, Vec<Term>>,
row_count: usize,
}
impl Default for CacheFriendlyStorage {
fn default() -> Self {
Self::new()
}
}
impl CacheFriendlyStorage {
pub fn new() -> Self {
Self {
columns: HashMap::new(),
row_count: 0,
}
}
pub fn add_solutions(&mut self, solutions: &[Solution]) {
for solution in solutions {
for binding in solution {
for (var, term) in binding {
self.columns
.entry(var.clone())
.or_default()
.push(term.clone());
}
}
self.row_count += solution.len();
}
}
pub fn get_column(&self, var: &Variable) -> Option<&Vec<Term>> {
self.columns.get(var)
}
pub fn to_solutions(&self) -> Vec<Solution> {
let mut solutions = Vec::new();
if self.row_count == 0 {
return solutions;
}
for i in 0..self.row_count {
let mut binding = HashMap::new();
for (var, column) in &self.columns {
if let Some(term) = column.get(i) {
binding.insert(var.clone(), term.clone());
}
}
if !binding.is_empty() {
solutions.push(vec![binding]);
}
}
solutions
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::Variable;
use oxirs_core::model::NamedNode;
#[test]
fn test_lock_free_queue() {
let queue = LockFreeWorkStealingQueue::new(16);
queue.push(42).unwrap();
queue.push(43).unwrap();
assert_eq!(queue.pop(), Some(43));
assert_eq!(queue.pop(), Some(42));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_memory_pool() {
let pool = MemoryPool::new(2, 10, HashMap::<String, i32>::new);
let mut obj1 = pool.acquire();
obj1.get_mut().insert("test".to_string(), 42);
let obj2 = pool.acquire();
assert_ne!(obj1.get().len(), obj2.get().len());
}
#[test]
fn test_cache_friendly_hash_join() {
let join = CacheFriendlyHashJoin::new(4);
let var_x = Variable::new("x").unwrap();
let var_y = Variable::new("y").unwrap();
let mut left_binding = HashMap::new();
left_binding.insert(
var_x.clone(),
Term::Iri(NamedNode::new("http://example.org/1").unwrap()),
);
left_binding.insert(
var_y.clone(),
Term::Iri(NamedNode::new("http://example.org/a").unwrap()),
);
let left_solutions = vec![vec![left_binding]];
let mut right_binding = HashMap::new();
right_binding.insert(
var_x.clone(),
Term::Iri(NamedNode::new("http://example.org/1").unwrap()),
);
let right_solutions = vec![vec![right_binding]];
let results = join
.join_parallel(left_solutions, right_solutions, &[var_x])
.unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_simd_ops() {
let strings = vec![
"hello world".to_string(),
"foo bar".to_string(),
"hello rust".to_string(),
];
let results = SIMDOptimizedOps::bulk_string_compare(&strings, "hello");
assert_eq!(results, vec![true, false, true]);
}
#[test]
fn test_cache_friendly_storage() {
let mut storage = CacheFriendlyStorage::new();
let var_x = Variable::new("x").unwrap();
let mut binding = HashMap::new();
binding.insert(
var_x.clone(),
Term::Iri(NamedNode::new("http://example.org/1").unwrap()),
);
let solutions = vec![vec![binding]];
storage.add_solutions(&solutions);
assert!(storage.get_column(&var_x).is_some());
let recovered = storage.to_solutions();
assert_eq!(recovered.len(), 1);
}
}