use crate::{RuleAtom, Term};
use scirs2_core::parallel_ops;
use std::collections::HashMap;
pub struct SimdMatcher;
impl SimdMatcher {
pub fn new() -> Self {
Self
}
#[inline]
pub fn fast_term_hash(&self, term: &Term) -> u64 {
match term {
Term::Constant(s) | Term::Literal(s) | Term::Variable(s) => {
self.simd_string_hash(s.as_bytes())
}
_ => 0,
}
}
#[inline]
fn simd_string_hash(&self, bytes: &[u8]) -> u64 {
if bytes.len() <= 16 {
return self.fast_hash_small(bytes);
}
let mut hash: u64 = 0xcbf29ce484222325;
let chunks = bytes.chunks_exact(16);
let remainder = chunks.remainder();
for chunk in chunks {
hash = self.process_chunk_simd(hash, chunk);
}
for &byte in remainder {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3); }
hash
}
#[inline]
fn fast_hash_small(&self, bytes: &[u8]) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in bytes {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
#[inline]
fn process_chunk_simd(&self, mut hash: u64, chunk: &[u8]) -> u64 {
let mut data = [0u64; 2];
for (i, byte_chunk) in chunk.chunks(8).enumerate() {
if byte_chunk.len() == 8 {
data[i] = u64::from_le_bytes([
byte_chunk[0],
byte_chunk[1],
byte_chunk[2],
byte_chunk[3],
byte_chunk[4],
byte_chunk[5],
byte_chunk[6],
byte_chunk[7],
]);
}
}
hash ^= data[0];
hash = hash.wrapping_mul(0x100000001b3);
hash ^= data[1];
hash = hash.wrapping_mul(0x100000001b3);
hash
}
pub fn batch_deduplicate(&self, facts: &mut Vec<RuleAtom>) {
if facts.len() < 2 {
return;
}
facts.sort_unstable_by(|a, b| self.fast_fact_compare(a, b));
facts.dedup_by(|a, b| self.facts_equal_simd(a, b));
}
#[inline]
fn fast_fact_compare(&self, a: &RuleAtom, b: &RuleAtom) -> std::cmp::Ordering {
if std::mem::discriminant(a) != std::mem::discriminant(b) {
return format!("{:?}", a).cmp(&format!("{:?}", b));
}
if let (
RuleAtom::Triple {
subject: s1,
predicate: p1,
object: o1,
},
RuleAtom::Triple {
subject: s2,
predicate: p2,
object: o2,
},
) = (a, b)
{
let hash1 = self
.fast_term_hash(s1)
.wrapping_add(self.fast_term_hash(p1))
.wrapping_add(self.fast_term_hash(o1));
let hash2 = self
.fast_term_hash(s2)
.wrapping_add(self.fast_term_hash(p2))
.wrapping_add(self.fast_term_hash(o2));
hash1.cmp(&hash2)
} else {
format!("{:?}", a).cmp(&format!("{:?}", b))
}
}
#[inline]
fn facts_equal_simd(&self, a: &RuleAtom, b: &RuleAtom) -> bool {
if std::mem::discriminant(a) != std::mem::discriminant(b) {
return false;
}
if let (
RuleAtom::Triple {
subject: s1,
predicate: p1,
object: o1,
},
RuleAtom::Triple {
subject: s2,
predicate: p2,
object: o2,
},
) = (a, b)
{
let hash1 = self.fast_term_hash(s1);
let hash2 = self.fast_term_hash(s2);
if hash1 != hash2 {
return false;
}
self.terms_equal(s1, s2) && self.terms_equal(p1, p2) && self.terms_equal(o1, o2)
} else {
format!("{:?}", a) == format!("{:?}", b)
}
}
#[inline]
fn terms_equal(&self, a: &Term, b: &Term) -> bool {
match (a, b) {
(Term::Constant(a), Term::Constant(b)) => a == b,
(Term::Literal(a), Term::Literal(b)) => a == b,
(Term::Variable(a), Term::Variable(b)) => a == b,
_ => false,
}
}
pub fn parallel_filter<F>(&self, facts: Vec<RuleAtom>, predicate: F) -> Vec<RuleAtom>
where
F: Fn(&RuleAtom) -> bool + Sync + Send,
{
if facts.len() < 1000 {
return facts.into_iter().filter(predicate).collect();
}
parallel_ops::parallel_map(&facts, |fact| {
if predicate(fact) {
Some(fact.clone())
} else {
None
}
})
.into_iter()
.flatten()
.collect()
}
}
impl Default for SimdMatcher {
fn default() -> Self {
Self::new()
}
}
pub struct BatchProcessor {
matcher: SimdMatcher,
batch_size: usize,
}
impl BatchProcessor {
pub fn new(batch_size: usize) -> Self {
Self {
matcher: SimdMatcher::new(),
batch_size,
}
}
pub fn process_batches<F, R>(&self, facts: &[RuleAtom], mut processor: F) -> Vec<R>
where
F: FnMut(&[RuleAtom]) -> Vec<R>,
R: Clone,
{
let mut results = Vec::with_capacity(facts.len());
for batch in facts.chunks(self.batch_size) {
let batch_results = processor(batch);
results.extend(batch_results);
}
results
}
pub fn deduplicate(&self, facts: Vec<RuleAtom>) -> Vec<RuleAtom> {
let mut deduped = facts;
self.matcher.batch_deduplicate(&mut deduped);
deduped
}
}
impl Default for BatchProcessor {
fn default() -> Self {
Self::new(256) }
}
pub struct SimdUnifier {
matcher: SimdMatcher,
}
impl SimdUnifier {
pub fn new() -> Self {
Self {
matcher: SimdMatcher::new(),
}
}
#[inline]
pub fn unify_terms(
&self,
term1: &Term,
term2: &Term,
substitution: &mut HashMap<String, Term>,
) -> bool {
if let (Term::Variable(v1), Term::Variable(v2)) = (term1, term2) {
return self.unify_variables_simd(v1, v2, substitution);
}
if let Term::Variable(var) = term1 {
return self.unify_variable_simd(var, term2, substitution);
}
if let Term::Variable(var) = term2 {
return self.unify_variable_simd(var, term1, substitution);
}
self.unify_constants_simd(term1, term2)
}
#[inline]
fn unify_variables_simd(
&self,
var1: &str,
var2: &str,
substitution: &mut HashMap<String, Term>,
) -> bool {
if var1 == var2 {
return true;
}
match (substitution.get(var1), substitution.get(var2)) {
(Some(t1), Some(t2)) => {
self.terms_equal_simd(t1, t2)
}
(Some(t1), None) => {
substitution.insert(var2.to_string(), t1.clone());
true
}
(None, Some(t2)) => {
substitution.insert(var1.to_string(), t2.clone());
true
}
(None, None) => {
substitution.insert(var1.to_string(), Term::Variable(var2.to_string()));
true
}
}
}
#[inline]
fn unify_variable_simd(
&self,
var: &str,
term: &Term,
substitution: &mut HashMap<String, Term>,
) -> bool {
if let Some(bound_term) = substitution.get(var).cloned() {
return self.unify_terms(&bound_term, term, substitution);
}
substitution.insert(var.to_string(), term.clone());
true
}
#[inline]
fn unify_constants_simd(&self, term1: &Term, term2: &Term) -> bool {
match (term1, term2) {
(Term::Constant(c1), Term::Constant(c2)) => {
let hash1 = self.matcher.simd_string_hash(c1.as_bytes());
let hash2 = self.matcher.simd_string_hash(c2.as_bytes());
if hash1 != hash2 {
return false;
}
c1 == c2
}
(Term::Literal(l1), Term::Literal(l2)) => {
let hash1 = self.matcher.simd_string_hash(l1.as_bytes());
let hash2 = self.matcher.simd_string_hash(l2.as_bytes());
if hash1 != hash2 {
return false;
}
l1 == l2
}
(Term::Function { name: n1, args: a1 }, Term::Function { name: n2, args: a2 }) => {
if n1 != n2 || a1.len() != a2.len() {
return false;
}
let mut temp_sub = HashMap::new();
for (arg1, arg2) in a1.iter().zip(a2.iter()) {
if !self.unify_terms(arg1, arg2, &mut temp_sub) {
return false;
}
}
true
}
_ => false,
}
}
#[inline]
fn terms_equal_simd(&self, term1: &Term, term2: &Term) -> bool {
match (term1, term2) {
(Term::Constant(c1), Term::Constant(c2)) => {
let hash1 = self.matcher.simd_string_hash(c1.as_bytes());
let hash2 = self.matcher.simd_string_hash(c2.as_bytes());
hash1 == hash2 && c1 == c2
}
(Term::Literal(l1), Term::Literal(l2)) => {
let hash1 = self.matcher.simd_string_hash(l1.as_bytes());
let hash2 = self.matcher.simd_string_hash(l2.as_bytes());
hash1 == hash2 && l1 == l2
}
(Term::Variable(v1), Term::Variable(v2)) => v1 == v2,
_ => false,
}
}
pub fn batch_unify(
&self,
term_pairs: &[(Term, Term)],
substitution: &mut HashMap<String, Term>,
) -> bool {
if term_pairs.len() < 10 {
return term_pairs
.iter()
.all(|(t1, t2)| self.unify_terms(t1, t2, substitution));
}
term_pairs
.iter()
.all(|(t1, t2)| self.unify_terms(t1, t2, substitution))
}
pub fn apply_substitution(&self, term: &Term, substitution: &HashMap<String, Term>) -> Term {
Self::apply_substitution_impl(term, substitution)
}
fn apply_substitution_impl(term: &Term, substitution: &HashMap<String, Term>) -> Term {
match term {
Term::Variable(var) => {
if let Some(bound_term) = substitution.get(var) {
Self::apply_substitution_impl(bound_term, substitution)
} else {
term.clone()
}
}
Term::Function { name, args } => {
let new_args: Vec<_> = args
.iter()
.map(|arg| Self::apply_substitution_impl(arg, substitution))
.collect();
Term::Function {
name: name.clone(),
args: new_args,
}
}
_ => term.clone(),
}
}
pub fn batch_apply_substitution(
&self,
terms: &[Term],
substitution: &HashMap<String, Term>,
) -> Vec<Term> {
if terms.len() < 100 {
terms
.iter()
.map(|t| self.apply_substitution(t, substitution))
.collect()
} else {
parallel_ops::parallel_map(terms, |t| self.apply_substitution(t, substitution))
}
}
}
impl Default for SimdUnifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_matcher_creation() {
let _matcher = SimdMatcher::new();
}
#[test]
fn test_fast_term_hash() {
let matcher = SimdMatcher::new();
let term1 = Term::Constant("test".to_string());
let term2 = Term::Constant("test".to_string());
let term3 = Term::Constant("different".to_string());
let hash1 = matcher.fast_term_hash(&term1);
let hash2 = matcher.fast_term_hash(&term2);
let hash3 = matcher.fast_term_hash(&term3);
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_batch_deduplicate() {
let matcher = SimdMatcher::new();
let mut facts = vec![
RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
},
RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
},
RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Constant("d".to_string()),
},
];
matcher.batch_deduplicate(&mut facts);
assert_eq!(facts.len(), 2);
}
#[test]
fn test_batch_processor() {
let processor = BatchProcessor::new(2);
let facts = vec![
RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
},
RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
},
];
let results = processor.process_batches(&facts, |batch| batch.to_vec());
assert_eq!(results.len(), 2);
}
#[test]
fn test_parallel_filter() {
let matcher = SimdMatcher::new();
let facts: Vec<RuleAtom> = (0..100)
.map(|i| RuleAtom::Triple {
subject: Term::Constant(format!("entity_{}", i)),
predicate: Term::Constant("hasProperty".to_string()),
object: Term::Constant(format!("value_{}", i)),
})
.collect();
let filtered = matcher.parallel_filter(facts, |fact| {
if let RuleAtom::Triple {
subject: Term::Constant(s),
..
} = fact
{
s.contains("entity_1")
} else {
false
}
});
assert!(!filtered.is_empty());
}
#[test]
fn test_simd_unifier_creation() {
let _unifier = SimdUnifier::new();
}
#[test]
fn test_simd_unify_constants() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let t1 = Term::Constant("test".to_string());
let t2 = Term::Constant("test".to_string());
let t3 = Term::Constant("different".to_string());
assert!(unifier.unify_terms(&t1, &t2, &mut sub));
assert!(!unifier.unify_terms(&t1, &t3, &mut sub));
}
#[test]
fn test_simd_unify_variable_constant() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let var = Term::Variable("X".to_string());
let const_term = Term::Constant("value".to_string());
assert!(unifier.unify_terms(&var, &const_term, &mut sub));
assert_eq!(sub.get("X"), Some(&const_term));
}
#[test]
fn test_simd_unify_variables() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let var1 = Term::Variable("X".to_string());
let var2 = Term::Variable("Y".to_string());
assert!(unifier.unify_terms(&var1, &var2, &mut sub));
assert!(sub.contains_key("X") || sub.contains_key("Y"));
}
#[test]
fn test_simd_unify_with_existing_binding() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let var = Term::Variable("X".to_string());
let val1 = Term::Constant("value1".to_string());
let val2 = Term::Constant("value1".to_string());
let val3 = Term::Constant("value2".to_string());
assert!(unifier.unify_terms(&var, &val1, &mut sub));
assert!(unifier.unify_terms(&var, &val2, &mut sub));
assert!(!unifier.unify_terms(&var, &val3, &mut sub));
}
#[test]
fn test_simd_apply_substitution() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
sub.insert("X".to_string(), Term::Constant("value".to_string()));
let term = Term::Variable("X".to_string());
let result = unifier.apply_substitution(&term, &sub);
assert_eq!(result, Term::Constant("value".to_string()));
}
#[test]
fn test_simd_apply_substitution_nested() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
sub.insert("X".to_string(), Term::Variable("Y".to_string()));
sub.insert("Y".to_string(), Term::Constant("value".to_string()));
let term = Term::Variable("X".to_string());
let result = unifier.apply_substitution(&term, &sub);
assert_eq!(result, Term::Constant("value".to_string()));
}
#[test]
fn test_simd_batch_unify() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let pairs = vec![
(
Term::Variable("X".to_string()),
Term::Constant("a".to_string()),
),
(
Term::Variable("Y".to_string()),
Term::Constant("b".to_string()),
),
];
assert!(unifier.batch_unify(&pairs, &mut sub));
assert_eq!(sub.get("X"), Some(&Term::Constant("a".to_string())));
assert_eq!(sub.get("Y"), Some(&Term::Constant("b".to_string())));
}
#[test]
fn test_simd_batch_apply_substitution() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
sub.insert("X".to_string(), Term::Constant("value".to_string()));
let terms = vec![
Term::Variable("X".to_string()),
Term::Constant("other".to_string()),
Term::Variable("X".to_string()),
];
let results = unifier.batch_apply_substitution(&terms, &sub);
assert_eq!(results.len(), 3);
assert_eq!(results[0], Term::Constant("value".to_string()));
assert_eq!(results[1], Term::Constant("other".to_string()));
assert_eq!(results[2], Term::Constant("value".to_string()));
}
#[test]
fn test_simd_unify_functions() {
let unifier = SimdUnifier::new();
let mut sub = HashMap::new();
let func1 = Term::Function {
name: "f".to_string(),
args: vec![Term::Constant("a".to_string())],
};
let func2 = Term::Function {
name: "f".to_string(),
args: vec![Term::Constant("a".to_string())],
};
let func3 = Term::Function {
name: "g".to_string(),
args: vec![Term::Constant("a".to_string())],
};
assert!(unifier.unify_terms(&func1, &func2, &mut sub));
assert!(!unifier.unify_terms(&func1, &func3, &mut sub));
}
}