use crate::{Rule, RuleAtom, Term};
use anyhow::Result;
use scirs2_core::metrics::{Counter, Gauge};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use tracing::{debug, info};
lazy_static::lazy_static! {
static ref LOCKFREE_FACT_INSERTIONS: Counter = Counter::new("lockfree_fact_insertions".to_string());
static ref LOCKFREE_RULE_APPLICATIONS: Counter = Counter::new("lockfree_rule_applications".to_string());
static ref LOCKFREE_ACTIVE_WORKERS: Gauge = Gauge::new("lockfree_active_workers".to_string());
static ref LOCKFREE_CAS_RETRIES: Counter = Counter::new("lockfree_cas_retries".to_string());
}
#[derive(Debug)]
pub struct LockFreeEngine {
rules: Vec<Rule>,
num_workers: usize,
max_iterations: usize,
work_stealing: bool,
}
impl Default for LockFreeEngine {
fn default() -> Self {
Self::new()
}
}
impl LockFreeEngine {
pub fn new() -> Self {
let num_workers = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self {
rules: Vec::new(),
num_workers,
max_iterations: 100,
work_stealing: true,
}
}
pub fn with_config(num_workers: usize, max_iterations: usize) -> Self {
Self {
rules: Vec::new(),
num_workers,
max_iterations,
work_stealing: true,
}
}
pub fn add_rule(&mut self, rule: Rule) {
self.rules.push(rule);
}
pub fn add_rules(&mut self, rules: Vec<Rule>) {
self.rules.extend(rules);
}
pub fn set_work_stealing(&mut self, enabled: bool) {
self.work_stealing = enabled;
}
pub fn infer(&self, initial_facts: &[RuleAtom]) -> Result<Vec<RuleAtom>> {
info!(
"Starting lock-free inference with {} workers",
self.num_workers
);
let facts = Arc::new(LockFreeFactSet::new());
let new_facts = Arc::new(LockFreeFactSet::new());
let iteration = Arc::new(AtomicUsize::new(0));
let converged = Arc::new(AtomicBool::new(false));
for fact in initial_facts {
facts.insert(fact.clone());
}
while iteration.load(Ordering::Acquire) < self.max_iterations {
let current_iter = iteration.load(Ordering::Acquire);
debug!("Lock-free iteration {}", current_iter);
new_facts.clear();
self.execute_rules_lockfree(Arc::clone(&facts), Arc::clone(&new_facts), current_iter)?;
if new_facts.is_empty() {
converged.store(true, Ordering::Release);
break;
}
let merge_count = facts.merge_from(&new_facts);
debug!(
"Merged {} new facts in iteration {}",
merge_count, current_iter
);
iteration.fetch_add(1, Ordering::AcqRel);
}
let results = facts.to_vec();
info!(
"Lock-free inference completed after {} iterations with {} facts",
iteration.load(Ordering::Acquire),
results.len()
);
Ok(results)
}
fn execute_rules_lockfree(
&self,
facts: Arc<LockFreeFactSet>,
new_facts: Arc<LockFreeFactSet>,
iteration: usize,
) -> Result<()> {
let rules_per_worker = (self.rules.len() + self.num_workers - 1) / self.num_workers;
LOCKFREE_ACTIVE_WORKERS.set(self.num_workers as f64);
scirs2_core::parallel_ops::par_scope(|scope| {
for worker_id in 0..self.num_workers {
let rules = &self.rules;
let facts = Arc::clone(&facts);
let new_facts = Arc::clone(&new_facts);
let start_idx = worker_id * rules_per_worker;
let end_idx = std::cmp::min(start_idx + rules_per_worker, rules.len());
scope.spawn(move |_| {
debug!(
"Worker {} processing rules {}-{} in iteration {}",
worker_id, start_idx, end_idx, iteration
);
for rule_idx in start_idx..end_idx {
if let Some(rule) = rules.get(rule_idx) {
Self::apply_rule_lockfree(rule, &facts, &new_facts);
}
}
});
}
});
LOCKFREE_ACTIVE_WORKERS.set(0.0);
Ok(())
}
fn apply_rule_lockfree(rule: &Rule, facts: &LockFreeFactSet, new_facts: &LockFreeFactSet) {
let all_facts = facts.to_vec();
for fact in &all_facts {
if let Some(substitution) = Self::try_match_body(rule, fact, &all_facts) {
for head_atom in &rule.head {
if let Some(derived_fact) = Self::apply_substitution(head_atom, &substitution) {
if new_facts.insert(derived_fact) {
LOCKFREE_RULE_APPLICATIONS.inc();
}
}
}
}
}
}
fn try_match_body(
rule: &Rule,
trigger_fact: &RuleAtom,
_all_facts: &[RuleAtom],
) -> Option<HashMap<String, Term>> {
if rule.body.len() != 1 {
return None;
}
let body_atom = &rule.body[0];
Self::unify_atoms(body_atom, trigger_fact)
}
fn unify_atoms(pattern: &RuleAtom, fact: &RuleAtom) -> Option<HashMap<String, Term>> {
match (pattern, fact) {
(
RuleAtom::Triple {
subject: ps,
predicate: pp,
object: po,
},
RuleAtom::Triple {
subject: fs,
predicate: fp,
object: fo,
},
) => {
let mut sub = HashMap::new();
if !Self::unify_term(ps, fs, &mut sub) {
return None;
}
if !Self::unify_term(pp, fp, &mut sub) {
return None;
}
if !Self::unify_term(po, fo, &mut sub) {
return None;
}
Some(sub)
}
_ => None,
}
}
fn unify_term(pattern: &Term, fact: &Term, substitution: &mut HashMap<String, Term>) -> bool {
match pattern {
Term::Variable(var) => {
if let Some(existing) = substitution.get(var) {
Self::terms_equal(existing, fact)
} else {
substitution.insert(var.clone(), fact.clone());
true
}
}
_ => Self::terms_equal(pattern, fact),
}
}
fn terms_equal(t1: &Term, t2: &Term) -> bool {
match (t1, t2) {
(Term::Constant(c1), Term::Constant(c2)) => c1 == c2,
(Term::Literal(l1), Term::Literal(l2)) => l1 == l2,
(Term::Variable(v1), Term::Variable(v2)) => v1 == v2,
_ => false,
}
}
fn apply_substitution(
atom: &RuleAtom,
substitution: &HashMap<String, Term>,
) -> Option<RuleAtom> {
match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
let new_subject = Self::substitute_term(subject, substitution);
let new_predicate = Self::substitute_term(predicate, substitution);
let new_object = Self::substitute_term(object, substitution);
Some(RuleAtom::Triple {
subject: new_subject,
predicate: new_predicate,
object: new_object,
})
}
_ => None,
}
}
fn substitute_term(term: &Term, substitution: &HashMap<String, Term>) -> Term {
match term {
Term::Variable(var) => substitution
.get(var)
.cloned()
.unwrap_or_else(|| term.clone()),
_ => term.clone(),
}
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn set_max_iterations(&mut self, max_iterations: usize) {
self.max_iterations = max_iterations;
}
}
#[derive(Debug)]
struct FactSlot {
fact: Option<RuleAtom>,
version: u64,
}
#[derive(Debug)]
struct LockFreeFactSet {
facts: Arc<std::sync::RwLock<Vec<FactSlot>>>,
insertion_count: Arc<AtomicU64>,
version: Arc<AtomicU64>,
capacity: Arc<AtomicUsize>,
size: Arc<AtomicUsize>,
}
impl LockFreeFactSet {
fn new() -> Self {
Self::with_capacity(1024)
}
fn with_capacity(capacity: usize) -> Self {
let mut facts = Vec::with_capacity(capacity);
for _ in 0..capacity {
facts.push(FactSlot {
fact: None,
version: 0,
});
}
Self {
facts: Arc::new(std::sync::RwLock::new(facts)),
insertion_count: Arc::new(AtomicU64::new(0)),
version: Arc::new(AtomicU64::new(0)),
capacity: Arc::new(AtomicUsize::new(capacity)),
size: Arc::new(AtomicUsize::new(0)),
}
}
fn insert(&self, fact: RuleAtom) -> bool {
let hash = self.hash_fact(&fact);
let capacity = self.capacity.load(Ordering::Acquire);
let mut attempt = 0;
let max_attempts = 100;
while attempt < max_attempts {
let slot_idx = (hash as usize + attempt) % capacity;
let mut facts = self.facts.write().expect("lock should not be poisoned");
if facts[slot_idx].fact.is_none() {
facts[slot_idx].fact = Some(fact.clone());
facts[slot_idx].version += 1;
drop(facts);
self.size.fetch_add(1, Ordering::AcqRel);
self.insertion_count.fetch_add(1, Ordering::AcqRel);
self.version.fetch_add(1, Ordering::AcqRel);
LOCKFREE_FACT_INSERTIONS.inc();
return true;
}
if let Some(ref existing) = facts[slot_idx].fact {
if *existing == fact {
return false;
}
}
drop(facts);
attempt += 1;
LOCKFREE_CAS_RETRIES.inc();
}
let mut facts = self.facts.write().expect("lock should not be poisoned");
for slot in facts.iter() {
if let Some(ref existing) = slot.fact {
if *existing == fact {
return false;
}
}
}
for slot in facts.iter_mut() {
if slot.fact.is_none() {
slot.fact = Some(fact);
slot.version += 1;
drop(facts);
self.size.fetch_add(1, Ordering::AcqRel);
self.insertion_count.fetch_add(1, Ordering::AcqRel);
self.version.fetch_add(1, Ordering::AcqRel);
LOCKFREE_FACT_INSERTIONS.inc();
return true;
}
}
let new_slot = FactSlot {
fact: Some(fact),
version: 1,
};
facts.push(new_slot);
self.capacity.fetch_add(1, Ordering::AcqRel);
drop(facts);
self.size.fetch_add(1, Ordering::AcqRel);
self.insertion_count.fetch_add(1, Ordering::AcqRel);
self.version.fetch_add(1, Ordering::AcqRel);
LOCKFREE_FACT_INSERTIONS.inc();
true
}
fn hash_fact(&self, fact: &RuleAtom) -> u64 {
match fact {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
let mut hash = 14695981039346656037u64;
hash = hash.wrapping_mul(1099511628211);
hash ^= self.hash_term(subject);
hash = hash.wrapping_mul(1099511628211);
hash ^= self.hash_term(predicate);
hash = hash.wrapping_mul(1099511628211);
hash ^= self.hash_term(object);
hash
}
_ => 0,
}
}
#[allow(clippy::only_used_in_recursion)]
fn hash_term(&self, term: &Term) -> u64 {
match term {
Term::Constant(s) | Term::Variable(s) | Term::Literal(s) => {
let mut hash = 0u64;
for byte in s.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
}
hash
}
Term::Function { name, args } => {
let mut hash = self.hash_term(&Term::Constant(name.clone()));
for arg in args {
hash = hash.wrapping_mul(31).wrapping_add(self.hash_term(arg));
}
hash
}
}
}
fn is_empty(&self) -> bool {
self.size.load(Ordering::Acquire) == 0
}
fn clear(&self) {
let mut facts = self.facts.write().expect("lock should not be poisoned");
for slot in facts.iter_mut() {
slot.fact = None;
slot.version += 1;
}
drop(facts);
self.size.store(0, Ordering::Release);
self.insertion_count.store(0, Ordering::Release);
self.version.fetch_add(1, Ordering::AcqRel);
}
fn merge_from(&self, other: &LockFreeFactSet) -> usize {
let other_facts = other.to_vec();
let mut merge_count = 0;
for fact in other_facts {
if self.insert(fact) {
merge_count += 1;
}
}
merge_count
}
fn to_vec(&self) -> Vec<RuleAtom> {
let facts = self.facts.read().expect("lock should not be poisoned");
facts.iter().filter_map(|slot| slot.fact.clone()).collect()
}
#[allow(dead_code)]
fn insertion_count(&self) -> u64 {
self.insertion_count.load(Ordering::Acquire)
}
#[allow(dead_code)]
fn len(&self) -> usize {
self.size.load(Ordering::Acquire)
}
#[allow(dead_code)]
fn contains(&self, fact: &RuleAtom) -> bool {
let hash = self.hash_fact(fact);
let capacity = self.capacity.load(Ordering::Acquire);
let facts = self.facts.read().expect("lock should not be poisoned");
let primary_idx = (hash as usize) % capacity;
if let Some(ref existing) = facts[primary_idx].fact {
if *existing == *fact {
return true;
}
}
for i in 1..capacity {
let idx = (primary_idx + i) % capacity;
if let Some(ref existing) = facts[idx].fact {
if *existing == *fact {
return true;
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_lockfree_basic_inference() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = LockFreeEngine::new();
engine.add_rule(Rule {
name: "test_rule".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Variable("Y".to_string()),
}],
});
let facts = vec![RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
}];
let results = engine.infer(&facts)?;
assert!(results.len() >= 2);
assert!(results.iter().any(|f| matches!(f, RuleAtom::Triple {
subject: Term::Constant(s),
predicate: Term::Constant(p),
object: Term::Constant(o)
} if s == "a" && p == "q" && o == "b")));
Ok(())
}
#[test]
fn test_lockfree_multiple_workers() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = LockFreeEngine::with_config(4, 100);
for i in 0..10 {
engine.add_rule(Rule {
name: format!("rule_{i}"),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(format!("p{i}")),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(format!("q{i}")),
object: Term::Variable("Y".to_string()),
}],
});
}
let mut facts = Vec::new();
for i in 0..10 {
facts.push(RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant(format!("p{i}")),
object: Term::Constant("b".to_string()),
});
}
let results = engine.infer(&facts)?;
assert!(results.len() >= 20); Ok(())
}
#[test]
fn test_lockfree_fact_set() {
let fact_set = LockFreeFactSet::new();
assert!(fact_set.is_empty());
let fact1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let fact2 = RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Constant("d".to_string()),
};
assert!(fact_set.insert(fact1.clone()));
assert!(!fact_set.insert(fact1)); assert!(fact_set.insert(fact2));
assert_eq!(fact_set.insertion_count(), 2);
let vec = fact_set.to_vec();
assert_eq!(vec.len(), 2);
}
#[test]
fn test_lockfree_merge() {
let set1 = LockFreeFactSet::new();
let set2 = LockFreeFactSet::new();
let fact1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let fact2 = RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Constant("d".to_string()),
};
set1.insert(fact1.clone());
set2.insert(fact2);
set2.insert(fact1);
let merge_count = set1.merge_from(&set2);
assert_eq!(merge_count, 1);
let vec = set1.to_vec();
assert_eq!(vec.len(), 2);
}
#[test]
fn test_lockfree_convergence() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = LockFreeEngine::new();
engine.add_rule(Rule {
name: "transitive".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Variable("Y".to_string()),
}],
});
let facts = vec![
RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
},
RuleAtom::Triple {
subject: Term::Constant("b".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("c".to_string()),
},
];
let results = engine.infer(&facts)?;
assert!(results.len() >= 2);
Ok(())
}
#[test]
fn test_lockfree_empty_rules() -> Result<(), Box<dyn std::error::Error>> {
let engine = LockFreeEngine::new();
let facts = vec![RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
}];
let results = engine.infer(&facts)?;
assert_eq!(results.len(), 1);
Ok(())
}
#[test]
fn test_lockfree_configuration() {
let mut engine = LockFreeEngine::with_config(8, 50);
assert_eq!(engine.num_workers(), 8);
engine.set_max_iterations(200);
engine.set_work_stealing(false);
assert_eq!(engine.max_iterations, 200);
assert!(!engine.work_stealing);
}
#[test]
fn test_lockfree_performance_scaling() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = LockFreeEngine::with_config(4, 100);
for i in 0..20 {
engine.add_rule(Rule {
name: format!("rule_{i}"),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(format!("p{i}")),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(format!("q{i}")),
object: Term::Variable("Y".to_string()),
}],
});
}
let mut facts = Vec::new();
for i in 0..20 {
for j in 0..10 {
facts.push(RuleAtom::Triple {
subject: Term::Constant(format!("entity_{j}")),
predicate: Term::Constant(format!("p{i}")),
object: Term::Constant(format!("value_{j}")),
});
}
}
let start = std::time::Instant::now();
let results = engine.infer(&facts)?;
let duration = start.elapsed();
assert!(results.len() >= 200); assert!(duration.as_secs() < 5); Ok(())
}
#[test]
fn test_rule_atom_equality() {
let atom1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let atom2 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let atom3 = RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
assert_eq!(atom1, atom2);
assert_ne!(atom1, atom3);
let mut set = HashSet::new();
assert!(set.insert(atom1.clone()));
assert!(!set.insert(atom2)); assert!(set.insert(atom3));
assert_eq!(set.len(), 2);
}
#[test]
fn test_lockfree_fact_set_contains() {
let fact_set = LockFreeFactSet::with_capacity(64);
let fact1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let fact2 = RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Constant("d".to_string()),
};
assert!(fact_set.insert(fact1.clone()));
assert!(fact_set.contains(&fact1));
assert!(!fact_set.contains(&fact2));
assert!(fact_set.insert(fact2.clone()));
assert!(fact_set.contains(&fact2));
}
#[test]
fn test_lockfree_fact_set_size() {
let fact_set = LockFreeFactSet::with_capacity(64);
assert_eq!(fact_set.len(), 0);
for i in 0..10 {
fact_set.insert(RuleAtom::Triple {
subject: Term::Constant(format!("s{i}")),
predicate: Term::Constant("p".to_string()),
object: Term::Constant(format!("o{i}")),
});
}
assert_eq!(fact_set.len(), 10);
}
#[test]
fn test_lockfree_hash_distribution() {
let fact_set = LockFreeFactSet::with_capacity(128);
for i in 0..50 {
let fact = RuleAtom::Triple {
subject: Term::Constant(format!("entity_{i}")),
predicate: Term::Constant(format!("relation_{}", i % 10)),
object: Term::Constant(format!("value_{i}")),
};
let hash = fact_set.hash_fact(&fact);
assert!(hash > 0); }
}
#[test]
fn test_lockfree_concurrent_insertion() -> Result<(), Box<dyn std::error::Error>> {
use std::thread;
let fact_set = Arc::new(LockFreeFactSet::with_capacity(256));
let mut handles = vec![];
for thread_id in 0..4 {
let fact_set = Arc::clone(&fact_set);
let handle = thread::spawn(move || {
for i in 0..25 {
let fact = RuleAtom::Triple {
subject: Term::Constant(format!("thread_{}_entity_{}", thread_id, i)),
predicate: Term::Constant("p".to_string()),
object: Term::Constant(format!("value_{i}")),
};
fact_set.insert(fact);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().map_err(|_| "thread panicked")?;
}
assert_eq!(fact_set.len(), 100);
Ok(())
}
#[test]
fn test_lockfree_duplicate_handling_concurrent() -> Result<(), Box<dyn std::error::Error>> {
use std::thread;
let fact_set = Arc::new(LockFreeFactSet::with_capacity(64));
let mut handles = vec![];
for _ in 0..4 {
let fact_set = Arc::clone(&fact_set);
let handle = thread::spawn(move || {
let fact = RuleAtom::Triple {
subject: Term::Constant("shared".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("value".to_string()),
};
fact_set.insert(fact);
});
handles.push(handle);
}
for handle in handles {
handle.join().map_err(|_| "thread panicked")?;
}
assert_eq!(fact_set.len(), 1);
Ok(())
}
#[test]
fn test_lockfree_capacity_growth() {
let fact_set = LockFreeFactSet::with_capacity(4);
for i in 0..20 {
fact_set.insert(RuleAtom::Triple {
subject: Term::Constant(format!("s{i}")),
predicate: Term::Constant("p".to_string()),
object: Term::Constant(format!("o{i}")),
});
}
assert_eq!(fact_set.len(), 20);
}
}