1use crate::model::{Object, Predicate, Subject, Triple, TriplePattern};
41use crate::model::{ObjectPattern, PredicatePattern, SubjectPattern};
42use crate::OxirsError;
43
44use scirs2_core::ndarray_ext::{Array1, Array2};
46
47#[cfg(feature = "parallel")]
49use rayon::prelude::*;
50
51use scirs2_core::metrics::{Counter, Timer};
53
54use std::sync::atomic::{AtomicU64, Ordering};
56use std::sync::Arc;
57
58pub type Result<T> = std::result::Result<T, OxirsError>;
60
61#[derive(Debug, Clone)]
63pub struct MatcherStats {
64 pub total_matches: u64,
66 pub total_triples_processed: u64,
68 pub simd_time_ns: u64,
70 pub scalar_time_ns: u64,
72 pub simd_calls: u64,
74 pub scalar_calls: u64,
76 pub avg_speedup: f64,
78}
79
80pub struct SimdTripleMatcher {
92 chunk_size: usize,
94 match_counter: Arc<Counter>,
96 simd_timer: Arc<Timer>,
98 scalar_timer: Arc<Timer>,
100 triples_processed: Arc<AtomicU64>,
102 simd_calls: Arc<AtomicU64>,
104 scalar_calls: Arc<AtomicU64>,
106}
107
108impl SimdTripleMatcher {
109 pub fn new() -> Self {
111 let match_counter = Arc::new(Counter::new("simd_triple_matches".to_string()));
112 let simd_timer = Arc::new(Timer::new("simd_matching".to_string()));
113 let scalar_timer = Arc::new(Timer::new("scalar_matching".to_string()));
114
115 Self {
116 chunk_size: Self::optimal_chunk_size(),
117 match_counter,
118 simd_timer,
119 scalar_timer,
120 triples_processed: Arc::new(AtomicU64::new(0)),
121 simd_calls: Arc::new(AtomicU64::new(0)),
122 scalar_calls: Arc::new(AtomicU64::new(0)),
123 }
124 }
125
126 pub fn with_chunk_size(chunk_size: usize) -> Self {
128 let mut matcher = Self::new();
129 matcher.chunk_size = chunk_size;
130 matcher
131 }
132
133 pub fn stats(&self) -> MatcherStats {
135 let simd_stats = self.simd_timer.get_stats();
136 let scalar_stats = self.scalar_timer.get_stats();
137
138 let simd_time_ns = (simd_stats.sum * 1_000_000_000.0) as u64;
139 let scalar_time_ns = (scalar_stats.sum * 1_000_000_000.0) as u64;
140 let simd_calls = self.simd_calls.load(Ordering::Relaxed);
141 let scalar_calls = self.scalar_calls.load(Ordering::Relaxed);
142
143 let avg_speedup = if simd_stats.mean > 0.0 && scalar_stats.mean > 0.0 {
145 scalar_stats.mean / simd_stats.mean
146 } else {
147 1.0
148 };
149
150 MatcherStats {
151 total_matches: self.match_counter.get(),
152 total_triples_processed: self.triples_processed.load(Ordering::Relaxed),
153 simd_time_ns,
154 scalar_time_ns,
155 simd_calls,
156 scalar_calls,
157 avg_speedup,
158 }
159 }
160
161 pub fn reset_stats(&self) {
166 self.triples_processed.store(0, Ordering::Relaxed);
167 self.simd_calls.store(0, Ordering::Relaxed);
168 self.scalar_calls.store(0, Ordering::Relaxed);
169 }
170
171 fn optimal_chunk_size() -> usize {
173 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
175 {
176 if is_x86_feature_detected!("avx512f") {
177 16 } else {
179 8 }
181 }
182
183 #[cfg(target_arch = "aarch64")]
184 {
185 4 }
187
188 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
189 {
190 8 }
192 }
193
194 pub fn match_batch(&self, pattern: &TriplePattern, triples: &[Triple]) -> Result<Vec<usize>> {
208 if triples.is_empty() {
209 return Ok(Vec::new());
210 }
211
212 if triples.len() < self.chunk_size * 2 {
214 return Ok(self.match_scalar(pattern, triples));
215 }
216
217 self.match_simd(pattern, triples)
219 }
220
221 fn match_scalar(&self, pattern: &TriplePattern, triples: &[Triple]) -> Vec<usize> {
223 let _timer_guard = self.scalar_timer.start();
225 self.scalar_calls.fetch_add(1, Ordering::Relaxed);
226 self.triples_processed
227 .fetch_add(triples.len() as u64, Ordering::Relaxed);
228
229 let matches: Vec<usize> = triples
230 .iter()
231 .enumerate()
232 .filter_map(|(idx, triple)| {
233 if pattern.matches(triple) {
234 Some(idx)
235 } else {
236 None
237 }
238 })
239 .collect();
240
241 self.match_counter.add(matches.len() as u64);
242 matches
243 }
244
245 fn match_simd(&self, pattern: &TriplePattern, triples: &[Triple]) -> Result<Vec<usize>> {
247 let _timer_guard = self.simd_timer.start();
249 self.simd_calls.fetch_add(1, Ordering::Relaxed);
250 self.triples_processed
251 .fetch_add(triples.len() as u64, Ordering::Relaxed);
252
253 let mut matches = Vec::with_capacity(triples.len() / 4); let pattern_mask = self.pattern_to_mask(pattern);
257
258 #[cfg(feature = "parallel")]
259 {
260 if triples.len() > self.chunk_size * 8 {
262 return self.match_simd_parallel(pattern, triples, &pattern_mask);
263 }
264 }
265
266 for (chunk_idx, chunk) in triples.chunks(self.chunk_size).enumerate() {
268 let base_idx = chunk_idx * self.chunk_size;
269
270 let triple_masks = self.triples_to_masks(chunk);
272
273 let match_results = self.simd_compare_masks(&pattern_mask, &triple_masks)?;
275
276 for (i, &matched) in match_results.iter().enumerate() {
278 if matched != 0.0 {
279 matches.push(base_idx + i);
280 }
281 }
282 }
283
284 self.match_counter.add(matches.len() as u64);
285 Ok(matches)
286 }
287
288 #[cfg(feature = "parallel")]
290 fn match_simd_parallel(
291 &self,
292 _pattern: &TriplePattern,
293 triples: &[Triple],
294 pattern_mask: &[f32; 3],
295 ) -> Result<Vec<usize>> {
296 use std::sync::Mutex;
297
298 let matches = Arc::new(Mutex::new(Vec::new()));
300 let chunk_size = self.chunk_size;
301
302 let chunks: Vec<&[Triple]> = triples.chunks(chunk_size * 4).collect();
304
305 chunks.par_iter().for_each(|chunk_group| {
306 let mut local_matches = Vec::new();
307 for (chunk_idx, chunk) in chunk_group.chunks(chunk_size).enumerate() {
308 let base_idx = chunk_idx * chunk_size;
309 let triple_masks = self.triples_to_masks(chunk);
310
311 if let Ok(match_results) = self.simd_compare_masks(pattern_mask, &triple_masks) {
313 for (i, &matched) in match_results.iter().enumerate() {
314 if matched != 0.0 {
315 local_matches.push(base_idx + i);
316 }
317 }
318 }
319 }
320
321 if let Ok(mut global) = matches.lock() {
323 global.extend(local_matches);
324 }
325 });
326
327 let final_matches = match Arc::try_unwrap(matches) {
328 Ok(mutex) => mutex.into_inner().unwrap_or_default(),
329 Err(arc) => arc.lock().expect("lock should not be poisoned").clone(),
330 };
331
332 self.match_counter.add(final_matches.len() as u64);
333 Ok(final_matches)
334 }
335
336 fn pattern_to_mask(&self, pattern: &TriplePattern) -> [f32; 3] {
342 let subject_mask = match &pattern.subject {
343 None => 0.0, Some(SubjectPattern::Variable(_)) => 0.0, Some(SubjectPattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
346 Some(SubjectPattern::BlankNode(bn)) => self.hash_term(bn.as_str()),
347 };
348
349 let predicate_mask = match &pattern.predicate {
350 None => 0.0,
351 Some(PredicatePattern::Variable(_)) => 0.0,
352 Some(PredicatePattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
353 };
354
355 let object_mask = match &pattern.object {
356 None => 0.0,
357 Some(ObjectPattern::Variable(_)) => 0.0,
358 Some(ObjectPattern::NamedNode(nn)) => self.hash_term(nn.as_str()),
359 Some(ObjectPattern::BlankNode(bn)) => self.hash_term(bn.as_str()),
360 Some(ObjectPattern::Literal(lit)) => self.hash_term(lit.value()),
361 };
362
363 [subject_mask, predicate_mask, object_mask]
364 }
365
366 fn triples_to_masks(&self, triples: &[Triple]) -> Vec<[f32; 3]> {
368 triples
369 .iter()
370 .map(|triple| {
371 [
372 self.hash_subject(triple.subject()),
373 self.hash_predicate(triple.predicate()),
374 self.hash_object(triple.object()),
375 ]
376 })
377 .collect()
378 }
379
380 fn simd_compare_masks(
385 &self,
386 pattern: &[f32; 3],
387 triple_masks: &[[f32; 3]],
388 ) -> Result<Vec<f32>> {
389 if triple_masks.is_empty() {
390 return Ok(Vec::new());
391 }
392
393 if triple_masks.len() < 4 {
395 return Ok(self.scalar_compare_masks(pattern, triple_masks));
396 }
397
398 let num_triples = triple_masks.len();
400 let mut triple_matrix = Vec::with_capacity(num_triples * 3);
401 for mask in triple_masks {
402 triple_matrix.extend_from_slice(mask);
403 }
404
405 let triple_array = Array2::from_shape_vec((num_triples, 3), triple_matrix)
407 .map_err(|e| OxirsError::Query(format!("Failed to create triple array: {}", e)))?;
408
409 let pattern_array = Array1::from_vec(pattern.to_vec());
411
412 let mut results = vec![1.0; num_triples];
414
415 for (i, triple_view) in triple_array.outer_iter().enumerate() {
417 let mut matches = true;
418
419 for j in 0..3 {
421 let pattern_val = pattern_array[j];
422 let triple_val = triple_view[j];
423
424 if pattern_val == 0.0 {
426 continue;
427 }
428
429 if (pattern_val - triple_val).abs() >= 0.0001 {
431 matches = false;
432 break;
433 }
434 }
435
436 results[i] = if matches { 1.0 } else { 0.0 };
437 }
438
439 Ok(results)
440 }
441
442 fn scalar_compare_masks(&self, pattern: &[f32; 3], triple_masks: &[[f32; 3]]) -> Vec<f32> {
444 triple_masks
445 .iter()
446 .map(|triple_mask| {
447 let matches_all = (0..3).all(|j| {
448 let pattern_val = pattern[j];
449 let triple_val = triple_mask[j];
450
451 if pattern_val == 0.0 {
453 return true;
454 }
455
456 (pattern_val - triple_val).abs() < 0.0001
458 });
459
460 if matches_all {
461 1.0
462 } else {
463 0.0
464 }
465 })
466 .collect()
467 }
468
469 #[allow(dead_code)]
471 fn matches_mask(&self, pattern: &[f32; 3], triple: &Triple) -> bool {
472 let triple_mask = [
473 self.hash_subject(triple.subject()),
474 self.hash_predicate(triple.predicate()),
475 self.hash_object(triple.object()),
476 ];
477
478 (0..3).all(|i| {
479 let pattern_val = pattern[i];
480 let triple_val = triple_mask[i];
481
482 pattern_val == 0.0 || (pattern_val - triple_val).abs() < 0.0001
484 })
485 }
486
487 fn hash_term(&self, term: &str) -> f32 {
492 use std::collections::hash_map::DefaultHasher;
493 use std::hash::{Hash, Hasher};
494
495 let mut hasher = DefaultHasher::new();
496 term.hash(&mut hasher);
497 let hash = hasher.finish();
498
499 ((hash % (i32::MAX as u64)) as f32) + 1.0
502 }
503
504 fn hash_subject(&self, subject: &Subject) -> f32 {
506 match subject {
507 Subject::NamedNode(nn) => self.hash_term(nn.as_str()),
508 Subject::BlankNode(bn) => self.hash_term(bn.as_str()),
509 Subject::Variable(v) => self.hash_term(v.as_str()),
510 Subject::QuotedTriple(qt) => {
511 let repr = format!("<<{:?}>>", qt);
513 self.hash_term(&repr)
514 }
515 }
516 }
517
518 fn hash_predicate(&self, predicate: &Predicate) -> f32 {
520 match predicate {
521 Predicate::NamedNode(nn) => self.hash_term(nn.as_str()),
522 Predicate::Variable(v) => self.hash_term(v.as_str()),
523 }
524 }
525
526 fn hash_object(&self, object: &Object) -> f32 {
528 match object {
529 Object::NamedNode(nn) => self.hash_term(nn.as_str()),
530 Object::BlankNode(bn) => self.hash_term(bn.as_str()),
531 Object::Literal(lit) => self.hash_term(lit.value()),
532 Object::Variable(v) => self.hash_term(v.as_str()),
533 Object::QuotedTriple(qt) => {
534 let repr = format!("<<{:?}>>", qt);
536 self.hash_term(&repr)
537 }
538 }
539 }
540
541 pub fn estimate_selectivity(&self, pattern: &TriplePattern, _total_triples: usize) -> f32 {
545 let num_wildcards = pattern.subject.is_none() as i32
546 + pattern.predicate.is_none() as i32
547 + pattern.object.is_none() as i32;
548
549 match num_wildcards {
551 3 => 1.0, 2 => 0.5, 1 => 0.1, 0 => 0.001, _ => 0.5,
556 }
557 }
558}
559
560impl Default for SimdTripleMatcher {
561 fn default() -> Self {
562 Self::new()
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use crate::model::{Literal, NamedNode};
570
571 #[test]
572 fn test_simd_matcher_creation() {
573 let matcher = SimdTripleMatcher::new();
574 assert!(matcher.chunk_size >= 4);
575 assert!(matcher.chunk_size <= 16);
576 }
577
578 #[test]
579 fn test_match_empty_batch() {
580 let matcher = SimdTripleMatcher::new();
581 let pattern = TriplePattern::new(None, None, None);
582 let triples = vec![];
583
584 let matches = matcher
585 .match_batch(&pattern, &triples)
586 .expect("operation should succeed");
587 assert_eq!(matches.len(), 0);
588 }
589
590 #[test]
591 fn test_match_all_pattern() -> Result<()> {
592 let matcher = SimdTripleMatcher::new();
593 let pattern = TriplePattern::new(None, None, None); let s = Subject::NamedNode(NamedNode::new("http://example.org/s")?);
597 let p = Predicate::NamedNode(NamedNode::new("http://example.org/p")?);
598 let o = Object::Literal(Literal::new("test"));
599
600 let triples = vec![
601 Triple::new(s.clone(), p.clone(), o.clone()),
602 Triple::new(s.clone(), p.clone(), o.clone()),
603 Triple::new(s, p, o),
604 ];
605
606 let matches = matcher.match_batch(&pattern, &triples)?;
607 assert_eq!(matches.len(), 3); Ok(())
610 }
611
612 #[test]
613 fn test_hash_term_non_zero() {
614 let matcher = SimdTripleMatcher::new();
615 let hash1 = matcher.hash_term("http://example.org/test");
616 let hash2 = matcher.hash_term("http://example.org/other");
617
618 assert!(hash1 > 0.0);
620 assert!(hash2 > 0.0);
621
622 assert_ne!(hash1, hash2);
625 }
626
627 #[test]
628 fn test_optimal_chunk_size() {
629 let size = SimdTripleMatcher::optimal_chunk_size();
630 assert!((4..=16).contains(&size));
632 }
633
634 #[test]
635 fn test_estimate_selectivity() {
636 let matcher = SimdTripleMatcher::new();
637
638 let pattern_all = TriplePattern::new(None, None, None);
640 assert_eq!(matcher.estimate_selectivity(&pattern_all, 1000), 1.0);
641
642 let s =
644 SubjectPattern::NamedNode(NamedNode::new("http://example.org/s").expect("valid IRI"));
645 let p =
646 PredicatePattern::NamedNode(NamedNode::new("http://example.org/p").expect("valid IRI"));
647 let o = ObjectPattern::Literal(Literal::new("test"));
648 let pattern_none = TriplePattern::new(Some(s), Some(p), Some(o));
649 assert_eq!(matcher.estimate_selectivity(&pattern_none, 1000), 0.001);
650 }
651}