use crate::{Rule, RuleAtom, Term};
use anyhow::Result;
use scirs2_core::metrics::{Counter, Gauge, Timer};
use std::collections::HashMap;
type GpuContext = ();
pub struct GpuRuleMatcher {
rules: Vec<Rule>,
#[allow(dead_code)]
gpu_context: Option<GpuContext>,
fact_hashes: HashMap<u64, usize>,
pattern_cache: Vec<PatternDescriptor>,
metrics: MatcherMetrics,
batch_size: usize,
use_gpu: bool,
}
#[derive(Debug, Clone)]
struct PatternDescriptor {
rule_idx: usize,
pattern_idx: usize,
pattern_type: u8,
subject_type: u8,
predicate_type: u8,
object_type: u8,
hash: u64,
}
pub struct MatcherMetrics {
total_matches: Counter,
gpu_matches: Counter,
cpu_matches: Counter,
#[allow(dead_code)]
gpu_memory: Gauge,
#[allow(dead_code)]
match_timer: Timer,
}
impl MatcherMetrics {
fn new() -> Self {
Self {
total_matches: Counter::new("gpu_total_matches".to_string()),
gpu_matches: Counter::new("gpu_matches".to_string()),
cpu_matches: Counter::new("gpu_cpu_fallback_matches".to_string()),
gpu_memory: Gauge::new("gpu_memory_usage".to_string()),
match_timer: Timer::new("gpu_match_time".to_string()),
}
}
}
impl GpuRuleMatcher {
pub fn new() -> Result<Self> {
let gpu_context: Option<GpuContext> = None;
let use_gpu = false;
Ok(Self {
rules: Vec::new(),
gpu_context,
fact_hashes: HashMap::new(),
pattern_cache: Vec::new(),
metrics: MatcherMetrics::new(),
batch_size: 1024,
use_gpu,
})
}
pub fn add_rule(&mut self, rule: Rule) {
let rule_idx = self.rules.len();
for (pattern_idx, atom) in rule.body.iter().enumerate() {
let descriptor = self.build_pattern_descriptor(rule_idx, pattern_idx, atom);
self.pattern_cache.push(descriptor);
}
self.rules.push(rule);
}
pub fn add_rules(&mut self, rules: Vec<Rule>) {
for rule in rules {
self.add_rule(rule);
}
}
fn build_pattern_descriptor(
&self,
rule_idx: usize,
pattern_idx: usize,
atom: &RuleAtom,
) -> PatternDescriptor {
match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
let subject_type = self.term_type(subject);
let predicate_type = self.term_type(predicate);
let object_type = self.term_type(object);
let hash = self.compute_pattern_hash(subject, predicate, object);
PatternDescriptor {
rule_idx,
pattern_idx,
pattern_type: 0, subject_type,
predicate_type,
object_type,
hash,
}
}
RuleAtom::Builtin { .. } => PatternDescriptor {
rule_idx,
pattern_idx,
pattern_type: 1, subject_type: 0,
predicate_type: 0,
object_type: 0,
hash: 0,
},
_ => PatternDescriptor {
rule_idx,
pattern_idx,
pattern_type: 2, subject_type: 0,
predicate_type: 0,
object_type: 0,
hash: 0,
},
}
}
fn term_type(&self, term: &Term) -> u8 {
match term {
Term::Constant(_) => 0,
Term::Variable(_) => 1,
Term::Literal(_) => 2,
Term::Function { .. } => 3,
}
}
fn compute_pattern_hash(&self, subject: &Term, predicate: &Term, object: &Term) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
if let Term::Constant(s) = subject {
hash = self.fnv1a_hash(hash, s.as_bytes());
}
if let Term::Constant(p) = predicate {
hash = self.fnv1a_hash(hash, p.as_bytes());
}
if let Term::Constant(o) = object {
hash = self.fnv1a_hash(hash, o.as_bytes());
}
hash
}
fn fnv1a_hash(&self, hash: u64, data: &[u8]) -> u64 {
let mut h = hash;
for &byte in data {
h ^= byte as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
pub fn match_facts(&mut self, facts: &[RuleAtom]) -> Result<Vec<RuleMatch>> {
self.metrics.total_matches.inc();
self.build_fact_index(facts);
if self.use_gpu && facts.len() >= self.batch_size {
self.metrics.gpu_matches.inc();
self.gpu_match_facts(facts)
} else {
self.metrics.cpu_matches.inc();
self.cpu_match_facts(facts)
}
}
fn build_fact_index(&mut self, facts: &[RuleAtom]) {
self.fact_hashes.clear();
for (idx, fact) in facts.iter().enumerate() {
if let RuleAtom::Triple {
subject,
predicate,
object,
} = fact
{
let hash = self.compute_pattern_hash(subject, predicate, object);
self.fact_hashes.insert(hash, idx);
}
}
}
fn gpu_match_facts(&mut self, facts: &[RuleAtom]) -> Result<Vec<RuleMatch>> {
self.cpu_match_facts(facts)
}
fn cpu_match_facts(&self, facts: &[RuleAtom]) -> Result<Vec<RuleMatch>> {
let mut matches = Vec::new();
let chunk_size = 100;
let chunks: Vec<_> = facts.chunks(chunk_size).collect();
for chunk in chunks {
for (rule_idx, rule) in self.rules.iter().enumerate() {
for fact in chunk {
if self.matches_pattern(&rule.body, fact) {
matches.push(RuleMatch {
rule_idx,
fact_idx: 0, substitutions: HashMap::new(),
confidence: 1.0,
});
}
}
}
}
Ok(matches)
}
fn matches_pattern(&self, body: &[RuleAtom], fact: &RuleAtom) -> bool {
for pattern in body {
if self.atom_matches(pattern, fact) {
return true;
}
}
false
}
fn atom_matches(&self, pattern: &RuleAtom, fact: &RuleAtom) -> bool {
match (pattern, fact) {
(
RuleAtom::Triple {
subject: ps,
predicate: pp,
object: po,
},
RuleAtom::Triple {
subject: fs,
predicate: fp,
object: fo,
},
) => {
self.term_matches(ps, fs) && self.term_matches(pp, fp) && self.term_matches(po, fo)
}
_ => false,
}
}
fn term_matches(&self, pattern: &Term, fact: &Term) -> bool {
match pattern {
Term::Variable(_) => true, Term::Constant(pc) => match fact {
Term::Constant(fc) => pc == fc,
_ => false,
},
Term::Literal(pl) => match fact {
Term::Literal(fl) => pl == fl,
_ => false,
},
_ => false,
}
}
#[allow(dead_code)]
fn facts_to_gpu_format(&self, facts: &[RuleAtom]) -> Result<Vec<f32>> {
let mut data = Vec::new();
for fact in facts {
if let RuleAtom::Triple {
subject,
predicate,
object,
} = fact
{
data.push(self.term_to_float(subject));
data.push(self.term_to_float(predicate));
data.push(self.term_to_float(object));
}
}
Ok(data)
}
#[allow(dead_code)]
fn patterns_to_gpu_format(&self) -> Result<Vec<f32>> {
let mut data = Vec::new();
for pattern in &self.pattern_cache {
data.push(pattern.rule_idx as f32);
data.push(pattern.pattern_idx as f32);
data.push(pattern.pattern_type as f32);
data.push(pattern.subject_type as f32);
data.push(pattern.predicate_type as f32);
data.push(pattern.object_type as f32);
data.push(pattern.hash as f32);
}
Ok(data)
}
#[allow(dead_code)]
fn term_to_float(&self, term: &Term) -> f32 {
match term {
Term::Constant(s) => {
let hash = self.fnv1a_hash(0xcbf29ce484222325, s.as_bytes());
(hash % 1000000) as f32
}
Term::Variable(_) => -1.0, Term::Literal(s) => {
let hash = self.fnv1a_hash(0xcbf29ce484222325, s.as_bytes());
(hash % 1000000) as f32 + 1000000.0
}
_ => 0.0,
}
}
pub fn get_metrics(&self) -> &MatcherMetrics {
&self.metrics
}
pub fn set_batch_size(&mut self, size: usize) {
self.batch_size = size;
}
pub fn is_gpu_available(&self) -> bool {
self.use_gpu
}
pub fn batch_match(&mut self, fact_sets: &[Vec<RuleAtom>]) -> Result<Vec<Vec<RuleMatch>>> {
let mut results = Vec::new();
for facts in fact_sets {
let matches = self.match_facts(facts)?;
results.push(matches);
}
Ok(results)
}
}
#[derive(Debug, Clone)]
pub struct RuleMatch {
pub rule_idx: usize,
pub fact_idx: usize,
pub substitutions: HashMap<String, Term>,
pub confidence: f64,
}
impl Default for GpuRuleMatcher {
fn default() -> Self {
Self::new().expect("GpuRuleMatcher::new should not fail")
}
}
pub struct GpuForwardChainer {
matcher: GpuRuleMatcher,
facts: Vec<RuleAtom>,
max_iterations: usize,
}
impl GpuForwardChainer {
pub fn new() -> Result<Self> {
Ok(Self {
matcher: GpuRuleMatcher::new()?,
facts: Vec::new(),
max_iterations: 100,
})
}
pub fn add_rules(&mut self, rules: Vec<Rule>) {
self.matcher.add_rules(rules);
}
pub fn infer(&mut self, initial_facts: &[RuleAtom]) -> Result<Vec<RuleAtom>> {
self.facts = initial_facts.to_vec();
let mut iteration = 0;
loop {
if iteration >= self.max_iterations {
break;
}
let matches = self.matcher.match_facts(&self.facts)?;
if matches.is_empty() {
break;
}
let new_facts = Vec::new();
for _mat in &matches {
}
if new_facts.is_empty() {
break;
}
self.facts.extend(new_facts);
iteration += 1;
}
Ok(self.facts.clone())
}
pub fn set_max_iterations(&mut self, max: usize) {
self.max_iterations = max;
}
}
impl Default for GpuForwardChainer {
fn default() -> Self {
Self::new().expect("GpuForwardChainer::new should not fail")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_rule() -> Rule {
Rule {
name: "test_rule".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Person".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("isHuman".to_string()),
object: Term::Constant("true".to_string()),
}],
}
}
fn create_test_fact() -> RuleAtom {
RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Person".to_string()),
}
}
#[test]
fn test_gpu_matcher_creation() {
let matcher = GpuRuleMatcher::new();
assert!(matcher.is_ok());
}
#[test]
fn test_add_rule() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
let rule = create_test_rule();
matcher.add_rule(rule);
assert_eq!(matcher.rules.len(), 1);
Ok(())
}
#[test]
fn test_pattern_descriptor_creation() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
let atom = RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Variable("P".to_string()),
object: Term::Literal("value".to_string()),
};
let desc = matcher.build_pattern_descriptor(0, 0, &atom);
assert_eq!(desc.pattern_type, 0); assert_eq!(desc.subject_type, 0); assert_eq!(desc.predicate_type, 1); assert_eq!(desc.object_type, 2); Ok(())
}
#[test]
fn test_fnv1a_hash() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
let hash1 = matcher.fnv1a_hash(0xcbf29ce484222325, b"test");
let hash2 = matcher.fnv1a_hash(0xcbf29ce484222325, b"test");
let hash3 = matcher.fnv1a_hash(0xcbf29ce484222325, b"different");
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
Ok(())
}
#[test]
fn test_term_type() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
assert_eq!(matcher.term_type(&Term::Constant("x".to_string())), 0);
assert_eq!(matcher.term_type(&Term::Variable("X".to_string())), 1);
assert_eq!(matcher.term_type(&Term::Literal("lit".to_string())), 2);
Ok(())
}
#[test]
fn test_term_matches() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
assert!(matcher.term_matches(
&Term::Variable("X".to_string()),
&Term::Constant("john".to_string())
));
assert!(matcher.term_matches(
&Term::Constant("john".to_string()),
&Term::Constant("john".to_string())
));
assert!(!matcher.term_matches(
&Term::Constant("john".to_string()),
&Term::Constant("mary".to_string())
));
Ok(())
}
#[test]
fn test_atom_matches() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
let pattern = RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Person".to_string()),
};
let fact1 = RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Person".to_string()),
};
let fact2 = RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("likes".to_string()),
object: Term::Constant("coffee".to_string()),
};
assert!(matcher.atom_matches(&pattern, &fact1));
assert!(!matcher.atom_matches(&pattern, &fact2));
Ok(())
}
#[test]
fn test_cpu_match_facts() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
matcher.add_rule(create_test_rule());
let facts = vec![create_test_fact()];
let matches = matcher.match_facts(&facts)?;
assert!(!matches.is_empty());
Ok(())
}
#[test]
fn test_batch_match() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
matcher.add_rule(create_test_rule());
let fact_sets = vec![vec![create_test_fact()], vec![create_test_fact()]];
let results = matcher.batch_match(&fact_sets)?;
assert_eq!(results.len(), 2);
Ok(())
}
#[test]
fn test_metrics_tracking() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
matcher.add_rule(create_test_rule());
let facts = vec![create_test_fact()];
matcher.match_facts(&facts)?;
let _metrics = matcher.get_metrics();
Ok(())
}
#[test]
fn test_batch_size_setting() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
matcher.set_batch_size(2048);
assert_eq!(matcher.batch_size, 2048);
Ok(())
}
#[test]
fn test_gpu_forward_chainer_creation() {
let chainer = GpuForwardChainer::new();
assert!(chainer.is_ok());
}
#[test]
fn test_gpu_forward_chainer_add_rules() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = GpuForwardChainer::new()?;
chainer.add_rules(vec![create_test_rule()]);
assert_eq!(chainer.matcher.rules.len(), 1);
Ok(())
}
#[test]
fn test_max_iterations_setting() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = GpuForwardChainer::new()?;
chainer.set_max_iterations(50);
assert_eq!(chainer.max_iterations, 50);
Ok(())
}
#[test]
fn test_pattern_cache() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
let rule = create_test_rule();
let body_len = rule.body.len();
matcher.add_rule(rule);
assert_eq!(matcher.pattern_cache.len(), body_len);
Ok(())
}
#[test]
fn test_fact_index_building() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
let facts = vec![create_test_fact()];
matcher.build_fact_index(&facts);
assert!(!matcher.fact_hashes.is_empty());
Ok(())
}
#[test]
fn test_compute_pattern_hash() -> Result<(), Box<dyn std::error::Error>> {
let matcher = GpuRuleMatcher::new()?;
let hash1 = matcher.compute_pattern_hash(
&Term::Constant("john".to_string()),
&Term::Constant("type".to_string()),
&Term::Constant("Person".to_string()),
);
let hash2 = matcher.compute_pattern_hash(
&Term::Constant("john".to_string()),
&Term::Constant("type".to_string()),
&Term::Constant("Person".to_string()),
);
assert_eq!(hash1, hash2);
Ok(())
}
#[test]
fn test_multiple_rules() -> Result<(), Box<dyn std::error::Error>> {
let mut matcher = GpuRuleMatcher::new()?;
let rule1 = Rule {
name: "rule1".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("p1".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("q1".to_string()),
object: Term::Variable("Y".to_string()),
}],
};
let rule2 = Rule {
name: "rule2".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("A".to_string()),
predicate: Term::Constant("p2".to_string()),
object: Term::Variable("B".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("A".to_string()),
predicate: Term::Constant("q2".to_string()),
object: Term::Variable("B".to_string()),
}],
};
matcher.add_rules(vec![rule1, rule2]);
assert_eq!(matcher.rules.len(), 2);
Ok(())
}
}