use crate::{GraphRAGError, GraphRAGResult, ScoredEntity, Triple};
use std::collections::{HashMap, HashSet, VecDeque};
use oxirs_rule::{Rule, RuleAtom, RuleEngine, Term};
pub type FiredRulesMap = HashMap<(String, String, String), Vec<String>>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GraphEdge {
pub subject: String,
pub predicate: String,
pub object: String,
pub inferred: bool,
}
#[derive(Debug, Clone)]
pub struct HopPath {
pub edges: Vec<GraphEdge>,
pub start: String,
pub end: String,
pub score: f64,
pub inferred_hops: usize,
pub fired_rules: Vec<String>,
}
impl HopPath {
pub fn hop_count(&self) -> usize {
self.edges.len()
}
pub fn has_inferred_hop(&self) -> bool {
self.inferred_hops > 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PathScoringFn {
InverseHopCount,
#[default]
SeedWeighted,
Uniform,
InferencePenalised,
}
#[derive(Debug, Clone)]
pub struct MultiHopConfig {
pub max_hops: usize,
pub max_paths: usize,
pub max_edges_budget: usize,
pub include_inferred: bool,
pub scoring_fn: PathScoringFn,
pub allowed_predicates: HashSet<String>,
pub blocked_predicates: HashSet<String>,
pub min_path_score: f64,
}
impl Default for MultiHopConfig {
fn default() -> Self {
Self {
max_hops: 3,
max_paths: 50,
max_edges_budget: 100_000,
include_inferred: true,
scoring_fn: PathScoringFn::SeedWeighted,
allowed_predicates: HashSet::new(),
blocked_predicates: HashSet::new(),
min_path_score: 0.0,
}
}
}
fn atoms_to_edges(atoms: &[RuleAtom], inferred: bool) -> Vec<GraphEdge> {
atoms
.iter()
.filter_map(|atom| match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
let s = term_to_str(subject)?;
let p = term_to_str(predicate)?;
let o = term_to_str(object)?;
Some(GraphEdge {
subject: s,
predicate: p,
object: o,
inferred,
})
}
_ => None,
})
.collect()
}
fn term_to_str(term: &Term) -> Option<String> {
match term {
Term::Constant(c) | Term::Literal(c) => Some(c.clone()),
Term::Variable(_) | Term::Function { .. } => None, }
}
fn triples_to_atoms(triples: &[Triple]) -> Vec<RuleAtom> {
triples
.iter()
.map(|t| RuleAtom::Triple {
subject: Term::Constant(t.subject.clone()),
predicate: Term::Constant(t.predicate.clone()),
object: Term::Constant(t.object.clone()),
})
.collect()
}
pub struct MultiHopEngine {
config: MultiHopConfig,
}
impl Default for MultiHopEngine {
fn default() -> Self {
Self::new(MultiHopConfig::default())
}
}
impl MultiHopEngine {
pub fn new(config: MultiHopConfig) -> Self {
Self { config }
}
pub fn reason(
&self,
seeds: &[ScoredEntity],
subgraph: &[Triple],
rules: &[Rule],
) -> GraphRAGResult<Vec<HopPath>> {
if seeds.is_empty() || subgraph.is_empty() {
return Ok(vec![]);
}
let (asserted_edges, inferred_edges, fired_rule_map) = self.materialise(subgraph, rules)?;
let mut all_edges: Vec<GraphEdge> = asserted_edges;
if self.config.include_inferred {
all_edges.extend(inferred_edges);
}
let adj = self.build_adjacency(&all_edges);
let mut paths: Vec<HopPath> = Vec::new();
let seed_map: HashMap<String, f64> =
seeds.iter().map(|s| (s.uri.clone(), s.score)).collect();
for seed in seeds {
let new_paths =
self.bfs_paths(&seed.uri, seed.score, &adj, &all_edges, &fired_rule_map);
paths.extend(new_paths);
}
paths.retain(|p| p.score >= self.config.min_path_score);
paths.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
paths.truncate(self.config.max_paths);
let _ = seed_map;
Ok(paths)
}
fn materialise(
&self,
subgraph: &[Triple],
rules: &[Rule],
) -> GraphRAGResult<(Vec<GraphEdge>, Vec<GraphEdge>, FiredRulesMap)> {
let asserted_edges = atoms_to_edges(&triples_to_atoms(subgraph), false);
if rules.is_empty() {
return Ok((asserted_edges, vec![], HashMap::new()));
}
let mut engine = RuleEngine::new();
engine.add_rules(rules.to_vec());
engine.enable_cache();
let facts = triples_to_atoms(subgraph);
let inferred_atoms = engine
.forward_chain(&facts)
.map_err(|e| GraphRAGError::InternalError(format!("Rule engine error: {e}")))?;
let asserted_keys: HashSet<(String, String, String)> = subgraph
.iter()
.map(|t| (t.subject.clone(), t.predicate.clone(), t.object.clone()))
.collect();
let inferred_edges: Vec<GraphEdge> = atoms_to_edges(&inferred_atoms, true)
.into_iter()
.filter(|e| {
!asserted_keys.contains(&(e.subject.clone(), e.predicate.clone(), e.object.clone()))
})
.collect();
let fired_rule_map: FiredRulesMap = rules
.iter()
.flat_map(|rule| {
rule.head.iter().filter_map(|atom| match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
let s = term_to_str(subject)?;
let p = term_to_str(predicate)?;
let o = term_to_str(object)?;
Some(((s, p, o), rule.name.clone()))
}
_ => None,
})
})
.fold(HashMap::new(), |mut acc, (key, rule_name)| {
acc.entry(key).or_default().push(rule_name);
acc
});
Ok((asserted_edges, inferred_edges, fired_rule_map))
}
fn build_adjacency(&self, edges: &[GraphEdge]) -> HashMap<String, Vec<usize>> {
let mut adj: HashMap<String, Vec<usize>> = HashMap::new();
for (i, edge) in edges.iter().enumerate() {
if self.allow_predicate(&edge.predicate) {
adj.entry(edge.subject.clone()).or_default().push(i);
}
}
adj
}
fn allow_predicate(&self, pred: &str) -> bool {
if !self.config.allowed_predicates.is_empty()
&& !self.config.allowed_predicates.contains(pred)
{
return false;
}
!self.config.blocked_predicates.contains(pred)
}
fn bfs_paths(
&self,
start: &str,
seed_score: f64,
adj: &HashMap<String, Vec<usize>>,
edges: &[GraphEdge],
fired_rule_map: &HashMap<(String, String, String), Vec<String>>,
) -> Vec<HopPath> {
struct State {
node: String,
edge_path: Vec<usize>,
visited: HashSet<String>,
}
let mut queue: VecDeque<State> = VecDeque::new();
queue.push_back(State {
node: start.to_string(),
edge_path: vec![],
visited: {
let mut h = HashSet::new();
h.insert(start.to_string());
h
},
});
let mut paths: Vec<HopPath> = Vec::new();
let mut budget = self.config.max_edges_budget;
while let Some(state) = queue.pop_front() {
if budget == 0 {
break;
}
budget -= 1;
if state.edge_path.len() > self.config.max_hops {
continue;
}
if !state.edge_path.is_empty() {
let path_edges: Vec<GraphEdge> =
state.edge_path.iter().map(|&i| edges[i].clone()).collect();
let inferred_hops = path_edges.iter().filter(|e| e.inferred).count();
let fired_rules: Vec<String> = path_edges
.iter()
.filter(|e| e.inferred)
.flat_map(|e| {
let key = (e.subject.clone(), e.predicate.clone(), e.object.clone());
fired_rule_map.get(&key).cloned().unwrap_or_default()
})
.collect::<HashSet<_>>()
.into_iter()
.collect();
let score = self.score_path(state.edge_path.len(), inferred_hops, seed_score);
paths.push(HopPath {
edges: path_edges,
start: start.to_string(),
end: state.node.clone(),
score,
inferred_hops,
fired_rules,
});
if paths.len() >= self.config.max_paths {
return paths;
}
}
if state.edge_path.len() >= self.config.max_hops {
continue;
}
if let Some(edge_indices) = adj.get(&state.node) {
for &ei in edge_indices {
let edge = &edges[ei];
if !state.visited.contains(&edge.object) {
let mut new_visited = state.visited.clone();
new_visited.insert(edge.object.clone());
let mut new_path = state.edge_path.clone();
new_path.push(ei);
queue.push_back(State {
node: edge.object.clone(),
edge_path: new_path,
visited: new_visited,
});
}
}
}
}
paths
}
fn score_path(&self, hops: usize, inferred_hops: usize, seed_score: f64) -> f64 {
let h = hops.max(1) as f64;
match self.config.scoring_fn {
PathScoringFn::InverseHopCount => 1.0 / h,
PathScoringFn::SeedWeighted => seed_score / h,
PathScoringFn::Uniform => 1.0,
PathScoringFn::InferencePenalised => (1.0 / h) * 0.8_f64.powi(inferred_hops as i32),
}
}
}
pub fn transitivity_rule(predicate: &str) -> Rule {
Rule {
name: format!("{predicate}_transitive"),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Variable("Z".to_string()),
}],
}
}
pub fn property_chain_rule(p1: &str, p2: &str, conclusion_pred: &str) -> Rule {
Rule {
name: format!("{p1}_{p2}_chain"),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(p1.to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant(p2.to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(conclusion_pred.to_string()),
object: Term::Variable("Z".to_string()),
}],
}
}
pub fn symmetry_rule(predicate: &str) -> Rule {
Rule {
name: format!("{predicate}_symmetric"),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Variable("X".to_string()),
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScoreSource;
fn make_seed(uri: &str, score: f64) -> ScoredEntity {
ScoredEntity {
uri: uri.to_string(),
score,
source: ScoreSource::Vector,
metadata: HashMap::new(),
}
}
fn make_triple(s: &str, p: &str, o: &str) -> Triple {
Triple::new(s, p, o)
}
#[test]
fn test_transitivity_rule_structure() {
let rule = transitivity_rule("subClassOf");
assert_eq!(rule.name, "subClassOf_transitive");
assert_eq!(rule.body.len(), 2);
assert_eq!(rule.head.len(), 1);
}
#[test]
fn test_property_chain_rule_structure() {
let rule = property_chain_rule("partOf", "locatedIn", "indirectlyIn");
assert_eq!(rule.name, "partOf_locatedIn_chain");
assert_eq!(rule.body.len(), 2);
}
#[test]
fn test_symmetry_rule_structure() {
let rule = symmetry_rule("sameAs");
assert_eq!(rule.name, "sameAs_symmetric");
assert_eq!(rule.body.len(), 1);
assert_eq!(rule.head.len(), 1);
}
#[test]
fn test_reason_empty_seeds() {
let engine = MultiHopEngine::default();
let triples = vec![make_triple("http://a", "http://rel", "http://b")];
let result = engine.reason(&[], &triples, &[]).expect("should succeed");
assert!(result.is_empty());
}
#[test]
fn test_reason_empty_subgraph() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 0.9)];
let result = engine.reason(&seeds, &[], &[]).expect("should succeed");
assert!(result.is_empty());
}
#[test]
fn test_reason_single_hop_no_rules() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://p/rel", "http://b"),
make_triple("http://b", "http://p/rel", "http://c"),
make_triple("http://x", "http://p/other", "http://y"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
assert!(!paths.is_empty());
for p in &paths {
assert_eq!(p.start, "http://a");
}
for p in &paths {
assert_eq!(p.inferred_hops, 0);
}
}
#[test]
fn test_reason_respects_max_hops() {
let config = MultiHopConfig {
max_hops: 1,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://p", "http://b"),
make_triple("http://b", "http://p", "http://c"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
for p in &paths {
assert!(
p.hop_count() <= 1,
"Path hop count {} > max_hops 1",
p.hop_count()
);
}
}
#[test]
fn test_reason_respects_max_paths() {
let config = MultiHopConfig {
max_paths: 2,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples: Vec<Triple> = (0..20)
.map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
.collect();
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
assert!(paths.len() <= 2);
}
#[test]
fn test_reason_transitivity_rule() {
let config = MultiHopConfig {
max_hops: 3,
include_inferred: true,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://subClassOf", "http://b"),
make_triple("http://b", "http://subClassOf", "http://c"),
];
let rules = vec![transitivity_rule("http://subClassOf")];
let paths = engine
.reason(&seeds, &triples, &rules)
.expect("should succeed");
assert!(!paths.is_empty());
let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
assert!(
has_inferred,
"Expected at least one path with inferred hops"
);
}
#[test]
fn test_reason_no_inferred_when_disabled() {
let config = MultiHopConfig {
include_inferred: false,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://subClassOf", "http://b"),
make_triple("http://b", "http://subClassOf", "http://c"),
];
let rules = vec![transitivity_rule("http://subClassOf")];
let paths = engine
.reason(&seeds, &triples, &rules)
.expect("should succeed");
for p in &paths {
assert_eq!(
p.inferred_hops, 0,
"Expected no inferred hops when disabled"
);
}
}
#[test]
fn test_score_inverse_hop_count() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::InverseHopCount,
..Default::default()
};
let engine = MultiHopEngine::new(config);
assert!((engine.score_path(1, 0, 1.0) - 1.0).abs() < 1e-9);
assert!((engine.score_path(2, 0, 1.0) - 0.5).abs() < 1e-9);
}
#[test]
fn test_score_seed_weighted() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::SeedWeighted,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let s = engine.score_path(2, 0, 0.8);
assert!((s - 0.4).abs() < 1e-9);
}
#[test]
fn test_score_uniform() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::Uniform,
..Default::default()
};
let engine = MultiHopEngine::new(config);
assert_eq!(engine.score_path(5, 3, 0.5), 1.0);
}
#[test]
fn test_score_inference_penalised() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::InferencePenalised,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let s_no_inf = engine.score_path(2, 0, 1.0);
let s_with_inf = engine.score_path(2, 1, 1.0);
assert!(s_no_inf > s_with_inf, "Inferred hop should reduce score");
}
#[test]
fn test_blocked_predicates_filter() {
let mut config = MultiHopConfig::default();
config
.blocked_predicates
.insert("http://p/blocked".to_string());
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://p/allowed", "http://b"),
make_triple("http://a", "http://p/blocked", "http://c"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
for p in &paths {
for e in &p.edges {
assert_ne!(
e.predicate, "http://p/blocked",
"Blocked predicate found in path"
);
}
}
}
#[test]
fn test_allowed_predicates_whitelist() {
let mut config = MultiHopConfig::default();
config
.allowed_predicates
.insert("http://p/allowed".to_string());
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://p/allowed", "http://b"),
make_triple("http://a", "http://p/other", "http://c"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
for p in &paths {
for e in &p.edges {
assert_eq!(e.predicate, "http://p/allowed");
}
}
}
#[test]
fn test_hop_path_fields() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 0.8)];
let triples = vec![make_triple("http://a", "http://p", "http://b")];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
assert!(!paths.is_empty());
let path = &paths[0];
assert_eq!(path.start, "http://a");
assert_eq!(path.end, "http://b");
assert_eq!(path.hop_count(), 1);
assert!(!path.has_inferred_hop());
}
#[test]
fn test_min_path_score_threshold() {
let config = MultiHopConfig {
min_path_score: 99.0, ..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![make_triple("http://a", "http://p", "http://b")];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
assert!(paths.is_empty());
}
#[test]
fn test_triples_to_atoms_roundtrip() {
let triples = vec![Triple::new("http://s", "http://p", "http://o")];
let atoms = triples_to_atoms(&triples);
assert_eq!(atoms.len(), 1);
if let RuleAtom::Triple {
subject,
predicate,
object,
} = &atoms[0]
{
assert_eq!(term_to_str(subject).expect("should succeed"), "http://s");
assert_eq!(term_to_str(predicate).expect("should succeed"), "http://p");
assert_eq!(term_to_str(object).expect("should succeed"), "http://o");
} else {
panic!("Expected Triple atom");
}
}
#[test]
fn test_property_chain_produces_derived_edges() {
let config = MultiHopConfig {
max_hops: 3,
include_inferred: true,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://partOf", "http://b"),
make_triple("http://b", "http://locatedIn", "http://c"),
];
let rules = vec![property_chain_rule(
"http://partOf",
"http://locatedIn",
"http://indirectlyIn",
)];
let paths = engine
.reason(&seeds, &triples, &rules)
.expect("should succeed");
assert!(!paths.is_empty());
}
}
#[cfg(test)]
mod additional_tests {
use super::*;
use crate::ScoreSource;
fn make_seed(uri: &str, score: f64) -> ScoredEntity {
ScoredEntity {
uri: uri.to_string(),
score,
source: ScoreSource::Vector,
metadata: HashMap::new(),
}
}
fn make_triple(s: &str, p: &str, o: &str) -> Triple {
Triple::new(s, p, o)
}
#[test]
fn test_symmetry_rule_head_swapped() {
let rule = symmetry_rule("http://sameAs");
assert!(matches!(&rule.head[0], RuleAtom::Triple { .. }));
if let RuleAtom::Triple {
subject, object, ..
} = &rule.head[0]
{
match (subject, object) {
(Term::Variable(sv), Term::Variable(ov)) => {
assert_ne!(sv, ov, "Head subject and object variables should differ");
}
_ => panic!("Expected Variable terms in head"),
}
} else {
panic!("Expected Triple head");
}
}
#[test]
fn test_transitivity_rule_has_shared_variable() {
let rule = transitivity_rule("http://subClassOf");
if let (RuleAtom::Triple { object: obj0, .. }, RuleAtom::Triple { subject: subj1, .. }) =
(&rule.body[0], &rule.body[1])
{
matches!(obj0, Term::Variable(v) if v == "Y");
matches!(subj1, Term::Variable(v) if v == "Y");
}
}
#[test]
fn test_property_chain_rule_body_predicates() {
let rule = property_chain_rule("http://partOf", "http://locatedIn", "http://indirectlyIn");
if let RuleAtom::Triple { predicate: p1, .. } = &rule.body[0] {
assert_eq!(term_to_str(p1).expect("should succeed"), "http://partOf");
}
if let RuleAtom::Triple { predicate: p2, .. } = &rule.body[1] {
assert_eq!(term_to_str(p2).expect("should succeed"), "http://locatedIn");
}
if let RuleAtom::Triple { predicate: ph, .. } = &rule.head[0] {
assert_eq!(
term_to_str(ph).expect("should succeed"),
"http://indirectlyIn"
);
}
}
#[test]
fn test_term_to_str_constant() {
let t = Term::Constant("http://example.org/x".to_string());
assert_eq!(
term_to_str(&t).expect("should succeed"),
"http://example.org/x"
);
}
#[test]
fn test_term_to_str_literal() {
let t = Term::Literal("hello world".to_string());
assert_eq!(term_to_str(&t).expect("should succeed"), "hello world");
}
#[test]
fn test_term_to_str_variable_returns_none() {
let t = Term::Variable("X".to_string());
assert!(term_to_str(&t).is_none());
}
#[test]
fn test_atoms_to_edges_filters_non_triple() {
let atoms = vec![RuleAtom::Triple {
subject: Term::Constant("http://s".to_string()),
predicate: Term::Constant("http://p".to_string()),
object: Term::Constant("http://o".to_string()),
}];
let edges = atoms_to_edges(&atoms, false);
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].subject, "http://s");
assert!(!edges[0].inferred);
}
#[test]
fn test_atoms_to_edges_inferred_flag() {
let atoms = vec![RuleAtom::Triple {
subject: Term::Constant("http://s".to_string()),
predicate: Term::Constant("http://p".to_string()),
object: Term::Constant("http://o".to_string()),
}];
let edges = atoms_to_edges(&atoms, true);
assert!(edges[0].inferred);
}
#[test]
fn test_score_inference_penalised_zero_inferred_equals_inverse() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::InferencePenalised,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let s = engine.score_path(2, 0, 1.0);
assert!((s - 0.5).abs() < 1e-9);
}
#[test]
fn test_score_seed_weighted_scales_with_seed_score() {
let config = MultiHopConfig {
scoring_fn: PathScoringFn::SeedWeighted,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let s1 = engine.score_path(1, 0, 1.0);
let s2 = engine.score_path(1, 0, 0.5);
assert!((s1 - 2.0 * s2).abs() < 1e-9, "s1={s1}, s2={s2}");
}
#[test]
fn test_graph_edge_equality() {
let e1 = GraphEdge {
subject: "http://a".to_string(),
predicate: "http://p".to_string(),
object: "http://b".to_string(),
inferred: false,
};
let e2 = e1.clone();
assert_eq!(e1, e2);
}
#[test]
fn test_multihop_config_defaults() {
let cfg = MultiHopConfig::default();
assert_eq!(cfg.max_hops, 3);
assert_eq!(cfg.max_paths, 50);
assert!(cfg.include_inferred);
assert!(cfg.allowed_predicates.is_empty());
assert!(cfg.blocked_predicates.is_empty());
}
#[test]
fn test_reason_cycle_in_graph_does_not_loop() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 0.9)];
let triples = vec![
make_triple("http://a", "http://p", "http://b"),
make_triple("http://b", "http://p", "http://c"),
make_triple("http://c", "http://p", "http://a"), ];
let paths = engine.reason(&seeds, &triples, &[]);
assert!(paths.is_ok(), "Should not error on cyclic graphs");
let paths = paths.expect("should succeed");
assert!(paths.len() < 1000, "Cycle guard should bound path count");
}
#[test]
fn test_hop_path_has_inferred_hop_false_for_asserted() {
let path = HopPath {
edges: vec![GraphEdge {
subject: "http://s".to_string(),
predicate: "http://p".to_string(),
object: "http://o".to_string(),
inferred: false,
}],
start: "http://s".to_string(),
end: "http://o".to_string(),
score: 0.8,
inferred_hops: 0,
fired_rules: vec![],
};
assert!(!path.has_inferred_hop());
assert_eq!(path.hop_count(), 1);
}
#[test]
fn test_hop_path_has_inferred_hop_true_for_inferred() {
let path = HopPath {
edges: vec![GraphEdge {
subject: "http://s".to_string(),
predicate: "http://p".to_string(),
object: "http://o".to_string(),
inferred: true,
}],
start: "http://s".to_string(),
end: "http://o".to_string(),
score: 0.5,
inferred_hops: 1,
fired_rules: vec!["rule1".to_string()],
};
assert!(path.has_inferred_hop());
}
#[test]
fn test_reason_two_seeds_produce_paths_from_both() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 0.9), make_seed("http://x", 0.8)];
let triples = vec![
make_triple("http://a", "http://p", "http://b"),
make_triple("http://x", "http://q", "http://y"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
let from_a = paths.iter().any(|p| p.start == "http://a");
let from_x = paths.iter().any(|p| p.start == "http://x");
assert!(from_a, "Expected paths from http://a");
assert!(from_x, "Expected paths from http://x");
}
#[test]
fn test_reason_symmetry_rule_adds_reverse_edge() {
let config = MultiHopConfig {
max_hops: 2,
include_inferred: true,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://b", 0.9)];
let triples = vec![make_triple("http://a", "http://sameAs", "http://b")];
let rules = vec![symmetry_rule("http://sameAs")];
let paths = engine
.reason(&seeds, &triples, &rules)
.expect("should succeed");
let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
assert!(has_inferred, "Symmetry rule should create inferred edges");
}
#[test]
fn test_reason_budget_guard_limits_expansion() {
let config = MultiHopConfig {
max_edges_budget: 5, max_hops: 10,
max_paths: 1000,
..Default::default()
};
let engine = MultiHopEngine::new(config);
let seeds = vec![make_seed("http://a", 0.9)];
let triples: Vec<Triple> = (0..100)
.map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
.collect();
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
assert!(paths.len() < 100, "Budget guard should limit path count");
}
#[test]
fn test_reason_paths_sorted_descending() {
let engine = MultiHopEngine::default();
let seeds = vec![make_seed("http://a", 1.0)];
let triples = vec![
make_triple("http://a", "http://p", "http://b"),
make_triple("http://b", "http://p", "http://c"),
make_triple("http://c", "http://p", "http://d"),
];
let paths = engine
.reason(&seeds, &triples, &[])
.expect("should succeed");
for i in 1..paths.len() {
assert!(
paths[i - 1].score >= paths[i].score,
"Paths should be sorted descending: {} < {}",
paths[i - 1].score,
paths[i].score
);
}
}
}