1use crate::{GraphRAGError, GraphRAGResult, ScoredEntity, Triple};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20use oxirs_rule::{Rule, RuleAtom, RuleEngine, Term};
25
26pub type FiredRulesMap = HashMap<(String, String, String), Vec<String>>;
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct GraphEdge {
34 pub subject: String,
35 pub predicate: String,
36 pub object: String,
37 pub inferred: bool,
39}
40
41#[derive(Debug, Clone)]
43pub struct HopPath {
44 pub edges: Vec<GraphEdge>,
46 pub start: String,
48 pub end: String,
50 pub score: f64,
52 pub inferred_hops: usize,
54 pub fired_rules: Vec<String>,
56}
57
58impl HopPath {
59 pub fn hop_count(&self) -> usize {
61 self.edges.len()
62 }
63
64 pub fn has_inferred_hop(&self) -> bool {
66 self.inferred_hops > 0
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
74pub enum PathScoringFn {
75 InverseHopCount,
77 #[default]
79 SeedWeighted,
80 Uniform,
82 InferencePenalised,
84}
85
86#[derive(Debug, Clone)]
88pub struct MultiHopConfig {
89 pub max_hops: usize,
91 pub max_paths: usize,
93 pub max_edges_budget: usize,
95 pub include_inferred: bool,
97 pub scoring_fn: PathScoringFn,
99 pub allowed_predicates: HashSet<String>,
101 pub blocked_predicates: HashSet<String>,
103 pub min_path_score: f64,
105}
106
107impl Default for MultiHopConfig {
108 fn default() -> Self {
109 Self {
110 max_hops: 3,
111 max_paths: 50,
112 max_edges_budget: 100_000,
113 include_inferred: true,
114 scoring_fn: PathScoringFn::SeedWeighted,
115 allowed_predicates: HashSet::new(),
116 blocked_predicates: HashSet::new(),
117 min_path_score: 0.0,
118 }
119 }
120}
121
122fn atoms_to_edges(atoms: &[RuleAtom], inferred: bool) -> Vec<GraphEdge> {
125 atoms
126 .iter()
127 .filter_map(|atom| match atom {
128 RuleAtom::Triple {
129 subject,
130 predicate,
131 object,
132 } => {
133 let s = term_to_str(subject)?;
134 let p = term_to_str(predicate)?;
135 let o = term_to_str(object)?;
136 Some(GraphEdge {
137 subject: s,
138 predicate: p,
139 object: o,
140 inferred,
141 })
142 }
143 _ => None,
144 })
145 .collect()
146}
147
148fn term_to_str(term: &Term) -> Option<String> {
149 match term {
150 Term::Constant(c) | Term::Literal(c) => Some(c.clone()),
151 Term::Variable(_) | Term::Function { .. } => None, }
153}
154
155fn triples_to_atoms(triples: &[Triple]) -> Vec<RuleAtom> {
156 triples
157 .iter()
158 .map(|t| RuleAtom::Triple {
159 subject: Term::Constant(t.subject.clone()),
160 predicate: Term::Constant(t.predicate.clone()),
161 object: Term::Constant(t.object.clone()),
162 })
163 .collect()
164}
165
166pub struct MultiHopEngine {
170 config: MultiHopConfig,
171}
172
173impl Default for MultiHopEngine {
174 fn default() -> Self {
175 Self::new(MultiHopConfig::default())
176 }
177}
178
179impl MultiHopEngine {
180 pub fn new(config: MultiHopConfig) -> Self {
181 Self { config }
182 }
183
184 pub fn reason(
188 &self,
189 seeds: &[ScoredEntity],
190 subgraph: &[Triple],
191 rules: &[Rule],
192 ) -> GraphRAGResult<Vec<HopPath>> {
193 if seeds.is_empty() || subgraph.is_empty() {
194 return Ok(vec![]);
195 }
196
197 let (asserted_edges, inferred_edges, fired_rule_map) = self.materialise(subgraph, rules)?;
199
200 let mut all_edges: Vec<GraphEdge> = asserted_edges;
202 if self.config.include_inferred {
203 all_edges.extend(inferred_edges);
204 }
205
206 let adj = self.build_adjacency(&all_edges);
207
208 let mut paths: Vec<HopPath> = Vec::new();
210 let seed_map: HashMap<String, f64> =
211 seeds.iter().map(|s| (s.uri.clone(), s.score)).collect();
212
213 for seed in seeds {
214 let new_paths =
215 self.bfs_paths(&seed.uri, seed.score, &adj, &all_edges, &fired_rule_map);
216 paths.extend(new_paths);
217 }
218
219 paths.retain(|p| p.score >= self.config.min_path_score);
221 paths.sort_by(|a, b| {
222 b.score
223 .partial_cmp(&a.score)
224 .unwrap_or(std::cmp::Ordering::Equal)
225 });
226 paths.truncate(self.config.max_paths);
227
228 let _ = seed_map;
230
231 Ok(paths)
232 }
233
234 fn materialise(
236 &self,
237 subgraph: &[Triple],
238 rules: &[Rule],
239 ) -> GraphRAGResult<(Vec<GraphEdge>, Vec<GraphEdge>, FiredRulesMap)> {
240 let asserted_edges = atoms_to_edges(&triples_to_atoms(subgraph), false);
241
242 if rules.is_empty() {
243 return Ok((asserted_edges, vec![], HashMap::new()));
244 }
245
246 let mut engine = RuleEngine::new();
247 engine.add_rules(rules.to_vec());
248 engine.enable_cache();
249
250 let facts = triples_to_atoms(subgraph);
251
252 let inferred_atoms = engine
253 .forward_chain(&facts)
254 .map_err(|e| GraphRAGError::InternalError(format!("Rule engine error: {e}")))?;
255
256 let asserted_keys: HashSet<(String, String, String)> = subgraph
258 .iter()
259 .map(|t| (t.subject.clone(), t.predicate.clone(), t.object.clone()))
260 .collect();
261
262 let inferred_edges: Vec<GraphEdge> = atoms_to_edges(&inferred_atoms, true)
263 .into_iter()
264 .filter(|e| {
265 !asserted_keys.contains(&(e.subject.clone(), e.predicate.clone(), e.object.clone()))
266 })
267 .collect();
268
269 let fired_rule_map: FiredRulesMap = rules
271 .iter()
272 .flat_map(|rule| {
273 rule.head.iter().filter_map(|atom| match atom {
274 RuleAtom::Triple {
275 subject,
276 predicate,
277 object,
278 } => {
279 let s = term_to_str(subject)?;
280 let p = term_to_str(predicate)?;
281 let o = term_to_str(object)?;
282 Some(((s, p, o), rule.name.clone()))
283 }
284 _ => None,
285 })
286 })
287 .fold(HashMap::new(), |mut acc, (key, rule_name)| {
288 acc.entry(key).or_default().push(rule_name);
289 acc
290 });
291
292 Ok((asserted_edges, inferred_edges, fired_rule_map))
293 }
294
295 fn build_adjacency(&self, edges: &[GraphEdge]) -> HashMap<String, Vec<usize>> {
297 let mut adj: HashMap<String, Vec<usize>> = HashMap::new();
298 for (i, edge) in edges.iter().enumerate() {
299 if self.allow_predicate(&edge.predicate) {
300 adj.entry(edge.subject.clone()).or_default().push(i);
301 }
302 }
303 adj
304 }
305
306 fn allow_predicate(&self, pred: &str) -> bool {
307 if !self.config.allowed_predicates.is_empty()
308 && !self.config.allowed_predicates.contains(pred)
309 {
310 return false;
311 }
312 !self.config.blocked_predicates.contains(pred)
313 }
314
315 fn bfs_paths(
317 &self,
318 start: &str,
319 seed_score: f64,
320 adj: &HashMap<String, Vec<usize>>,
321 edges: &[GraphEdge],
322 fired_rule_map: &HashMap<(String, String, String), Vec<String>>,
323 ) -> Vec<HopPath> {
324 struct State {
326 node: String,
327 edge_path: Vec<usize>,
328 visited: HashSet<String>,
329 }
330
331 let mut queue: VecDeque<State> = VecDeque::new();
332 queue.push_back(State {
333 node: start.to_string(),
334 edge_path: vec![],
335 visited: {
336 let mut h = HashSet::new();
337 h.insert(start.to_string());
338 h
339 },
340 });
341
342 let mut paths: Vec<HopPath> = Vec::new();
343 let mut budget = self.config.max_edges_budget;
344
345 while let Some(state) = queue.pop_front() {
346 if budget == 0 {
347 break;
348 }
349 budget -= 1;
350
351 if state.edge_path.len() > self.config.max_hops {
352 continue;
353 }
354
355 if !state.edge_path.is_empty() {
357 let path_edges: Vec<GraphEdge> =
358 state.edge_path.iter().map(|&i| edges[i].clone()).collect();
359
360 let inferred_hops = path_edges.iter().filter(|e| e.inferred).count();
361 let fired_rules: Vec<String> = path_edges
362 .iter()
363 .filter(|e| e.inferred)
364 .flat_map(|e| {
365 let key = (e.subject.clone(), e.predicate.clone(), e.object.clone());
366 fired_rule_map.get(&key).cloned().unwrap_or_default()
367 })
368 .collect::<HashSet<_>>()
369 .into_iter()
370 .collect();
371
372 let score = self.score_path(state.edge_path.len(), inferred_hops, seed_score);
373
374 paths.push(HopPath {
375 edges: path_edges,
376 start: start.to_string(),
377 end: state.node.clone(),
378 score,
379 inferred_hops,
380 fired_rules,
381 });
382
383 if paths.len() >= self.config.max_paths {
384 return paths;
385 }
386 }
387
388 if state.edge_path.len() >= self.config.max_hops {
389 continue;
390 }
391
392 if let Some(edge_indices) = adj.get(&state.node) {
394 for &ei in edge_indices {
395 let edge = &edges[ei];
396 if !state.visited.contains(&edge.object) {
397 let mut new_visited = state.visited.clone();
398 new_visited.insert(edge.object.clone());
399 let mut new_path = state.edge_path.clone();
400 new_path.push(ei);
401 queue.push_back(State {
402 node: edge.object.clone(),
403 edge_path: new_path,
404 visited: new_visited,
405 });
406 }
407 }
408 }
409 }
410
411 paths
412 }
413
414 fn score_path(&self, hops: usize, inferred_hops: usize, seed_score: f64) -> f64 {
415 let h = hops.max(1) as f64;
416 match self.config.scoring_fn {
417 PathScoringFn::InverseHopCount => 1.0 / h,
418 PathScoringFn::SeedWeighted => seed_score / h,
419 PathScoringFn::Uniform => 1.0,
420 PathScoringFn::InferencePenalised => (1.0 / h) * 0.8_f64.powi(inferred_hops as i32),
421 }
422 }
423}
424
425pub fn transitivity_rule(predicate: &str) -> Rule {
430 Rule {
431 name: format!("{predicate}_transitive"),
432 body: vec![
433 RuleAtom::Triple {
434 subject: Term::Variable("X".to_string()),
435 predicate: Term::Constant(predicate.to_string()),
436 object: Term::Variable("Y".to_string()),
437 },
438 RuleAtom::Triple {
439 subject: Term::Variable("Y".to_string()),
440 predicate: Term::Constant(predicate.to_string()),
441 object: Term::Variable("Z".to_string()),
442 },
443 ],
444 head: vec![RuleAtom::Triple {
445 subject: Term::Variable("X".to_string()),
446 predicate: Term::Constant(predicate.to_string()),
447 object: Term::Variable("Z".to_string()),
448 }],
449 }
450}
451
452pub fn property_chain_rule(p1: &str, p2: &str, conclusion_pred: &str) -> Rule {
454 Rule {
455 name: format!("{p1}_{p2}_chain"),
456 body: vec![
457 RuleAtom::Triple {
458 subject: Term::Variable("X".to_string()),
459 predicate: Term::Constant(p1.to_string()),
460 object: Term::Variable("Y".to_string()),
461 },
462 RuleAtom::Triple {
463 subject: Term::Variable("Y".to_string()),
464 predicate: Term::Constant(p2.to_string()),
465 object: Term::Variable("Z".to_string()),
466 },
467 ],
468 head: vec![RuleAtom::Triple {
469 subject: Term::Variable("X".to_string()),
470 predicate: Term::Constant(conclusion_pred.to_string()),
471 object: Term::Variable("Z".to_string()),
472 }],
473 }
474}
475
476pub fn symmetry_rule(predicate: &str) -> Rule {
478 Rule {
479 name: format!("{predicate}_symmetric"),
480 body: vec![RuleAtom::Triple {
481 subject: Term::Variable("X".to_string()),
482 predicate: Term::Constant(predicate.to_string()),
483 object: Term::Variable("Y".to_string()),
484 }],
485 head: vec![RuleAtom::Triple {
486 subject: Term::Variable("Y".to_string()),
487 predicate: Term::Constant(predicate.to_string()),
488 object: Term::Variable("X".to_string()),
489 }],
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::ScoreSource;
497
498 fn make_seed(uri: &str, score: f64) -> ScoredEntity {
499 ScoredEntity {
500 uri: uri.to_string(),
501 score,
502 source: ScoreSource::Vector,
503 metadata: HashMap::new(),
504 }
505 }
506
507 fn make_triple(s: &str, p: &str, o: &str) -> Triple {
508 Triple::new(s, p, o)
509 }
510
511 #[test]
514 fn test_transitivity_rule_structure() {
515 let rule = transitivity_rule("subClassOf");
516 assert_eq!(rule.name, "subClassOf_transitive");
517 assert_eq!(rule.body.len(), 2);
518 assert_eq!(rule.head.len(), 1);
519 }
520
521 #[test]
522 fn test_property_chain_rule_structure() {
523 let rule = property_chain_rule("partOf", "locatedIn", "indirectlyIn");
524 assert_eq!(rule.name, "partOf_locatedIn_chain");
525 assert_eq!(rule.body.len(), 2);
526 }
527
528 #[test]
529 fn test_symmetry_rule_structure() {
530 let rule = symmetry_rule("sameAs");
531 assert_eq!(rule.name, "sameAs_symmetric");
532 assert_eq!(rule.body.len(), 1);
533 assert_eq!(rule.head.len(), 1);
534 }
535
536 #[test]
539 fn test_reason_empty_seeds() {
540 let engine = MultiHopEngine::default();
541 let triples = vec![make_triple("http://a", "http://rel", "http://b")];
542 let result = engine.reason(&[], &triples, &[]).expect("should succeed");
543 assert!(result.is_empty());
544 }
545
546 #[test]
547 fn test_reason_empty_subgraph() {
548 let engine = MultiHopEngine::default();
549 let seeds = vec![make_seed("http://a", 0.9)];
550 let result = engine.reason(&seeds, &[], &[]).expect("should succeed");
551 assert!(result.is_empty());
552 }
553
554 #[test]
555 fn test_reason_single_hop_no_rules() {
556 let engine = MultiHopEngine::default();
557 let seeds = vec![make_seed("http://a", 0.9)];
558 let triples = vec![
559 make_triple("http://a", "http://p/rel", "http://b"),
560 make_triple("http://b", "http://p/rel", "http://c"),
561 make_triple("http://x", "http://p/other", "http://y"),
562 ];
563 let paths = engine
564 .reason(&seeds, &triples, &[])
565 .expect("should succeed");
566 assert!(!paths.is_empty());
567 for p in &paths {
569 assert_eq!(p.start, "http://a");
570 }
571 for p in &paths {
573 assert_eq!(p.inferred_hops, 0);
574 }
575 }
576
577 #[test]
578 fn test_reason_respects_max_hops() {
579 let config = MultiHopConfig {
580 max_hops: 1,
581 ..Default::default()
582 };
583 let engine = MultiHopEngine::new(config);
584 let seeds = vec![make_seed("http://a", 0.9)];
585 let triples = vec![
586 make_triple("http://a", "http://p", "http://b"),
587 make_triple("http://b", "http://p", "http://c"),
588 ];
589 let paths = engine
590 .reason(&seeds, &triples, &[])
591 .expect("should succeed");
592 for p in &paths {
593 assert!(
594 p.hop_count() <= 1,
595 "Path hop count {} > max_hops 1",
596 p.hop_count()
597 );
598 }
599 }
600
601 #[test]
602 fn test_reason_respects_max_paths() {
603 let config = MultiHopConfig {
604 max_paths: 2,
605 ..Default::default()
606 };
607 let engine = MultiHopEngine::new(config);
608 let seeds = vec![make_seed("http://a", 0.9)];
609 let triples: Vec<Triple> = (0..20)
610 .map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
611 .collect();
612 let paths = engine
613 .reason(&seeds, &triples, &[])
614 .expect("should succeed");
615 assert!(paths.len() <= 2);
616 }
617
618 #[test]
621 fn test_reason_transitivity_rule() {
622 let config = MultiHopConfig {
623 max_hops: 3,
624 include_inferred: true,
625 ..Default::default()
626 };
627 let engine = MultiHopEngine::new(config);
628 let seeds = vec![make_seed("http://a", 0.9)];
629 let triples = vec![
630 make_triple("http://a", "http://subClassOf", "http://b"),
631 make_triple("http://b", "http://subClassOf", "http://c"),
632 ];
633 let rules = vec![transitivity_rule("http://subClassOf")];
634 let paths = engine
635 .reason(&seeds, &triples, &rules)
636 .expect("should succeed");
637 assert!(!paths.is_empty());
638 let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
640 assert!(
641 has_inferred,
642 "Expected at least one path with inferred hops"
643 );
644 }
645
646 #[test]
647 fn test_reason_no_inferred_when_disabled() {
648 let config = MultiHopConfig {
649 include_inferred: false,
650 ..Default::default()
651 };
652 let engine = MultiHopEngine::new(config);
653 let seeds = vec![make_seed("http://a", 0.9)];
654 let triples = vec![
655 make_triple("http://a", "http://subClassOf", "http://b"),
656 make_triple("http://b", "http://subClassOf", "http://c"),
657 ];
658 let rules = vec![transitivity_rule("http://subClassOf")];
659 let paths = engine
660 .reason(&seeds, &triples, &rules)
661 .expect("should succeed");
662 for p in &paths {
663 assert_eq!(
664 p.inferred_hops, 0,
665 "Expected no inferred hops when disabled"
666 );
667 }
668 }
669
670 #[test]
673 fn test_score_inverse_hop_count() {
674 let config = MultiHopConfig {
675 scoring_fn: PathScoringFn::InverseHopCount,
676 ..Default::default()
677 };
678 let engine = MultiHopEngine::new(config);
679 assert!((engine.score_path(1, 0, 1.0) - 1.0).abs() < 1e-9);
681 assert!((engine.score_path(2, 0, 1.0) - 0.5).abs() < 1e-9);
682 }
683
684 #[test]
685 fn test_score_seed_weighted() {
686 let config = MultiHopConfig {
687 scoring_fn: PathScoringFn::SeedWeighted,
688 ..Default::default()
689 };
690 let engine = MultiHopEngine::new(config);
691 let s = engine.score_path(2, 0, 0.8);
692 assert!((s - 0.4).abs() < 1e-9);
693 }
694
695 #[test]
696 fn test_score_uniform() {
697 let config = MultiHopConfig {
698 scoring_fn: PathScoringFn::Uniform,
699 ..Default::default()
700 };
701 let engine = MultiHopEngine::new(config);
702 assert_eq!(engine.score_path(5, 3, 0.5), 1.0);
703 }
704
705 #[test]
706 fn test_score_inference_penalised() {
707 let config = MultiHopConfig {
708 scoring_fn: PathScoringFn::InferencePenalised,
709 ..Default::default()
710 };
711 let engine = MultiHopEngine::new(config);
712 let s_no_inf = engine.score_path(2, 0, 1.0);
713 let s_with_inf = engine.score_path(2, 1, 1.0);
714 assert!(s_no_inf > s_with_inf, "Inferred hop should reduce score");
715 }
716
717 #[test]
720 fn test_blocked_predicates_filter() {
721 let mut config = MultiHopConfig::default();
722 config
723 .blocked_predicates
724 .insert("http://p/blocked".to_string());
725 let engine = MultiHopEngine::new(config);
726 let seeds = vec![make_seed("http://a", 0.9)];
727 let triples = vec![
728 make_triple("http://a", "http://p/allowed", "http://b"),
729 make_triple("http://a", "http://p/blocked", "http://c"),
730 ];
731 let paths = engine
732 .reason(&seeds, &triples, &[])
733 .expect("should succeed");
734 for p in &paths {
736 for e in &p.edges {
737 assert_ne!(
738 e.predicate, "http://p/blocked",
739 "Blocked predicate found in path"
740 );
741 }
742 }
743 }
744
745 #[test]
746 fn test_allowed_predicates_whitelist() {
747 let mut config = MultiHopConfig::default();
748 config
749 .allowed_predicates
750 .insert("http://p/allowed".to_string());
751 let engine = MultiHopEngine::new(config);
752 let seeds = vec![make_seed("http://a", 0.9)];
753 let triples = vec![
754 make_triple("http://a", "http://p/allowed", "http://b"),
755 make_triple("http://a", "http://p/other", "http://c"),
756 ];
757 let paths = engine
758 .reason(&seeds, &triples, &[])
759 .expect("should succeed");
760 for p in &paths {
761 for e in &p.edges {
762 assert_eq!(e.predicate, "http://p/allowed");
763 }
764 }
765 }
766
767 #[test]
770 fn test_hop_path_fields() {
771 let engine = MultiHopEngine::default();
772 let seeds = vec![make_seed("http://a", 0.8)];
773 let triples = vec![make_triple("http://a", "http://p", "http://b")];
774 let paths = engine
775 .reason(&seeds, &triples, &[])
776 .expect("should succeed");
777 assert!(!paths.is_empty());
778 let path = &paths[0];
779 assert_eq!(path.start, "http://a");
780 assert_eq!(path.end, "http://b");
781 assert_eq!(path.hop_count(), 1);
782 assert!(!path.has_inferred_hop());
783 }
784
785 #[test]
786 fn test_min_path_score_threshold() {
787 let config = MultiHopConfig {
788 min_path_score: 99.0, ..Default::default()
790 };
791 let engine = MultiHopEngine::new(config);
792 let seeds = vec![make_seed("http://a", 0.9)];
793 let triples = vec![make_triple("http://a", "http://p", "http://b")];
794 let paths = engine
795 .reason(&seeds, &triples, &[])
796 .expect("should succeed");
797 assert!(paths.is_empty());
798 }
799
800 #[test]
803 fn test_triples_to_atoms_roundtrip() {
804 let triples = vec![Triple::new("http://s", "http://p", "http://o")];
805 let atoms = triples_to_atoms(&triples);
806 assert_eq!(atoms.len(), 1);
807 if let RuleAtom::Triple {
808 subject,
809 predicate,
810 object,
811 } = &atoms[0]
812 {
813 assert_eq!(term_to_str(subject).expect("should succeed"), "http://s");
814 assert_eq!(term_to_str(predicate).expect("should succeed"), "http://p");
815 assert_eq!(term_to_str(object).expect("should succeed"), "http://o");
816 } else {
817 panic!("Expected Triple atom");
818 }
819 }
820
821 #[test]
822 fn test_property_chain_produces_derived_edges() {
823 let config = MultiHopConfig {
824 max_hops: 3,
825 include_inferred: true,
826 ..Default::default()
827 };
828 let engine = MultiHopEngine::new(config);
829 let seeds = vec![make_seed("http://a", 0.9)];
830 let triples = vec![
831 make_triple("http://a", "http://partOf", "http://b"),
832 make_triple("http://b", "http://locatedIn", "http://c"),
833 ];
834 let rules = vec![property_chain_rule(
835 "http://partOf",
836 "http://locatedIn",
837 "http://indirectlyIn",
838 )];
839 let paths = engine
840 .reason(&seeds, &triples, &rules)
841 .expect("should succeed");
842 assert!(!paths.is_empty());
843 }
844}
845
846#[cfg(test)]
849mod additional_tests {
850 use super::*;
851 use crate::ScoreSource;
852
853 fn make_seed(uri: &str, score: f64) -> ScoredEntity {
854 ScoredEntity {
855 uri: uri.to_string(),
856 score,
857 source: ScoreSource::Vector,
858 metadata: HashMap::new(),
859 }
860 }
861
862 fn make_triple(s: &str, p: &str, o: &str) -> Triple {
863 Triple::new(s, p, o)
864 }
865
866 #[test]
869 fn test_symmetry_rule_head_swapped() {
870 let rule = symmetry_rule("http://sameAs");
871 assert!(matches!(&rule.head[0], RuleAtom::Triple { .. }));
873 if let RuleAtom::Triple {
874 subject, object, ..
875 } = &rule.head[0]
876 {
877 match (subject, object) {
879 (Term::Variable(sv), Term::Variable(ov)) => {
880 assert_ne!(sv, ov, "Head subject and object variables should differ");
881 }
882 _ => panic!("Expected Variable terms in head"),
883 }
884 } else {
885 panic!("Expected Triple head");
886 }
887 }
888
889 #[test]
890 fn test_transitivity_rule_has_shared_variable() {
891 let rule = transitivity_rule("http://subClassOf");
893 if let (RuleAtom::Triple { object: obj0, .. }, RuleAtom::Triple { subject: subj1, .. }) =
894 (&rule.body[0], &rule.body[1])
895 {
896 matches!(obj0, Term::Variable(v) if v == "Y");
898 matches!(subj1, Term::Variable(v) if v == "Y");
899 }
900 }
901
902 #[test]
903 fn test_property_chain_rule_body_predicates() {
904 let rule = property_chain_rule("http://partOf", "http://locatedIn", "http://indirectlyIn");
905 if let RuleAtom::Triple { predicate: p1, .. } = &rule.body[0] {
906 assert_eq!(term_to_str(p1).expect("should succeed"), "http://partOf");
907 }
908 if let RuleAtom::Triple { predicate: p2, .. } = &rule.body[1] {
909 assert_eq!(term_to_str(p2).expect("should succeed"), "http://locatedIn");
910 }
911 if let RuleAtom::Triple { predicate: ph, .. } = &rule.head[0] {
912 assert_eq!(
913 term_to_str(ph).expect("should succeed"),
914 "http://indirectlyIn"
915 );
916 }
917 }
918
919 #[test]
922 fn test_term_to_str_constant() {
923 let t = Term::Constant("http://example.org/x".to_string());
924 assert_eq!(
925 term_to_str(&t).expect("should succeed"),
926 "http://example.org/x"
927 );
928 }
929
930 #[test]
931 fn test_term_to_str_literal() {
932 let t = Term::Literal("hello world".to_string());
933 assert_eq!(term_to_str(&t).expect("should succeed"), "hello world");
934 }
935
936 #[test]
937 fn test_term_to_str_variable_returns_none() {
938 let t = Term::Variable("X".to_string());
939 assert!(term_to_str(&t).is_none());
940 }
941
942 #[test]
945 fn test_atoms_to_edges_filters_non_triple() {
946 let atoms = vec![RuleAtom::Triple {
947 subject: Term::Constant("http://s".to_string()),
948 predicate: Term::Constant("http://p".to_string()),
949 object: Term::Constant("http://o".to_string()),
950 }];
951 let edges = atoms_to_edges(&atoms, false);
952 assert_eq!(edges.len(), 1);
953 assert_eq!(edges[0].subject, "http://s");
954 assert!(!edges[0].inferred);
955 }
956
957 #[test]
958 fn test_atoms_to_edges_inferred_flag() {
959 let atoms = vec![RuleAtom::Triple {
960 subject: Term::Constant("http://s".to_string()),
961 predicate: Term::Constant("http://p".to_string()),
962 object: Term::Constant("http://o".to_string()),
963 }];
964 let edges = atoms_to_edges(&atoms, true);
965 assert!(edges[0].inferred);
966 }
967
968 #[test]
971 fn test_score_inference_penalised_zero_inferred_equals_inverse() {
972 let config = MultiHopConfig {
973 scoring_fn: PathScoringFn::InferencePenalised,
974 ..Default::default()
975 };
976 let engine = MultiHopEngine::new(config);
977 let s = engine.score_path(2, 0, 1.0);
979 assert!((s - 0.5).abs() < 1e-9);
980 }
981
982 #[test]
983 fn test_score_seed_weighted_scales_with_seed_score() {
984 let config = MultiHopConfig {
985 scoring_fn: PathScoringFn::SeedWeighted,
986 ..Default::default()
987 };
988 let engine = MultiHopEngine::new(config);
989 let s1 = engine.score_path(1, 0, 1.0);
990 let s2 = engine.score_path(1, 0, 0.5);
991 assert!((s1 - 2.0 * s2).abs() < 1e-9, "s1={s1}, s2={s2}");
992 }
993
994 #[test]
997 fn test_graph_edge_equality() {
998 let e1 = GraphEdge {
999 subject: "http://a".to_string(),
1000 predicate: "http://p".to_string(),
1001 object: "http://b".to_string(),
1002 inferred: false,
1003 };
1004 let e2 = e1.clone();
1005 assert_eq!(e1, e2);
1006 }
1007
1008 #[test]
1011 fn test_multihop_config_defaults() {
1012 let cfg = MultiHopConfig::default();
1013 assert_eq!(cfg.max_hops, 3);
1014 assert_eq!(cfg.max_paths, 50);
1015 assert!(cfg.include_inferred);
1016 assert!(cfg.allowed_predicates.is_empty());
1017 assert!(cfg.blocked_predicates.is_empty());
1018 }
1019
1020 #[test]
1023 fn test_reason_cycle_in_graph_does_not_loop() {
1024 let engine = MultiHopEngine::default();
1026 let seeds = vec![make_seed("http://a", 0.9)];
1027 let triples = vec![
1028 make_triple("http://a", "http://p", "http://b"),
1029 make_triple("http://b", "http://p", "http://c"),
1030 make_triple("http://c", "http://p", "http://a"), ];
1032 let paths = engine.reason(&seeds, &triples, &[]);
1033 assert!(paths.is_ok(), "Should not error on cyclic graphs");
1034 let paths = paths.expect("should succeed");
1036 assert!(paths.len() < 1000, "Cycle guard should bound path count");
1037 }
1038
1039 #[test]
1040 fn test_hop_path_has_inferred_hop_false_for_asserted() {
1041 let path = HopPath {
1042 edges: vec![GraphEdge {
1043 subject: "http://s".to_string(),
1044 predicate: "http://p".to_string(),
1045 object: "http://o".to_string(),
1046 inferred: false,
1047 }],
1048 start: "http://s".to_string(),
1049 end: "http://o".to_string(),
1050 score: 0.8,
1051 inferred_hops: 0,
1052 fired_rules: vec![],
1053 };
1054 assert!(!path.has_inferred_hop());
1055 assert_eq!(path.hop_count(), 1);
1056 }
1057
1058 #[test]
1059 fn test_hop_path_has_inferred_hop_true_for_inferred() {
1060 let path = HopPath {
1061 edges: vec![GraphEdge {
1062 subject: "http://s".to_string(),
1063 predicate: "http://p".to_string(),
1064 object: "http://o".to_string(),
1065 inferred: true,
1066 }],
1067 start: "http://s".to_string(),
1068 end: "http://o".to_string(),
1069 score: 0.5,
1070 inferred_hops: 1,
1071 fired_rules: vec!["rule1".to_string()],
1072 };
1073 assert!(path.has_inferred_hop());
1074 }
1075
1076 #[test]
1079 fn test_reason_two_seeds_produce_paths_from_both() {
1080 let engine = MultiHopEngine::default();
1081 let seeds = vec![make_seed("http://a", 0.9), make_seed("http://x", 0.8)];
1082 let triples = vec![
1083 make_triple("http://a", "http://p", "http://b"),
1084 make_triple("http://x", "http://q", "http://y"),
1085 ];
1086 let paths = engine
1087 .reason(&seeds, &triples, &[])
1088 .expect("should succeed");
1089 let from_a = paths.iter().any(|p| p.start == "http://a");
1090 let from_x = paths.iter().any(|p| p.start == "http://x");
1091 assert!(from_a, "Expected paths from http://a");
1092 assert!(from_x, "Expected paths from http://x");
1093 }
1094
1095 #[test]
1098 fn test_reason_symmetry_rule_adds_reverse_edge() {
1099 let config = MultiHopConfig {
1100 max_hops: 2,
1101 include_inferred: true,
1102 ..Default::default()
1103 };
1104 let engine = MultiHopEngine::new(config);
1105 let seeds = vec![make_seed("http://b", 0.9)];
1106 let triples = vec![make_triple("http://a", "http://sameAs", "http://b")];
1107 let rules = vec![symmetry_rule("http://sameAs")];
1108 let paths = engine
1109 .reason(&seeds, &triples, &rules)
1110 .expect("should succeed");
1111 let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
1113 assert!(has_inferred, "Symmetry rule should create inferred edges");
1114 }
1115
1116 #[test]
1119 fn test_reason_budget_guard_limits_expansion() {
1120 let config = MultiHopConfig {
1121 max_edges_budget: 5, max_hops: 10,
1123 max_paths: 1000,
1124 ..Default::default()
1125 };
1126 let engine = MultiHopEngine::new(config);
1127 let seeds = vec![make_seed("http://a", 0.9)];
1128 let triples: Vec<Triple> = (0..100)
1130 .map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
1131 .collect();
1132 let paths = engine
1133 .reason(&seeds, &triples, &[])
1134 .expect("should succeed");
1135 assert!(paths.len() < 100, "Budget guard should limit path count");
1137 }
1138
1139 #[test]
1142 fn test_reason_paths_sorted_descending() {
1143 let engine = MultiHopEngine::default();
1144 let seeds = vec![make_seed("http://a", 1.0)];
1145 let triples = vec![
1146 make_triple("http://a", "http://p", "http://b"),
1147 make_triple("http://b", "http://p", "http://c"),
1148 make_triple("http://c", "http://p", "http://d"),
1149 ];
1150 let paths = engine
1151 .reason(&seeds, &triples, &[])
1152 .expect("should succeed");
1153 for i in 1..paths.len() {
1154 assert!(
1155 paths[i - 1].score >= paths[i].score,
1156 "Paths should be sorted descending: {} < {}",
1157 paths[i - 1].score,
1158 paths[i].score
1159 );
1160 }
1161 }
1162}