use crate::{Rule, RuleAtom, RuleEngine, Term};
use anyhow::Result;
use scirs2_core::metrics::Timer;
use std::collections::{HashMap, HashSet};
use tracing::{debug, info};
lazy_static::lazy_static! {
static ref QUERY_DIRECT_TIMER: Timer = Timer::new("sparql_query_direct".to_string());
static ref QUERY_FORWARD_TIMER: Timer = Timer::new("sparql_query_forward".to_string());
static ref QUERY_BACKWARD_TIMER: Timer = Timer::new("sparql_query_backward".to_string());
}
#[derive(Debug, Clone, PartialEq)]
pub enum QueryMode {
Direct,
ForwardReasoning,
BackwardReasoning,
Hybrid,
LazyMaterialization,
}
#[derive(Debug, Clone)]
pub struct QueryPattern {
pub subject: Option<String>,
pub predicate: Option<String>,
pub object: Option<String>,
}
impl QueryPattern {
pub fn new(subject: Option<String>, predicate: Option<String>, object: Option<String>) -> Self {
Self {
subject,
predicate,
object,
}
}
pub fn matches(&self, atom: &RuleAtom) -> bool {
if let RuleAtom::Triple {
subject,
predicate,
object,
} = atom
{
self.matches_term(&self.subject, subject)
&& self.matches_term(&self.predicate, predicate)
&& self.matches_term(&self.object, object)
} else {
false
}
}
fn matches_term(&self, pattern: &Option<String>, term: &Term) -> bool {
match pattern {
None => true, Some(pat) => match term {
Term::Constant(c) => c == pat,
Term::Literal(l) => l == pat,
Term::Variable(_) => true, _ => false,
},
}
}
}
pub struct SparqlRuleIntegration {
engine: RuleEngine,
mode: QueryMode,
pattern_rules: HashMap<String, Vec<String>>,
stats: IntegrationStats,
materialized_cache: Option<Vec<RuleAtom>>,
facts_hash: u64,
}
impl SparqlRuleIntegration {
pub fn new(engine: RuleEngine) -> Self {
Self {
engine,
mode: QueryMode::Hybrid,
pattern_rules: HashMap::new(),
stats: IntegrationStats::default(),
materialized_cache: None,
facts_hash: 0,
}
}
pub fn set_mode(&mut self, mode: QueryMode) {
info!("Setting query mode to {:?}", mode);
self.mode = mode;
}
pub fn get_mode(&self) -> &QueryMode {
&self.mode
}
pub fn register_pattern_rule(&mut self, pattern: String, rule_name: String) {
debug!("Registering rule '{}' for pattern '{}'", rule_name, pattern);
self.pattern_rules
.entry(pattern)
.or_default()
.push(rule_name);
}
pub fn query_with_reasoning(&mut self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
self.stats.total_queries += 1;
match self.mode {
QueryMode::Direct => self.query_direct(patterns),
QueryMode::ForwardReasoning => self.query_with_forward(patterns),
QueryMode::BackwardReasoning => self.query_with_backward(patterns),
QueryMode::Hybrid => self.query_hybrid(patterns),
QueryMode::LazyMaterialization => self.query_lazy(patterns),
}
}
fn query_direct(&self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
let _timer = QUERY_DIRECT_TIMER.start();
let facts = self.engine.get_facts();
if patterns.is_empty() {
return Ok(Vec::new());
}
if patterns.len() == 1 {
let pattern = &patterns[0];
let results: Vec<RuleAtom> = facts
.into_iter()
.filter(|fact| pattern.matches(fact))
.collect();
if results.len() > 100 {
use crate::simd_ops::BatchProcessor;
let processor = BatchProcessor::default();
return Ok(processor.deduplicate(results));
}
return Ok(results);
}
let mut results: Vec<RuleAtom> = facts
.into_iter()
.filter(|fact| {
for pattern in patterns {
if pattern.matches(fact) {
return true;
}
}
false
})
.collect();
if results.len() > 100 {
use crate::simd_ops::SimdMatcher;
let matcher = SimdMatcher::new();
matcher.batch_deduplicate(&mut results);
}
Ok(results)
}
fn query_with_forward(&mut self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
let _timer = QUERY_FORWARD_TIMER.start();
let materialized = self.get_materialized_facts()?;
let results = materialized
.into_iter()
.filter(|fact| patterns.iter().any(|p| p.matches(fact)))
.collect();
self.stats.forward_reasoning_queries += 1;
Ok(results)
}
fn query_with_backward(&mut self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
let _timer = QUERY_BACKWARD_TIMER.start();
let mut results = Vec::new();
for pattern in patterns {
if let Some(goal) = self.pattern_to_goal(pattern) {
if self.engine.backward_chain(&goal)? {
results.push(goal);
}
} else {
return self.query_with_forward(patterns);
}
}
self.stats.backward_reasoning_queries += 1;
Ok(results)
}
fn query_hybrid(&mut self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
let has_variables = patterns
.iter()
.any(|p| p.subject.is_none() || p.predicate.is_none() || p.object.is_none());
if has_variables {
self.query_with_forward(patterns)
} else {
self.query_with_backward(patterns)
}
}
fn query_lazy(&mut self, patterns: &[QueryPattern]) -> Result<Vec<RuleAtom>> {
let relevant_rules = self.find_relevant_rules(patterns);
let facts = self.engine.get_facts();
let mut results = facts.clone();
for _rule_name in relevant_rules {
results = self.engine.forward_chain(&results)?;
}
let filtered = results
.into_iter()
.filter(|fact| patterns.iter().any(|p| p.matches(fact)))
.collect();
self.stats.lazy_queries += 1;
Ok(filtered)
}
fn find_relevant_rules(&self, patterns: &[QueryPattern]) -> Vec<String> {
let mut relevant = HashSet::new();
for pattern in patterns {
for (pattern_str, rules) in &self.pattern_rules {
if let Some(pred) = &pattern.predicate {
if pattern_str.contains(pred) {
relevant.extend(rules.clone());
}
}
}
}
relevant.into_iter().collect()
}
fn pattern_to_goal(&self, pattern: &QueryPattern) -> Option<RuleAtom> {
if pattern.subject.is_some() && pattern.predicate.is_some() && pattern.object.is_some() {
Some(RuleAtom::Triple {
subject: Term::Constant(
pattern
.subject
.clone()
.expect("subject verified to be Some"),
),
predicate: Term::Constant(
pattern
.predicate
.clone()
.expect("predicate verified to be Some"),
),
object: Term::Constant(pattern.object.clone().expect("object verified to be Some")),
})
} else {
None
}
}
pub fn engine_mut(&mut self) -> &mut RuleEngine {
&mut self.engine
}
pub fn engine(&self) -> &RuleEngine {
&self.engine
}
pub fn get_stats(&self) -> &IntegrationStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = IntegrationStats::default();
}
fn compute_facts_hash(&self, facts: &[RuleAtom]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
facts.len().hash(&mut hasher);
let sample_size = facts.len().min(10);
for fact in facts.iter().take(sample_size) {
format!("{:?}", fact).hash(&mut hasher);
}
if facts.len() > sample_size {
for fact in facts.iter().skip(facts.len() - sample_size) {
format!("{:?}", fact).hash(&mut hasher);
}
}
hasher.finish()
}
pub fn invalidate_cache(&mut self) {
self.materialized_cache = None;
self.facts_hash = 0;
debug!("Materialization cache invalidated");
}
fn get_materialized_facts(&mut self) -> Result<Vec<RuleAtom>> {
let facts = self.engine.get_facts();
let current_hash = self.compute_facts_hash(&facts);
if let Some(ref cached) = self.materialized_cache {
if current_hash == self.facts_hash {
debug!("Using cached materialized facts ({} facts)", cached.len());
self.stats.cache_hits += 1;
return Ok(cached.clone());
}
}
debug!("Cache miss - materializing facts");
self.stats.cache_misses += 1;
let materialized = self.engine.forward_chain(&facts)?;
self.materialized_cache = Some(materialized.clone());
self.facts_hash = current_hash;
Ok(materialized)
}
}
#[derive(Debug, Clone, Default)]
pub struct IntegrationStats {
pub total_queries: usize,
pub forward_reasoning_queries: usize,
pub backward_reasoning_queries: usize,
pub lazy_queries: usize,
pub cache_hits: usize,
pub cache_misses: usize,
}
impl std::fmt::Display for IntegrationStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Total: {}, Forward: {}, Backward: {}, Lazy: {}, Cache(hits/misses): {}/{}",
self.total_queries,
self.forward_reasoning_queries,
self.backward_reasoning_queries,
self.lazy_queries,
self.cache_hits,
self.cache_misses
)
}
}
pub struct QueryRewriter {
rules: Vec<Rule>,
rewrites: usize,
}
impl QueryRewriter {
pub fn new(rules: Vec<Rule>) -> Self {
Self { rules, rewrites: 0 }
}
pub fn can_rewrite(&self, patterns: &[QueryPattern]) -> bool {
for pattern in patterns {
for rule in &self.rules {
if self.rule_derives_pattern(rule, pattern) {
return true;
}
}
}
false
}
fn rule_derives_pattern(&self, rule: &Rule, pattern: &QueryPattern) -> bool {
rule.head.iter().any(|atom| pattern.matches(atom))
}
pub fn rewrite(&mut self, patterns: Vec<QueryPattern>) -> Vec<QueryPattern> {
self.rewrites += 1;
patterns
}
pub fn get_rewrite_count(&self) -> usize {
self.rewrites
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_pattern_matching() {
let pattern = QueryPattern::new(
Some("john".to_string()),
Some("knows".to_string()),
None, );
let atom = RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("knows".to_string()),
object: Term::Constant("mary".to_string()),
};
assert!(pattern.matches(&atom));
}
#[test]
fn test_sparql_integration_creation() {
let engine = RuleEngine::new();
let integration = SparqlRuleIntegration::new(engine);
assert_eq!(*integration.get_mode(), QueryMode::Hybrid);
}
#[test]
fn test_query_mode_setting() {
let engine = RuleEngine::new();
let mut integration = SparqlRuleIntegration::new(engine);
integration.set_mode(QueryMode::ForwardReasoning);
assert_eq!(*integration.get_mode(), QueryMode::ForwardReasoning);
}
#[test]
fn test_pattern_rule_registration() -> Result<(), Box<dyn std::error::Error>> {
let engine = RuleEngine::new();
let mut integration = SparqlRuleIntegration::new(engine);
integration.register_pattern_rule("?s rdf:type ?o".to_string(), "typing_rule".to_string());
assert_eq!(integration.pattern_rules.len(), 1);
Ok(())
}
#[test]
fn test_direct_query() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = RuleEngine::new();
engine.add_fact(RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("knows".to_string()),
object: Term::Constant("mary".to_string()),
});
let mut integration = SparqlRuleIntegration::new(engine);
integration.set_mode(QueryMode::Direct);
let patterns = vec![QueryPattern::new(
Some("john".to_string()),
Some("knows".to_string()),
None,
)];
let results = integration.query_with_reasoning(&patterns)?;
assert_eq!(results.len(), 1);
Ok(())
}
#[test]
fn test_query_rewriter() {
let rule = Rule {
name: "test_rule".to_string(),
body: vec![],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("derived".to_string()),
object: Term::Variable("Y".to_string()),
}],
};
let rewriter = QueryRewriter::new(vec![rule]);
let pattern = QueryPattern::new(None, Some("derived".to_string()), None);
assert!(rewriter.can_rewrite(&[pattern]));
}
#[test]
fn test_integration_stats() {
let engine = RuleEngine::new();
let mut integration = SparqlRuleIntegration::new(engine);
let patterns = vec![QueryPattern::new(None, None, None)];
integration.set_mode(QueryMode::ForwardReasoning);
let _ = integration.query_with_reasoning(&patterns);
let stats = integration.get_stats();
assert_eq!(stats.total_queries, 1);
assert_eq!(stats.forward_reasoning_queries, 1);
}
}