use crate::{Rule, RuleAtom};
use anyhow::Result;
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use tracing::{debug, info};
type DerivationId = usize;
#[derive(Debug, Clone)]
pub struct DerivationStep {
pub id: DerivationId,
pub fact: RuleAtom,
pub rule: Option<Rule>,
pub premises: Vec<DerivationId>,
pub timestamp: u64,
pub method: DerivationMethod,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DerivationMethod {
Asserted,
ForwardChaining,
BackwardChaining,
Rete,
Rdfs,
Owl,
Swrl,
}
impl fmt::Display for DerivationMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DerivationMethod::Asserted => write!(f, "asserted"),
DerivationMethod::ForwardChaining => write!(f, "forward chaining"),
DerivationMethod::BackwardChaining => write!(f, "backward chaining"),
DerivationMethod::Rete => write!(f, "RETE"),
DerivationMethod::Rdfs => write!(f, "RDFS reasoning"),
DerivationMethod::Owl => write!(f, "OWL reasoning"),
DerivationMethod::Swrl => write!(f, "SWRL reasoning"),
}
}
}
#[derive(Debug, Clone)]
pub struct Explanation {
pub target: RuleAtom,
pub steps: Vec<DerivationStep>,
pub justification: Justification,
pub confidence: f64,
}
impl fmt::Display for Explanation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Explanation for: {:?}", self.target)?;
writeln!(f, "Confidence: {:.2}%", self.confidence * 100.0)?;
writeln!(f, "\nDerivation steps:")?;
for (i, step) in self.steps.iter().enumerate() {
writeln!(f, " {}. {:?} via {}", i + 1, step.fact, step.method)?;
if let Some(rule) = &step.rule {
writeln!(f, " Rule: {}", rule.name)?;
}
}
writeln!(f, "\nJustification:")?;
writeln!(f, "{}", self.justification)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Justification {
pub axioms: Vec<RuleAtom>,
pub rules: Vec<Rule>,
pub chain: Vec<String>,
}
impl fmt::Display for Justification {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " Axioms:")?;
for axiom in &self.axioms {
writeln!(f, " - {:?}", axiom)?;
}
writeln!(f, " Rules:")?;
for rule in &self.rules {
writeln!(f, " - {}", rule.name)?;
}
writeln!(f, " Inference chain:")?;
for (i, step) in self.chain.iter().enumerate() {
writeln!(f, " {}. {}", i + 1, step)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExplanationRequest {
pub target: RuleAtom,
pub explanation_type: ExplanationType,
pub max_depth: usize,
pub include_all_paths: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ExplanationType {
Why,
How,
AllPaths,
Minimal,
}
pub struct ExplanationEngine {
derivations: HashMap<DerivationId, DerivationStep>,
fact_index: HashMap<RuleAtom, Vec<DerivationId>>,
next_id: DerivationId,
timestamp: u64,
}
impl Default for ExplanationEngine {
fn default() -> Self {
Self::new()
}
}
impl ExplanationEngine {
pub fn new() -> Self {
Self {
derivations: HashMap::new(),
fact_index: HashMap::new(),
next_id: 0,
timestamp: 0,
}
}
pub fn record_assertion(&mut self, fact: RuleAtom) -> DerivationId {
let id = self.next_id;
self.next_id += 1;
self.timestamp += 1;
let step = DerivationStep {
id,
fact: fact.clone(),
rule: None,
premises: vec![],
timestamp: self.timestamp,
method: DerivationMethod::Asserted,
};
self.derivations.insert(id, step);
self.fact_index.entry(fact).or_default().push(id);
debug!("Recorded assertion: ID {}", id);
id
}
pub fn record_derivation(
&mut self,
fact: RuleAtom,
rule: Option<Rule>,
premises: Vec<DerivationId>,
method: DerivationMethod,
) -> DerivationId {
let id = self.next_id;
self.next_id += 1;
self.timestamp += 1;
let step = DerivationStep {
id,
fact: fact.clone(),
rule,
premises,
timestamp: self.timestamp,
method: method.clone(),
};
debug!("Recorded derivation: ID {} via {:?}", id, method);
self.derivations.insert(id, step);
self.fact_index.entry(fact).or_default().push(id);
id
}
pub fn explain_why(&self, target: &RuleAtom) -> Result<Explanation> {
info!("Generating 'why' explanation for: {:?}", target);
let request = ExplanationRequest {
target: target.clone(),
explanation_type: ExplanationType::Why,
max_depth: 100,
include_all_paths: false,
};
self.generate_explanation(&request)
}
pub fn explain_how(&self, target: &RuleAtom) -> Result<Explanation> {
info!("Generating 'how' explanation for: {:?}", target);
let request = ExplanationRequest {
target: target.clone(),
explanation_type: ExplanationType::How,
max_depth: 100,
include_all_paths: false,
};
self.generate_explanation(&request)
}
pub fn explain_all_paths(&self, target: &RuleAtom) -> Result<Explanation> {
info!("Finding all derivation paths for: {:?}", target);
let request = ExplanationRequest {
target: target.clone(),
explanation_type: ExplanationType::AllPaths,
max_depth: 100,
include_all_paths: true,
};
self.generate_explanation(&request)
}
pub fn explain_minimal(&self, target: &RuleAtom) -> Result<Explanation> {
info!("Finding minimal justification for: {:?}", target);
let request = ExplanationRequest {
target: target.clone(),
explanation_type: ExplanationType::Minimal,
max_depth: 100,
include_all_paths: false,
};
self.generate_explanation(&request)
}
fn generate_explanation(&self, request: &ExplanationRequest) -> Result<Explanation> {
let derivation_ids = self
.fact_index
.get(&request.target)
.ok_or_else(|| anyhow::anyhow!("Fact not found in derivation history"))?;
if derivation_ids.is_empty() {
return Err(anyhow::anyhow!("No derivations found for fact"));
}
let steps = self.collect_derivation_steps(derivation_ids[0], request.max_depth)?;
let justification = self.build_justification(&steps)?;
let confidence = self.calculate_confidence(&steps, &justification);
Ok(Explanation {
target: request.target.clone(),
steps,
justification,
confidence,
})
}
fn collect_derivation_steps(
&self,
target_id: DerivationId,
max_depth: usize,
) -> Result<Vec<DerivationStep>> {
let mut steps = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back((target_id, 0));
while let Some((id, depth)) = queue.pop_front() {
if depth > max_depth || !visited.insert(id) {
continue;
}
if let Some(step) = self.derivations.get(&id) {
steps.push(step.clone());
for &premise_id in &step.premises {
queue.push_back((premise_id, depth + 1));
}
}
}
steps.sort_by_key(|s| s.timestamp);
Ok(steps)
}
fn build_justification(&self, steps: &[DerivationStep]) -> Result<Justification> {
let mut axioms = Vec::new();
let mut rules = Vec::new();
let mut chain = Vec::new();
for step in steps {
match step.method {
DerivationMethod::Asserted => {
axioms.push(step.fact.clone());
chain.push(format!("Axiom: {:?}", step.fact));
}
_ => {
if let Some(rule) = &step.rule {
if !rules.iter().any(|r: &Rule| r.name == rule.name) {
rules.push(rule.clone());
}
chain.push(format!(
"Apply rule '{}' to derive {:?}",
rule.name, step.fact
));
} else {
chain.push(format!("Derive {:?} via {}", step.fact, step.method));
}
}
}
}
Ok(Justification {
axioms,
rules,
chain,
})
}
fn calculate_confidence(&self, steps: &[DerivationStep], justification: &Justification) -> f64 {
if steps.is_empty() {
return 1.0; }
let max_depth = self.compute_max_depth(steps);
let depth_confidence = 0.95_f64.powi(max_depth as i32);
let axiom_count = justification.axioms.len() as f64;
let total_facts = steps.len() as f64;
let axiom_ratio = if total_facts > 0.0 {
axiom_count / total_facts
} else {
1.0
};
let axiom_confidence = 0.7 + (0.3 * axiom_ratio);
let method_weights: HashMap<DerivationMethod, f64> = [
(DerivationMethod::Asserted, 1.0),
(DerivationMethod::Rdfs, 0.95),
(DerivationMethod::Owl, 0.90),
(DerivationMethod::ForwardChaining, 0.88),
(DerivationMethod::BackwardChaining, 0.88),
(DerivationMethod::Rete, 0.92),
(DerivationMethod::Swrl, 0.85),
]
.iter()
.cloned()
.collect();
let method_confidence = steps
.iter()
.filter_map(|step| method_weights.get(&step.method))
.copied()
.sum::<f64>()
/ steps.len() as f64;
let avg_premises = if !justification.rules.is_empty() {
justification
.rules
.iter()
.map(|r| r.body.len())
.sum::<usize>() as f64
/ justification.rules.len() as f64
} else {
1.0
};
let complexity_confidence = if avg_premises > 5.0 {
1.0 - ((avg_premises - 5.0) * 0.05).min(0.3) } else {
1.0
};
let combined_confidence =
(depth_confidence * axiom_confidence * method_confidence * complexity_confidence)
.powf(0.25);
combined_confidence.clamp(0.0, 1.0)
}
fn compute_max_depth(&self, steps: &[DerivationStep]) -> usize {
if steps.is_empty() {
return 0;
}
let mut depths: HashMap<DerivationId, usize> = HashMap::new();
for step in steps {
if step.method == DerivationMethod::Asserted {
depths.insert(step.id, 0);
}
}
let mut changed = true;
while changed {
changed = false;
for step in steps {
if depths.contains_key(&step.id) {
continue; }
if step.premises.iter().all(|p| depths.contains_key(p)) {
let max_premise_depth = step
.premises
.iter()
.filter_map(|p| depths.get(p))
.max()
.copied()
.unwrap_or(0);
depths.insert(step.id, max_premise_depth + 1);
changed = true;
}
}
}
depths.values().copied().max().unwrap_or(0)
}
pub fn to_dot(&self, target: &RuleAtom) -> Result<String> {
let mut dot = String::from("digraph Derivation {\n");
dot.push_str(" rankdir=BT;\n");
dot.push_str(" node [shape=box];\n\n");
if let Some(derivation_ids) = self.fact_index.get(target) {
let mut visited = HashSet::new();
for &id in derivation_ids {
self.add_to_dot(id, &mut dot, &mut visited);
}
}
dot.push_str("}\n");
Ok(dot)
}
fn add_to_dot(&self, id: DerivationId, dot: &mut String, visited: &mut HashSet<DerivationId>) {
if !visited.insert(id) {
return;
}
if let Some(step) = self.derivations.get(&id) {
let label = format!("{:?}", step.fact);
let color = match step.method {
DerivationMethod::Asserted => "lightblue",
_ => "lightgreen",
};
dot.push_str(&format!(
" n{} [label=\"{}\", fillcolor={}, style=filled];\n",
id, label, color
));
for &premise_id in &step.premises {
dot.push_str(&format!(" n{} -> n{};\n", premise_id, id));
self.add_to_dot(premise_id, dot, visited);
}
}
}
pub fn get_stats(&self) -> ExplanationStats {
let asserted_count = self
.derivations
.values()
.filter(|s| s.method == DerivationMethod::Asserted)
.count();
let derived_count = self.derivations.len() - asserted_count;
ExplanationStats {
total_derivations: self.derivations.len(),
asserted_facts: asserted_count,
derived_facts: derived_count,
unique_facts: self.fact_index.len(),
}
}
pub fn clear(&mut self) {
self.derivations.clear();
self.fact_index.clear();
self.next_id = 0;
self.timestamp = 0;
}
}
#[derive(Debug, Clone)]
pub struct ExplanationStats {
pub total_derivations: usize,
pub asserted_facts: usize,
pub derived_facts: usize,
pub unique_facts: usize,
}
impl fmt::Display for ExplanationStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Derivations: {}, Asserted: {}, Derived: {}, Unique: {}",
self.total_derivations, self.asserted_facts, self.derived_facts, self.unique_facts
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Term;
#[test]
fn test_record_assertion() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let fact = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let id = engine.record_assertion(fact.clone());
assert_eq!(id, 0);
let step = engine.derivations.get(&id).ok_or("expected Some value")?;
assert_eq!(step.method, DerivationMethod::Asserted);
Ok(())
}
#[test]
fn test_record_derivation() {
let mut engine = ExplanationEngine::new();
let premise = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let premise_id = engine.record_assertion(premise);
let derived = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("mortal".to_string()),
};
let rule = Rule {
name: "mortality".to_string(),
body: vec![],
head: vec![],
};
let derived_id = engine.record_derivation(
derived,
Some(rule),
vec![premise_id],
DerivationMethod::ForwardChaining,
);
assert_eq!(derived_id, 1);
}
#[test]
fn test_explain_why() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let fact = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
engine.record_assertion(fact.clone());
let explanation = engine.explain_why(&fact)?;
assert_eq!(explanation.steps.len(), 1);
assert_eq!(explanation.steps[0].method, DerivationMethod::Asserted);
Ok(())
}
#[test]
fn test_justification() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let premise = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let premise_id = engine.record_assertion(premise.clone());
let derived = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("mortal".to_string()),
};
let rule = Rule {
name: "mortality".to_string(),
body: vec![],
head: vec![],
};
engine.record_derivation(
derived.clone(),
Some(rule),
vec![premise_id],
DerivationMethod::ForwardChaining,
);
let explanation = engine.explain_why(&derived)?;
assert!(!explanation.justification.axioms.is_empty());
assert!(!explanation.justification.rules.is_empty());
Ok(())
}
#[test]
fn test_stats() {
let mut engine = ExplanationEngine::new();
engine.record_assertion(RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
});
let stats = engine.get_stats();
assert_eq!(stats.asserted_facts, 1);
assert_eq!(stats.total_derivations, 1);
}
#[test]
fn test_confidence_asserted_fact() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let fact = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
engine.record_assertion(fact.clone());
let explanation = engine.explain_why(&fact)?;
assert_eq!(explanation.confidence, 1.0);
Ok(())
}
#[test]
fn test_confidence_single_derivation() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let premise = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let premise_id = engine.record_assertion(premise.clone());
let derived = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("mortal".to_string()),
};
let rule = Rule {
name: "mortality".to_string(),
body: vec![premise],
head: vec![derived.clone()],
};
engine.record_derivation(
derived.clone(),
Some(rule),
vec![premise_id],
DerivationMethod::ForwardChaining,
);
let explanation = engine.explain_why(&derived)?;
assert!(
explanation.confidence > 0.8,
"Expected confidence > 0.8, got {}",
explanation.confidence
);
assert!(
explanation.confidence < 1.0,
"Expected confidence < 1.0, got {}",
explanation.confidence
);
Ok(())
}
#[test]
fn test_confidence_chain_depth() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let fact1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("b".to_string()),
};
let id1 = engine.record_assertion(fact1.clone());
let fact2 = RuleAtom::Triple {
subject: Term::Constant("b".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("c".to_string()),
};
let rule = Rule {
name: "transitive".to_string(),
body: vec![fact1.clone()],
head: vec![fact2.clone()],
};
let id2 = engine.record_derivation(
fact2.clone(),
Some(rule.clone()),
vec![id1],
DerivationMethod::ForwardChaining,
);
let fact3 = RuleAtom::Triple {
subject: Term::Constant("c".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("d".to_string()),
};
engine.record_derivation(
fact3.clone(),
Some(rule),
vec![id2],
DerivationMethod::ForwardChaining,
);
let explanation = engine.explain_why(&fact3)?;
assert!(
explanation.confidence < 0.905,
"Expected confidence < 0.905 for depth-2 chain, got {}",
explanation.confidence
);
Ok(())
}
#[test]
fn test_confidence_different_methods() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let premise = RuleAtom::Triple {
subject: Term::Constant("x".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Class".to_string()),
};
let premise_id = engine.record_assertion(premise.clone());
let rdfs_derived = RuleAtom::Triple {
subject: Term::Constant("x".to_string()),
predicate: Term::Constant("subClassOf".to_string()),
object: Term::Constant("Resource".to_string()),
};
let rule = Rule {
name: "rdfs_rule".to_string(),
body: vec![premise.clone()],
head: vec![rdfs_derived.clone()],
};
engine.record_derivation(
rdfs_derived.clone(),
Some(rule.clone()),
vec![premise_id],
DerivationMethod::Rdfs,
);
let swrl_derived = RuleAtom::Triple {
subject: Term::Constant("x".to_string()),
predicate: Term::Constant("prop".to_string()),
object: Term::Constant("value".to_string()),
};
engine.record_derivation(
swrl_derived.clone(),
Some(rule),
vec![premise_id],
DerivationMethod::Swrl,
);
let rdfs_explanation = engine.explain_why(&rdfs_derived)?;
let swrl_explanation = engine.explain_why(&swrl_derived)?;
assert!(
rdfs_explanation.confidence > swrl_explanation.confidence,
"RDFS confidence ({}) should be > SWRL confidence ({})",
rdfs_explanation.confidence,
swrl_explanation.confidence
);
Ok(())
}
#[test]
fn test_confidence_rule_complexity() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let premise1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p1".to_string()),
object: Term::Constant("b".to_string()),
};
let id1 = engine.record_assertion(premise1.clone());
let simple_derived = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("q1".to_string()),
object: Term::Constant("b".to_string()),
};
let simple_rule = Rule {
name: "simple".to_string(),
body: vec![premise1.clone()],
head: vec![simple_derived.clone()],
};
engine.record_derivation(
simple_derived.clone(),
Some(simple_rule),
vec![id1],
DerivationMethod::ForwardChaining,
);
let mut premises = Vec::new();
let mut premise_ids = Vec::new();
for i in 0..6 {
let p = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant(format!("p{}", i)),
object: Term::Constant("b".to_string()),
};
let pid = engine.record_assertion(p.clone());
premises.push(p);
premise_ids.push(pid);
}
let complex_derived = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("q2".to_string()),
object: Term::Constant("b".to_string()),
};
let complex_rule = Rule {
name: "complex".to_string(),
body: premises,
head: vec![complex_derived.clone()],
};
engine.record_derivation(
complex_derived.clone(),
Some(complex_rule),
premise_ids,
DerivationMethod::ForwardChaining,
);
let simple_explanation = engine.explain_why(&simple_derived)?;
let complex_explanation = engine.explain_why(&complex_derived)?;
assert!(
complex_explanation.confidence > 0.85,
"Complex confidence ({}) should be > 0.85",
complex_explanation.confidence
);
assert!(
simple_explanation.confidence > 0.9,
"Simple confidence ({}) should be > 0.9",
simple_explanation.confidence
);
Ok(())
}
#[test]
fn test_confidence_axiom_ratio() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let axiom1 = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("p1".to_string()),
object: Term::Constant("b".to_string()),
};
let axiom2 = RuleAtom::Triple {
subject: Term::Constant("b".to_string()),
predicate: Term::Constant("p2".to_string()),
object: Term::Constant("c".to_string()),
};
let id1 = engine.record_assertion(axiom1.clone());
let id2 = engine.record_assertion(axiom2.clone());
let derived = RuleAtom::Triple {
subject: Term::Constant("a".to_string()),
predicate: Term::Constant("q".to_string()),
object: Term::Constant("c".to_string()),
};
let rule = Rule {
name: "combine".to_string(),
body: vec![axiom1, axiom2],
head: vec![derived.clone()],
};
engine.record_derivation(
derived.clone(),
Some(rule),
vec![id1, id2],
DerivationMethod::ForwardChaining,
);
let explanation = engine.explain_why(&derived)?;
assert!(
explanation.confidence > 0.85,
"Expected confidence > 0.85 with high axiom ratio, got {}",
explanation.confidence
);
Ok(())
}
#[test]
fn test_confidence_bounds() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ExplanationEngine::new();
let fact = RuleAtom::Triple {
subject: Term::Constant("test".to_string()),
predicate: Term::Constant("p".to_string()),
object: Term::Constant("value".to_string()),
};
engine.record_assertion(fact.clone());
let explanation = engine.explain_why(&fact)?;
assert!(
explanation.confidence >= 0.0,
"Confidence {} is below 0.0",
explanation.confidence
);
assert!(
explanation.confidence <= 1.0,
"Confidence {} is above 1.0",
explanation.confidence
);
Ok(())
}
}