use crate::model::{Object, Predicate, Subject, Triple, TriplePattern};
use crate::model::{ObjectPattern, PredicatePattern, SubjectPattern};
use crate::OxirsError;
use scirs2_core::ndarray_ext::{Array1, Array2};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use scirs2_core::metrics::{Counter, Timer};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub type Result<T> = std::result::Result<T, OxirsError>;
#[derive(Debug, Clone)]
pub struct MatcherStats {
pub total_matches: u64,
pub total_triples_processed: u64,
pub simd_time_ns: u64,
pub scalar_time_ns: u64,
pub simd_calls: u64,
pub scalar_calls: u64,
pub avg_speedup: f64,
}
pub struct SimdTripleMatcher {
chunk_size: usize,
match_counter: Arc<Counter>,
simd_timer: Arc<Timer>,
scalar_timer: Arc<Timer>,
triples_processed: Arc<AtomicU64>,
simd_calls: Arc<AtomicU64>,
scalar_calls: Arc<AtomicU64>,
}
impl SimdTripleMatcher {
pub fn new() -> Self {
let match_counter = Arc::new(Counter::new("simd_triple_matches".to_string()));
let simd_timer = Arc::new(Timer::new("simd_matching".to_string()));
let scalar_timer = Arc::new(Timer::new("scalar_matching".to_string()));
Self {
chunk_size: Self::optimal_chunk_size(),
match_counter,
simd_timer,
scalar_timer,
triples_processed: Arc::new(AtomicU64::new(0)),
simd_calls: Arc::new(AtomicU64::new(0)),
scalar_calls: Arc::new(AtomicU64::new(0)),
}
}
pub fn with_chunk_size(chunk_size: usize) -> Self {
let mut matcher = Self::new();
matcher.chunk_size = chunk_size;
matcher
}
pub fn stats(&self) -> MatcherStats {
let simd_stats = self.simd_timer.get_stats();
let scalar_stats = self.scalar_timer.get_stats();
let simd_time_ns = (simd_stats.sum * 1_000_000_000.0) as u64;
let scalar_time_ns = (scalar_stats.sum * 1_000_000_000.0) as u64;
let simd_calls = self.simd_calls.load(Ordering::Relaxed);
let scalar_calls = self.scalar_calls.load(Ordering::Relaxed);
let avg_speedup = if simd_stats.mean > 0.0 && scalar_stats.mean > 0.0 {
scalar_stats.mean / simd_stats.mean
} else {
1.0
};
MatcherStats {
total_matches: self.match_counter.get(),
total_triples_processed: self.triples_processed.load(Ordering::Relaxed),
simd_time_ns,
scalar_time_ns,
simd_calls,
scalar_calls,
avg_speedup,
}
}
pub fn reset_stats(&self) {
self.triples_processed.store(0, Ordering::Relaxed);
self.simd_calls.store(0, Ordering::Relaxed);
self.scalar_calls.store(0, Ordering::Relaxed);
}
fn optimal_chunk_size() -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512f") {
16 } else {
8 }
}
#[cfg(target_arch = "aarch64")]
{
4 }
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
{
8 }
}
pub fn match_batch(&self, pattern: &TriplePattern, triples: &[Triple]) -> Result<Vec<usize>> {
if triples.is_empty() {
return Ok(Vec::new());
}
if triples.len() < self.chunk_size * 2 {
return Ok(self.match_scalar(pattern, triples));
}
self.match_simd(pattern, triples)
}
fn match_scalar(&self, pattern: &TriplePattern, triples: &[Triple]) -> Vec<usize> {
let _timer_guard = self.scalar_timer.start();
self.scalar_calls.fetch_add(1, Ordering::Relaxed);
self.triples_processed
.fetch_add(triples.len() as u64, Ordering::Relaxed);
let matches: Vec<usize> = triples
.iter()
.enumerate()
.filter_map(|(idx, triple)| {
if pattern.matches(triple) {
Some(idx)
} else {
None
}
})
.collect();
self.match_counter.add(matches.len() as u64);
matches
}
fn match_simd(&self, pattern: &TriplePattern, triples: &[Triple]) -> Result<Vec<usize>> {
let _timer_guard = self.simd_timer.start();
self.simd_calls.fetch_add(1, Ordering::Relaxed);
self.triples_processed
.fetch_add(triples.len() as u64, Ordering::Relaxed);
let mut matches = Vec::with_capacity(triples.len() / 4);
let pattern_mask = self.pattern_to_mask(pattern);
#[cfg(feature = "parallel")]
{
if triples.len() > self.chunk_size * 8 {
return self.match_simd_parallel(pattern, triples, &pattern_mask);
}
}
for (chunk_idx, chunk) in triples.chunks(self.chunk_size).enumerate() {
let base_idx = chunk_idx * self.chunk_size;
let triple_masks = self.triples_to_masks(chunk);
let match_results = self.simd_compare_masks(&pattern_mask, &triple_masks)?;
for (i, &matched) in match_results.iter().enumerate() {
if matched != 0.0 {
matches.push(base_idx + i);
}
}
}
self.match_counter.add(matches.len() as u64);
Ok(matches)
}
#[cfg(feature = "parallel")]
fn match_simd_parallel(
&self,
_pattern: &TriplePattern,
triples: &[Triple],
pattern_mask: &[f32; 3],
) -> Result<Vec<usize>> {
use std::sync::Mutex;
let matches = Arc::new(Mutex::new(Vec::new()));
let chunk_size = self.chunk_size;
let chunks: Vec<&[Triple]> = triples.chunks(chunk_size * 4).collect();
chunks.par_iter().for_each(|chunk_group| {
let mut local_matches = Vec::new();
for (chunk_idx, chunk) in chunk_group.chunks(chunk_size).enumerate() {
let base_idx = chunk_idx * chunk_size;
let triple_masks = self.triples_to_masks(chunk);
if let Ok(match_results) = self.simd_compare_masks(pattern_mask, &triple_masks) {
for (i, &matched) in match_results.iter().enumerate() {
if matched != 0.0 {
local_matches.push(base_idx + i);
}
}
}
}
if let Ok(mut global) = matches.lock() {
global.extend(local_matches);
}
});
let final_matches = match Arc::try_unwrap(matches) {
Ok(mutex) => mutex.into_inner().unwrap_or_default(),
Err(arc) => arc.lock().expect("lock should not be poisoned").clone(),
};
self.match_counter.add(final_matches.len() as u64);
Ok(final_matches)
}
fn pattern_to_mask(&self, pattern: &TriplePattern) -> [f32; 3] {
let subject_mask = match &pattern.subject {
None => 0.0, Some(SubjectPattern::Variable(_)) => 0.0, Some(SubjectPattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
Some(SubjectPattern::BlankNode(bn)) => self.hash_term(bn.as_str()),
};
let predicate_mask = match &pattern.predicate {
None => 0.0,
Some(PredicatePattern::Variable(_)) => 0.0,
Some(PredicatePattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
};
let object_mask = match &pattern.object {
None => 0.0,
Some(ObjectPattern::Variable(_)) => 0.0,
Some(ObjectPattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
Some(ObjectPattern::BlankNode(bn)) => self.hash_term(bn.as_str()),
Some(ObjectPattern::Literal(lit)) => self.hash_term(lit.value()),
};
[subject_mask, predicate_mask, object_mask]
}
fn triples_to_masks(&self, triples: &[Triple]) -> Vec<[f32; 3]> {
triples
.iter()
.map(|triple| {
[
self.hash_subject(triple.subject()),
self.hash_predicate(triple.predicate()),
self.hash_object(triple.object()),
]
})
.collect()
}
fn simd_compare_masks(
&self,
pattern: &[f32; 3],
triple_masks: &[[f32; 3]],
) -> Result<Vec<f32>> {
if triple_masks.is_empty() {
return Ok(Vec::new());
}
if triple_masks.len() < 4 {
return Ok(self.scalar_compare_masks(pattern, triple_masks));
}
let num_triples = triple_masks.len();
let mut triple_matrix = Vec::with_capacity(num_triples * 3);
for mask in triple_masks {
triple_matrix.extend_from_slice(mask);
}
let triple_array = Array2::from_shape_vec((num_triples, 3), triple_matrix)
.map_err(|e| OxirsError::Query(format!("Failed to create triple array: {}", e)))?;
let pattern_array = Array1::from_vec(pattern.to_vec());
let mut results = vec![1.0; num_triples];
for (i, triple_view) in triple_array.outer_iter().enumerate() {
let mut matches = true;
for j in 0..3 {
let pattern_val = pattern_array[j];
let triple_val = triple_view[j];
if pattern_val == 0.0 {
continue;
}
if (pattern_val - triple_val).abs() >= 0.0001 {
matches = false;
break;
}
}
results[i] = if matches { 1.0 } else { 0.0 };
}
Ok(results)
}
fn scalar_compare_masks(&self, pattern: &[f32; 3], triple_masks: &[[f32; 3]]) -> Vec<f32> {
triple_masks
.iter()
.map(|triple_mask| {
let matches_all = (0..3).all(|j| {
let pattern_val = pattern[j];
let triple_val = triple_mask[j];
if pattern_val == 0.0 {
return true;
}
(pattern_val - triple_val).abs() < 0.0001
});
if matches_all {
1.0
} else {
0.0
}
})
.collect()
}
#[allow(dead_code)]
fn matches_mask(&self, pattern: &[f32; 3], triple: &Triple) -> bool {
let triple_mask = [
self.hash_subject(triple.subject()),
self.hash_predicate(triple.predicate()),
self.hash_object(triple.object()),
];
(0..3).all(|i| {
let pattern_val = pattern[i];
let triple_val = triple_mask[i];
pattern_val == 0.0 || (pattern_val - triple_val).abs() < 0.0001
})
}
fn hash_term(&self, term: &str) -> f32 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
term.hash(&mut hasher);
let hash = hasher.finish();
((hash % (i32::MAX as u64)) as f32) + 1.0
}
fn hash_subject(&self, subject: &Subject) -> f32 {
match subject {
Subject::NamedNode(nn) => self.hash_term(nn.as_str()),
Subject::BlankNode(bn) => self.hash_term(bn.as_str()),
Subject::Variable(v) => self.hash_term(v.as_str()),
Subject::QuotedTriple(qt) => {
let repr = format!("<<{:?}>>", qt);
self.hash_term(&repr)
}
}
}
fn hash_predicate(&self, predicate: &Predicate) -> f32 {
match predicate {
Predicate::NamedNode(nn) => self.hash_term(nn.as_str()),
Predicate::Variable(v) => self.hash_term(v.as_str()),
}
}
fn hash_object(&self, object: &Object) -> f32 {
match object {
Object::NamedNode(nn) => self.hash_term(nn.as_str()),
Object::BlankNode(bn) => self.hash_term(bn.as_str()),
Object::Literal(lit) => self.hash_term(lit.value()),
Object::Variable(v) => self.hash_term(v.as_str()),
Object::QuotedTriple(qt) => {
let repr = format!("<<{:?}>>", qt);
self.hash_term(&repr)
}
}
}
pub fn estimate_selectivity(&self, pattern: &TriplePattern, _total_triples: usize) -> f32 {
let num_wildcards = pattern.subject.is_none() as i32
+ pattern.predicate.is_none() as i32
+ pattern.object.is_none() as i32;
match num_wildcards {
3 => 1.0, 2 => 0.5, 1 => 0.1, 0 => 0.001, _ => 0.5,
}
}
}
impl Default for SimdTripleMatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{Literal, NamedNode};
#[test]
fn test_simd_matcher_creation() {
let matcher = SimdTripleMatcher::new();
assert!(matcher.chunk_size >= 4);
assert!(matcher.chunk_size <= 16);
}
#[test]
fn test_match_empty_batch() {
let matcher = SimdTripleMatcher::new();
let pattern = TriplePattern::new(None, None, None);
let triples = vec![];
let matches = matcher
.match_batch(&pattern, &triples)
.expect("operation should succeed");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_match_all_pattern() -> Result<()> {
let matcher = SimdTripleMatcher::new();
let pattern = TriplePattern::new(None, None, None);
let s = Subject::NamedNode(NamedNode::new("http://example.org/s")?);
let p = Predicate::NamedNode(NamedNode::new("http://example.org/p")?);
let o = Object::Literal(Literal::new("test"));
let triples = vec![
Triple::new(s.clone(), p.clone(), o.clone()),
Triple::new(s.clone(), p.clone(), o.clone()),
Triple::new(s, p, o),
];
let matches = matcher.match_batch(&pattern, &triples)?;
assert_eq!(matches.len(), 3);
Ok(())
}
#[test]
fn test_hash_term_non_zero() {
let matcher = SimdTripleMatcher::new();
let hash1 = matcher.hash_term("http://example.org/test");
let hash2 = matcher.hash_term("http://example.org/other");
assert!(hash1 > 0.0);
assert!(hash2 > 0.0);
assert_ne!(hash1, hash2);
}
#[test]
fn test_optimal_chunk_size() {
let size = SimdTripleMatcher::optimal_chunk_size();
assert!((4..=16).contains(&size));
}
#[test]
fn test_estimate_selectivity() {
let matcher = SimdTripleMatcher::new();
let pattern_all = TriplePattern::new(None, None, None);
assert_eq!(matcher.estimate_selectivity(&pattern_all, 1000), 1.0);
let s =
SubjectPattern::NamedNode(NamedNode::new("http://example.org/s").expect("valid IRI"));
let p =
PredicatePattern::NamedNode(NamedNode::new("http://example.org/p").expect("valid IRI"));
let o = ObjectPattern::Literal(Literal::new("test"));
let pattern_none = TriplePattern::new(Some(s), Some(p), Some(o));
assert_eq!(matcher.estimate_selectivity(&pattern_none, 1000), 0.001);
}
}