use crate::{StarResult, StarTerm, StarTriple};
use rayon::prelude::*; use scirs2_core::parallel_ops::par_scope;
use scirs2_core::profiling::Profiler;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct ParallelQueryExecutor {
worker_count: usize,
profiler: Arc<Mutex<Profiler>>,
query_cache: Arc<Mutex<HashMap<String, QueryPlan>>>,
}
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub patterns: Vec<TriplePattern>,
pub joins: Vec<JoinOperation>,
pub filters: Vec<FilterOperation>,
pub cost: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TriplePattern {
pub subject: Option<StarTerm>,
pub predicate: Option<StarTerm>,
pub object: Option<StarTerm>,
pub variable_name: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JoinOperation {
pub left: usize,
pub right: usize,
pub join_type: JoinType,
pub join_vars: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinType {
Inner,
LeftOuter,
Optional,
}
#[derive(Debug, Clone)]
pub struct FilterOperation {
pub expression: FilterExpression,
}
#[derive(Debug, Clone)]
pub enum FilterExpression {
Equals(String, StarTerm),
Regex(String, String),
Bound(String),
NestingDepth(String, usize, usize),
}
#[derive(Debug, Clone)]
pub struct QueryBinding {
pub bindings: HashMap<String, StarTerm>,
}
impl ParallelQueryExecutor {
pub fn new() -> Self {
Self {
worker_count: num_cpus::get(),
profiler: Arc::new(Mutex::new(Profiler::new())),
query_cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_workers(worker_count: usize) -> Self {
Self {
worker_count,
profiler: Arc::new(Mutex::new(Profiler::new())),
query_cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn execute_parallel(
&self,
plan: &QueryPlan,
triples: &[StarTriple],
) -> StarResult<Vec<QueryBinding>> {
let profiler = Arc::clone(&self.profiler);
if let Ok(mut p) = profiler.lock() {
p.start();
}
let pattern_results: Vec<Vec<QueryBinding>> = plan
.patterns
.iter()
.map(|pattern| self.match_pattern_parallel(pattern, triples))
.collect::<StarResult<Vec<_>>>()?;
let mut result = if pattern_results.is_empty() {
Vec::new()
} else {
pattern_results[0].clone()
};
for join in &plan.joins {
if join.left < pattern_results.len() && join.right < pattern_results.len() {
result = self.parallel_join(
&result,
&pattern_results[join.right],
&join.join_vars,
join.join_type.clone(),
)?;
}
}
result = self.parallel_filter(&result, &plan.filters)?;
if let Ok(mut p) = profiler.lock() {
p.stop();
}
Ok(result)
}
fn match_pattern_parallel(
&self,
pattern: &TriplePattern,
triples: &[StarTriple],
) -> StarResult<Vec<QueryBinding>> {
let results: Vec<QueryBinding> = triples
.par_iter()
.filter_map(|triple| {
if self.matches_pattern(triple, pattern) {
Some(self.create_binding(triple, pattern))
} else {
None
}
})
.collect();
Ok(results)
}
fn matches_pattern(&self, triple: &StarTriple, pattern: &TriplePattern) -> bool {
if let Some(ref subject) = pattern.subject {
if &triple.subject != subject {
return false;
}
}
if let Some(ref predicate) = pattern.predicate {
if &triple.predicate != predicate {
return false;
}
}
if let Some(ref object) = pattern.object {
if &triple.object != object {
return false;
}
}
true
}
fn create_binding(&self, triple: &StarTriple, pattern: &TriplePattern) -> QueryBinding {
let mut bindings = HashMap::new();
if pattern.subject.is_none() {
if let Some(ref var) = pattern.variable_name {
bindings.insert(format!("{}Subject", var), triple.subject.clone());
}
}
if pattern.predicate.is_none() {
if let Some(ref var) = pattern.variable_name {
bindings.insert(format!("{}Predicate", var), triple.predicate.clone());
}
}
if pattern.object.is_none() {
if let Some(ref var) = pattern.variable_name {
bindings.insert(format!("{}Object", var), triple.object.clone());
}
}
QueryBinding { bindings }
}
fn parallel_join(
&self,
left: &[QueryBinding],
right: &[QueryBinding],
join_vars: &[String],
join_type: JoinType,
) -> StarResult<Vec<QueryBinding>> {
let results = Arc::new(Mutex::new(Vec::new()));
let results_clone = Arc::clone(&results);
par_scope(|s| {
let chunk_size = (left.len() / self.worker_count).max(10);
for chunk in left.chunks(chunk_size) {
let right_ref = right;
let join_vars_ref = join_vars;
let join_type_ref = join_type.clone();
let results_ref = Arc::clone(&results_clone);
s.spawn(move |_| {
let mut local_results = Vec::new();
for left_binding in chunk {
for right_binding in right_ref {
if self.bindings_compatible(left_binding, right_binding, join_vars_ref)
{
let mut merged = left_binding.clone();
for (k, v) in &right_binding.bindings {
merged.bindings.insert(k.clone(), v.clone());
}
local_results.push(merged);
}
}
if join_type_ref == JoinType::LeftOuter && local_results.is_empty() {
local_results.push(left_binding.clone());
}
}
if let Ok(mut results) = results_ref.lock() {
results.extend(local_results);
}
});
}
});
let final_results = Arc::try_unwrap(results).unwrap_or_else(|arc| {
let mutex = arc.lock().unwrap_or_else(|e| e.into_inner());
Mutex::new(mutex.clone())
});
Ok(final_results
.into_inner()
.expect("lock should not be poisoned"))
}
fn bindings_compatible(
&self,
left: &QueryBinding,
right: &QueryBinding,
join_vars: &[String],
) -> bool {
for var in join_vars {
match (left.bindings.get(var), right.bindings.get(var)) {
(Some(left_val), Some(right_val)) => {
if left_val != right_val {
return false;
}
}
(None, None) => continue,
_ => return false,
}
}
true
}
fn parallel_filter(
&self,
bindings: &[QueryBinding],
filters: &[FilterOperation],
) -> StarResult<Vec<QueryBinding>> {
if filters.is_empty() {
return Ok(bindings.to_vec());
}
let results: Vec<QueryBinding> = bindings
.par_iter()
.filter(|binding| self.apply_filters(binding, filters))
.cloned()
.collect();
Ok(results)
}
fn apply_filters(&self, binding: &QueryBinding, filters: &[FilterOperation]) -> bool {
filters.iter().all(|filter| match &filter.expression {
FilterExpression::Equals(var, value) => binding
.bindings
.get(var)
.map(|v| v == value)
.unwrap_or(false),
FilterExpression::Bound(var) => binding.bindings.contains_key(var),
FilterExpression::NestingDepth(var, min, max) => binding
.bindings
.get(var)
.map(|term| {
let depth = term.nesting_depth();
depth >= *min && depth <= *max
})
.unwrap_or(false),
FilterExpression::Regex(var, _pattern) => {
binding.bindings.contains_key(var)
}
})
}
pub fn get_statistics(&self) -> HashMap<String, u64> {
HashMap::new()
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.query_cache.lock() {
cache.clear();
}
}
}
impl Default for ParallelQueryExecutor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_executor_creation() {
let executor = ParallelQueryExecutor::new();
assert!(executor.worker_count > 0);
}
#[test]
fn test_pattern_matching() {
let executor = ParallelQueryExecutor::new();
let triple = StarTriple::new(
StarTerm::iri("http://example.org/s").unwrap(),
StarTerm::iri("http://example.org/p").unwrap(),
StarTerm::iri("http://example.org/o").unwrap(),
);
let pattern = TriplePattern {
subject: Some(StarTerm::iri("http://example.org/s").unwrap()),
predicate: None,
object: None,
variable_name: Some("x".to_string()),
};
assert!(executor.matches_pattern(&triple, &pattern));
}
#[test]
fn test_parallel_execution() {
let executor = ParallelQueryExecutor::new();
let triples = vec![
StarTriple::new(
StarTerm::iri("http://example.org/s1").unwrap(),
StarTerm::iri("http://example.org/p").unwrap(),
StarTerm::iri("http://example.org/o1").unwrap(),
),
StarTriple::new(
StarTerm::iri("http://example.org/s2").unwrap(),
StarTerm::iri("http://example.org/p").unwrap(),
StarTerm::iri("http://example.org/o2").unwrap(),
),
];
let plan = QueryPlan {
patterns: vec![TriplePattern {
subject: None,
predicate: Some(StarTerm::iri("http://example.org/p").unwrap()),
object: None,
variable_name: Some("x".to_string()),
}],
joins: vec![],
filters: vec![],
cost: 1.0,
};
let results = executor.execute_parallel(&plan, &triples).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_filter_application() {
let executor = ParallelQueryExecutor::new();
let mut bindings = HashMap::new();
bindings.insert(
"x".to_string(),
StarTerm::iri("http://example.org/test").unwrap(),
);
let binding = QueryBinding { bindings };
let filters = vec![FilterOperation {
expression: FilterExpression::Bound("x".to_string()),
}];
assert!(executor.apply_filters(&binding, &filters));
}
}