Skip to main content

oxirs_graphrag/reasoning/
multihop.rs

1//! Multi-hop reasoning with rule engine integration
2//!
3//! This module implements rule-guided multi-hop graph traversal.  Given a set
4//! of seed entities and a collection of inference rules (expressed using the
5//! oxirs-rule `Rule` / `RuleAtom` / `Term` API), the engine:
6//!
7//! 1. Loads the subgraph into the rule engine as `RuleAtom::Triple` facts.
8//! 2. Runs forward chaining to materialise derived facts.
9//! 3. Selects paths between seeds and goal nodes that pass through the
10//!    derived facts, ranking them by a configurable path scoring function.
11//! 4. Returns scored `HopPath` objects that can be fed back into a context
12//!    builder or summariser.
13//!
14//! The module is intentionally self-contained (no external ML dependencies)
15//! and operates purely on the RDF triple model from this crate.
16
17use crate::{GraphRAGError, GraphRAGResult, ScoredEntity, Triple};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20// ─── Re-export rule-engine types ────────────────────────────────────────────
21
22// We keep a lightweight re-export to decouple from oxirs-rule's public API
23// changes (only the pieces we actually use).
24use oxirs_rule::{Rule, RuleAtom, RuleEngine, Term};
25
26// ─── Types ──────────────────────────────────────────────────────────────────
27
28/// Maps a (subject, predicate, object) triple key to the list of rule names that fired it.
29pub type FiredRulesMap = HashMap<(String, String, String), Vec<String>>;
30
31/// A directed edge in the knowledge graph
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct GraphEdge {
34    pub subject: String,
35    pub predicate: String,
36    pub object: String,
37    /// Whether this edge was derived by rule inference (vs. asserted)
38    pub inferred: bool,
39}
40
41/// A multi-hop path through the knowledge graph
42#[derive(Debug, Clone)]
43pub struct HopPath {
44    /// Ordered list of edges traversed
45    pub edges: Vec<GraphEdge>,
46    /// Starting entity
47    pub start: String,
48    /// Ending entity
49    pub end: String,
50    /// Path score (higher = more relevant)
51    pub score: f64,
52    /// Number of inferred edges on this path
53    pub inferred_hops: usize,
54    /// Rule names that fired to produce inferred edges
55    pub fired_rules: Vec<String>,
56}
57
58impl HopPath {
59    /// Total number of hops (edges)
60    pub fn hop_count(&self) -> usize {
61        self.edges.len()
62    }
63
64    /// Whether this path contains at least one inferred edge
65    pub fn has_inferred_hop(&self) -> bool {
66        self.inferred_hops > 0
67    }
68}
69
70// ─── Configuration ──────────────────────────────────────────────────────────
71
72/// Path scoring function variant
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
74pub enum PathScoringFn {
75    /// Score = 1 / hop_count (prefer shorter paths)
76    InverseHopCount,
77    /// Score = seed_score / hop_count
78    #[default]
79    SeedWeighted,
80    /// Uniform score of 1.0 for all paths
81    Uniform,
82    /// Penalise inferred hops: score = (1/hop_count) * (0.8 ^ inferred_hops)
83    InferencePenalised,
84}
85
86/// Configuration for multi-hop reasoning
87#[derive(Debug, Clone)]
88pub struct MultiHopConfig {
89    /// Maximum hop count per path
90    pub max_hops: usize,
91    /// Maximum number of paths to return
92    pub max_paths: usize,
93    /// Maximum edges to process during BFS (budget guard)
94    pub max_edges_budget: usize,
95    /// Whether to include inferred (rule-derived) edges
96    pub include_inferred: bool,
97    /// Path scoring function
98    pub scoring_fn: PathScoringFn,
99    /// Predicates to follow (empty = all)
100    pub allowed_predicates: HashSet<String>,
101    /// Predicates to skip
102    pub blocked_predicates: HashSet<String>,
103    /// Minimum path score threshold
104    pub min_path_score: f64,
105}
106
107impl Default for MultiHopConfig {
108    fn default() -> Self {
109        Self {
110            max_hops: 3,
111            max_paths: 50,
112            max_edges_budget: 100_000,
113            include_inferred: true,
114            scoring_fn: PathScoringFn::SeedWeighted,
115            allowed_predicates: HashSet::new(),
116            blocked_predicates: HashSet::new(),
117            min_path_score: 0.0,
118        }
119    }
120}
121
122// ─── Graph builder from Rule atoms ──────────────────────────────────────────
123
124fn atoms_to_edges(atoms: &[RuleAtom], inferred: bool) -> Vec<GraphEdge> {
125    atoms
126        .iter()
127        .filter_map(|atom| match atom {
128            RuleAtom::Triple {
129                subject,
130                predicate,
131                object,
132            } => {
133                let s = term_to_str(subject)?;
134                let p = term_to_str(predicate)?;
135                let o = term_to_str(object)?;
136                Some(GraphEdge {
137                    subject: s,
138                    predicate: p,
139                    object: o,
140                    inferred,
141                })
142            }
143            _ => None,
144        })
145        .collect()
146}
147
148fn term_to_str(term: &Term) -> Option<String> {
149    match term {
150        Term::Constant(c) | Term::Literal(c) => Some(c.clone()),
151        Term::Variable(_) | Term::Function { .. } => None, // unbound variables
152    }
153}
154
155fn triples_to_atoms(triples: &[Triple]) -> Vec<RuleAtom> {
156    triples
157        .iter()
158        .map(|t| RuleAtom::Triple {
159            subject: Term::Constant(t.subject.clone()),
160            predicate: Term::Constant(t.predicate.clone()),
161            object: Term::Constant(t.object.clone()),
162        })
163        .collect()
164}
165
166// ─── Multi-hop engine ────────────────────────────────────────────────────────
167
168/// Multi-hop reasoning engine backed by the oxirs-rule RuleEngine
169pub struct MultiHopEngine {
170    config: MultiHopConfig,
171}
172
173impl Default for MultiHopEngine {
174    fn default() -> Self {
175        Self::new(MultiHopConfig::default())
176    }
177}
178
179impl MultiHopEngine {
180    pub fn new(config: MultiHopConfig) -> Self {
181        Self { config }
182    }
183
184    /// Run multi-hop reasoning over `subgraph`, guided by `rules`.
185    ///
186    /// Returns all scored paths starting from `seeds`.
187    pub fn reason(
188        &self,
189        seeds: &[ScoredEntity],
190        subgraph: &[Triple],
191        rules: &[Rule],
192    ) -> GraphRAGResult<Vec<HopPath>> {
193        if seeds.is_empty() || subgraph.is_empty() {
194            return Ok(vec![]);
195        }
196
197        // 1. Materialise inferred facts with the rule engine
198        let (asserted_edges, inferred_edges, fired_rule_map) = self.materialise(subgraph, rules)?;
199
200        // 2. Build adjacency index
201        let mut all_edges: Vec<GraphEdge> = asserted_edges;
202        if self.config.include_inferred {
203            all_edges.extend(inferred_edges);
204        }
205
206        let adj = self.build_adjacency(&all_edges);
207
208        // 3. BFS/DFS from each seed to find paths
209        let mut paths: Vec<HopPath> = Vec::new();
210        let seed_map: HashMap<String, f64> =
211            seeds.iter().map(|s| (s.uri.clone(), s.score)).collect();
212
213        for seed in seeds {
214            let new_paths =
215                self.bfs_paths(&seed.uri, seed.score, &adj, &all_edges, &fired_rule_map);
216            paths.extend(new_paths);
217        }
218
219        // 4. Score and filter
220        paths.retain(|p| p.score >= self.config.min_path_score);
221        paths.sort_by(|a, b| {
222            b.score
223                .partial_cmp(&a.score)
224                .unwrap_or(std::cmp::Ordering::Equal)
225        });
226        paths.truncate(self.config.max_paths);
227
228        // Suppress unused warning for seed_map (used conceptually above)
229        let _ = seed_map;
230
231        Ok(paths)
232    }
233
234    /// Materialise inferred facts and return (asserted_edges, inferred_edges, fired_rules_by_triple)
235    fn materialise(
236        &self,
237        subgraph: &[Triple],
238        rules: &[Rule],
239    ) -> GraphRAGResult<(Vec<GraphEdge>, Vec<GraphEdge>, FiredRulesMap)> {
240        let asserted_edges = atoms_to_edges(&triples_to_atoms(subgraph), false);
241
242        if rules.is_empty() {
243            return Ok((asserted_edges, vec![], HashMap::new()));
244        }
245
246        let mut engine = RuleEngine::new();
247        engine.add_rules(rules.to_vec());
248        engine.enable_cache();
249
250        let facts = triples_to_atoms(subgraph);
251
252        let inferred_atoms = engine
253            .forward_chain(&facts)
254            .map_err(|e| GraphRAGError::InternalError(format!("Rule engine error: {e}")))?;
255
256        // Collect inferred triples (skip those already in subgraph)
257        let asserted_keys: HashSet<(String, String, String)> = subgraph
258            .iter()
259            .map(|t| (t.subject.clone(), t.predicate.clone(), t.object.clone()))
260            .collect();
261
262        let inferred_edges: Vec<GraphEdge> = atoms_to_edges(&inferred_atoms, true)
263            .into_iter()
264            .filter(|e| {
265                !asserted_keys.contains(&(e.subject.clone(), e.predicate.clone(), e.object.clone()))
266            })
267            .collect();
268
269        // Build a map from triple → fired rule names (approximation: one rule per triple)
270        let fired_rule_map: FiredRulesMap = rules
271            .iter()
272            .flat_map(|rule| {
273                rule.head.iter().filter_map(|atom| match atom {
274                    RuleAtom::Triple {
275                        subject,
276                        predicate,
277                        object,
278                    } => {
279                        let s = term_to_str(subject)?;
280                        let p = term_to_str(predicate)?;
281                        let o = term_to_str(object)?;
282                        Some(((s, p, o), rule.name.clone()))
283                    }
284                    _ => None,
285                })
286            })
287            .fold(HashMap::new(), |mut acc, (key, rule_name)| {
288                acc.entry(key).or_default().push(rule_name);
289                acc
290            });
291
292        Ok((asserted_edges, inferred_edges, fired_rule_map))
293    }
294
295    /// Build an adjacency list: node → list of edge indices
296    fn build_adjacency(&self, edges: &[GraphEdge]) -> HashMap<String, Vec<usize>> {
297        let mut adj: HashMap<String, Vec<usize>> = HashMap::new();
298        for (i, edge) in edges.iter().enumerate() {
299            if self.allow_predicate(&edge.predicate) {
300                adj.entry(edge.subject.clone()).or_default().push(i);
301            }
302        }
303        adj
304    }
305
306    fn allow_predicate(&self, pred: &str) -> bool {
307        if !self.config.allowed_predicates.is_empty()
308            && !self.config.allowed_predicates.contains(pred)
309        {
310            return false;
311        }
312        !self.config.blocked_predicates.contains(pred)
313    }
314
315    /// BFS from a seed node; returns all valid paths
316    fn bfs_paths(
317        &self,
318        start: &str,
319        seed_score: f64,
320        adj: &HashMap<String, Vec<usize>>,
321        edges: &[GraphEdge],
322        fired_rule_map: &HashMap<(String, String, String), Vec<String>>,
323    ) -> Vec<HopPath> {
324        // Queue entry: (current_node, path_so_far_edge_indices, visited_nodes)
325        struct State {
326            node: String,
327            edge_path: Vec<usize>,
328            visited: HashSet<String>,
329        }
330
331        let mut queue: VecDeque<State> = VecDeque::new();
332        queue.push_back(State {
333            node: start.to_string(),
334            edge_path: vec![],
335            visited: {
336                let mut h = HashSet::new();
337                h.insert(start.to_string());
338                h
339            },
340        });
341
342        let mut paths: Vec<HopPath> = Vec::new();
343        let mut budget = self.config.max_edges_budget;
344
345        while let Some(state) = queue.pop_front() {
346            if budget == 0 {
347                break;
348            }
349            budget -= 1;
350
351            if state.edge_path.len() > self.config.max_hops {
352                continue;
353            }
354
355            // If we have at least one hop, record as a path
356            if !state.edge_path.is_empty() {
357                let path_edges: Vec<GraphEdge> =
358                    state.edge_path.iter().map(|&i| edges[i].clone()).collect();
359
360                let inferred_hops = path_edges.iter().filter(|e| e.inferred).count();
361                let fired_rules: Vec<String> = path_edges
362                    .iter()
363                    .filter(|e| e.inferred)
364                    .flat_map(|e| {
365                        let key = (e.subject.clone(), e.predicate.clone(), e.object.clone());
366                        fired_rule_map.get(&key).cloned().unwrap_or_default()
367                    })
368                    .collect::<HashSet<_>>()
369                    .into_iter()
370                    .collect();
371
372                let score = self.score_path(state.edge_path.len(), inferred_hops, seed_score);
373
374                paths.push(HopPath {
375                    edges: path_edges,
376                    start: start.to_string(),
377                    end: state.node.clone(),
378                    score,
379                    inferred_hops,
380                    fired_rules,
381                });
382
383                if paths.len() >= self.config.max_paths {
384                    return paths;
385                }
386            }
387
388            if state.edge_path.len() >= self.config.max_hops {
389                continue;
390            }
391
392            // Expand neighbours
393            if let Some(edge_indices) = adj.get(&state.node) {
394                for &ei in edge_indices {
395                    let edge = &edges[ei];
396                    if !state.visited.contains(&edge.object) {
397                        let mut new_visited = state.visited.clone();
398                        new_visited.insert(edge.object.clone());
399                        let mut new_path = state.edge_path.clone();
400                        new_path.push(ei);
401                        queue.push_back(State {
402                            node: edge.object.clone(),
403                            edge_path: new_path,
404                            visited: new_visited,
405                        });
406                    }
407                }
408            }
409        }
410
411        paths
412    }
413
414    fn score_path(&self, hops: usize, inferred_hops: usize, seed_score: f64) -> f64 {
415        let h = hops.max(1) as f64;
416        match self.config.scoring_fn {
417            PathScoringFn::InverseHopCount => 1.0 / h,
418            PathScoringFn::SeedWeighted => seed_score / h,
419            PathScoringFn::Uniform => 1.0,
420            PathScoringFn::InferencePenalised => (1.0 / h) * 0.8_f64.powi(inferred_hops as i32),
421        }
422    }
423}
424
425// ─── Convenience: build rules from SPARQL-like property chains ──────────────
426
427/// Build a transitivity rule for a given predicate
428/// e.g.  `subClassOf(X,Y) ∧ subClassOf(Y,Z) → subClassOf(X,Z)`
429pub fn transitivity_rule(predicate: &str) -> Rule {
430    Rule {
431        name: format!("{predicate}_transitive"),
432        body: vec![
433            RuleAtom::Triple {
434                subject: Term::Variable("X".to_string()),
435                predicate: Term::Constant(predicate.to_string()),
436                object: Term::Variable("Y".to_string()),
437            },
438            RuleAtom::Triple {
439                subject: Term::Variable("Y".to_string()),
440                predicate: Term::Constant(predicate.to_string()),
441                object: Term::Variable("Z".to_string()),
442            },
443        ],
444        head: vec![RuleAtom::Triple {
445            subject: Term::Variable("X".to_string()),
446            predicate: Term::Constant(predicate.to_string()),
447            object: Term::Variable("Z".to_string()),
448        }],
449    }
450}
451
452/// Build a property chain rule:  p1(X,Y) ∧ p2(Y,Z) → q(X,Z)
453pub fn property_chain_rule(p1: &str, p2: &str, conclusion_pred: &str) -> Rule {
454    Rule {
455        name: format!("{p1}_{p2}_chain"),
456        body: vec![
457            RuleAtom::Triple {
458                subject: Term::Variable("X".to_string()),
459                predicate: Term::Constant(p1.to_string()),
460                object: Term::Variable("Y".to_string()),
461            },
462            RuleAtom::Triple {
463                subject: Term::Variable("Y".to_string()),
464                predicate: Term::Constant(p2.to_string()),
465                object: Term::Variable("Z".to_string()),
466            },
467        ],
468        head: vec![RuleAtom::Triple {
469            subject: Term::Variable("X".to_string()),
470            predicate: Term::Constant(conclusion_pred.to_string()),
471            object: Term::Variable("Z".to_string()),
472        }],
473    }
474}
475
476/// Build a symmetry rule:  p(X,Y) → p(Y,X)
477pub fn symmetry_rule(predicate: &str) -> Rule {
478    Rule {
479        name: format!("{predicate}_symmetric"),
480        body: vec![RuleAtom::Triple {
481            subject: Term::Variable("X".to_string()),
482            predicate: Term::Constant(predicate.to_string()),
483            object: Term::Variable("Y".to_string()),
484        }],
485        head: vec![RuleAtom::Triple {
486            subject: Term::Variable("Y".to_string()),
487            predicate: Term::Constant(predicate.to_string()),
488            object: Term::Variable("X".to_string()),
489        }],
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::ScoreSource;
497
498    fn make_seed(uri: &str, score: f64) -> ScoredEntity {
499        ScoredEntity {
500            uri: uri.to_string(),
501            score,
502            source: ScoreSource::Vector,
503            metadata: HashMap::new(),
504        }
505    }
506
507    fn make_triple(s: &str, p: &str, o: &str) -> Triple {
508        Triple::new(s, p, o)
509    }
510
511    // ── rule helper tests ─────────────────────────────────────────────────
512
513    #[test]
514    fn test_transitivity_rule_structure() {
515        let rule = transitivity_rule("subClassOf");
516        assert_eq!(rule.name, "subClassOf_transitive");
517        assert_eq!(rule.body.len(), 2);
518        assert_eq!(rule.head.len(), 1);
519    }
520
521    #[test]
522    fn test_property_chain_rule_structure() {
523        let rule = property_chain_rule("partOf", "locatedIn", "indirectlyIn");
524        assert_eq!(rule.name, "partOf_locatedIn_chain");
525        assert_eq!(rule.body.len(), 2);
526    }
527
528    #[test]
529    fn test_symmetry_rule_structure() {
530        let rule = symmetry_rule("sameAs");
531        assert_eq!(rule.name, "sameAs_symmetric");
532        assert_eq!(rule.body.len(), 1);
533        assert_eq!(rule.head.len(), 1);
534    }
535
536    // ── MultiHopEngine – basic ─────────────────────────────────────────────
537
538    #[test]
539    fn test_reason_empty_seeds() {
540        let engine = MultiHopEngine::default();
541        let triples = vec![make_triple("http://a", "http://rel", "http://b")];
542        let result = engine.reason(&[], &triples, &[]).expect("should succeed");
543        assert!(result.is_empty());
544    }
545
546    #[test]
547    fn test_reason_empty_subgraph() {
548        let engine = MultiHopEngine::default();
549        let seeds = vec![make_seed("http://a", 0.9)];
550        let result = engine.reason(&seeds, &[], &[]).expect("should succeed");
551        assert!(result.is_empty());
552    }
553
554    #[test]
555    fn test_reason_single_hop_no_rules() {
556        let engine = MultiHopEngine::default();
557        let seeds = vec![make_seed("http://a", 0.9)];
558        let triples = vec![
559            make_triple("http://a", "http://p/rel", "http://b"),
560            make_triple("http://b", "http://p/rel", "http://c"),
561            make_triple("http://x", "http://p/other", "http://y"),
562        ];
563        let paths = engine
564            .reason(&seeds, &triples, &[])
565            .expect("should succeed");
566        assert!(!paths.is_empty());
567        // All paths should start from http://a
568        for p in &paths {
569            assert_eq!(p.start, "http://a");
570        }
571        // There should be no inferred hops (no rules)
572        for p in &paths {
573            assert_eq!(p.inferred_hops, 0);
574        }
575    }
576
577    #[test]
578    fn test_reason_respects_max_hops() {
579        let config = MultiHopConfig {
580            max_hops: 1,
581            ..Default::default()
582        };
583        let engine = MultiHopEngine::new(config);
584        let seeds = vec![make_seed("http://a", 0.9)];
585        let triples = vec![
586            make_triple("http://a", "http://p", "http://b"),
587            make_triple("http://b", "http://p", "http://c"),
588        ];
589        let paths = engine
590            .reason(&seeds, &triples, &[])
591            .expect("should succeed");
592        for p in &paths {
593            assert!(
594                p.hop_count() <= 1,
595                "Path hop count {} > max_hops 1",
596                p.hop_count()
597            );
598        }
599    }
600
601    #[test]
602    fn test_reason_respects_max_paths() {
603        let config = MultiHopConfig {
604            max_paths: 2,
605            ..Default::default()
606        };
607        let engine = MultiHopEngine::new(config);
608        let seeds = vec![make_seed("http://a", 0.9)];
609        let triples: Vec<Triple> = (0..20)
610            .map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
611            .collect();
612        let paths = engine
613            .reason(&seeds, &triples, &[])
614            .expect("should succeed");
615        assert!(paths.len() <= 2);
616    }
617
618    // ── MultiHopEngine – with rules ───────────────────────────────────────
619
620    #[test]
621    fn test_reason_transitivity_rule() {
622        let config = MultiHopConfig {
623            max_hops: 3,
624            include_inferred: true,
625            ..Default::default()
626        };
627        let engine = MultiHopEngine::new(config);
628        let seeds = vec![make_seed("http://a", 0.9)];
629        let triples = vec![
630            make_triple("http://a", "http://subClassOf", "http://b"),
631            make_triple("http://b", "http://subClassOf", "http://c"),
632        ];
633        let rules = vec![transitivity_rule("http://subClassOf")];
634        let paths = engine
635            .reason(&seeds, &triples, &rules)
636            .expect("should succeed");
637        assert!(!paths.is_empty());
638        // At least some paths should have inferred hops
639        let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
640        assert!(
641            has_inferred,
642            "Expected at least one path with inferred hops"
643        );
644    }
645
646    #[test]
647    fn test_reason_no_inferred_when_disabled() {
648        let config = MultiHopConfig {
649            include_inferred: false,
650            ..Default::default()
651        };
652        let engine = MultiHopEngine::new(config);
653        let seeds = vec![make_seed("http://a", 0.9)];
654        let triples = vec![
655            make_triple("http://a", "http://subClassOf", "http://b"),
656            make_triple("http://b", "http://subClassOf", "http://c"),
657        ];
658        let rules = vec![transitivity_rule("http://subClassOf")];
659        let paths = engine
660            .reason(&seeds, &triples, &rules)
661            .expect("should succeed");
662        for p in &paths {
663            assert_eq!(
664                p.inferred_hops, 0,
665                "Expected no inferred hops when disabled"
666            );
667        }
668    }
669
670    // ── Path scoring ──────────────────────────────────────────────────────
671
672    #[test]
673    fn test_score_inverse_hop_count() {
674        let config = MultiHopConfig {
675            scoring_fn: PathScoringFn::InverseHopCount,
676            ..Default::default()
677        };
678        let engine = MultiHopEngine::new(config);
679        // 1-hop path: score = 1.0, 2-hop: score = 0.5
680        assert!((engine.score_path(1, 0, 1.0) - 1.0).abs() < 1e-9);
681        assert!((engine.score_path(2, 0, 1.0) - 0.5).abs() < 1e-9);
682    }
683
684    #[test]
685    fn test_score_seed_weighted() {
686        let config = MultiHopConfig {
687            scoring_fn: PathScoringFn::SeedWeighted,
688            ..Default::default()
689        };
690        let engine = MultiHopEngine::new(config);
691        let s = engine.score_path(2, 0, 0.8);
692        assert!((s - 0.4).abs() < 1e-9);
693    }
694
695    #[test]
696    fn test_score_uniform() {
697        let config = MultiHopConfig {
698            scoring_fn: PathScoringFn::Uniform,
699            ..Default::default()
700        };
701        let engine = MultiHopEngine::new(config);
702        assert_eq!(engine.score_path(5, 3, 0.5), 1.0);
703    }
704
705    #[test]
706    fn test_score_inference_penalised() {
707        let config = MultiHopConfig {
708            scoring_fn: PathScoringFn::InferencePenalised,
709            ..Default::default()
710        };
711        let engine = MultiHopEngine::new(config);
712        let s_no_inf = engine.score_path(2, 0, 1.0);
713        let s_with_inf = engine.score_path(2, 1, 1.0);
714        assert!(s_no_inf > s_with_inf, "Inferred hop should reduce score");
715    }
716
717    // ── Predicate filtering ───────────────────────────────────────────────
718
719    #[test]
720    fn test_blocked_predicates_filter() {
721        let mut config = MultiHopConfig::default();
722        config
723            .blocked_predicates
724            .insert("http://p/blocked".to_string());
725        let engine = MultiHopEngine::new(config);
726        let seeds = vec![make_seed("http://a", 0.9)];
727        let triples = vec![
728            make_triple("http://a", "http://p/allowed", "http://b"),
729            make_triple("http://a", "http://p/blocked", "http://c"),
730        ];
731        let paths = engine
732            .reason(&seeds, &triples, &[])
733            .expect("should succeed");
734        // No path should traverse the blocked predicate
735        for p in &paths {
736            for e in &p.edges {
737                assert_ne!(
738                    e.predicate, "http://p/blocked",
739                    "Blocked predicate found in path"
740                );
741            }
742        }
743    }
744
745    #[test]
746    fn test_allowed_predicates_whitelist() {
747        let mut config = MultiHopConfig::default();
748        config
749            .allowed_predicates
750            .insert("http://p/allowed".to_string());
751        let engine = MultiHopEngine::new(config);
752        let seeds = vec![make_seed("http://a", 0.9)];
753        let triples = vec![
754            make_triple("http://a", "http://p/allowed", "http://b"),
755            make_triple("http://a", "http://p/other", "http://c"),
756        ];
757        let paths = engine
758            .reason(&seeds, &triples, &[])
759            .expect("should succeed");
760        for p in &paths {
761            for e in &p.edges {
762                assert_eq!(e.predicate, "http://p/allowed");
763            }
764        }
765    }
766
767    // ── Path metadata ─────────────────────────────────────────────────────
768
769    #[test]
770    fn test_hop_path_fields() {
771        let engine = MultiHopEngine::default();
772        let seeds = vec![make_seed("http://a", 0.8)];
773        let triples = vec![make_triple("http://a", "http://p", "http://b")];
774        let paths = engine
775            .reason(&seeds, &triples, &[])
776            .expect("should succeed");
777        assert!(!paths.is_empty());
778        let path = &paths[0];
779        assert_eq!(path.start, "http://a");
780        assert_eq!(path.end, "http://b");
781        assert_eq!(path.hop_count(), 1);
782        assert!(!path.has_inferred_hop());
783    }
784
785    #[test]
786    fn test_min_path_score_threshold() {
787        let config = MultiHopConfig {
788            min_path_score: 99.0, // impossible
789            ..Default::default()
790        };
791        let engine = MultiHopEngine::new(config);
792        let seeds = vec![make_seed("http://a", 0.9)];
793        let triples = vec![make_triple("http://a", "http://p", "http://b")];
794        let paths = engine
795            .reason(&seeds, &triples, &[])
796            .expect("should succeed");
797        assert!(paths.is_empty());
798    }
799
800    // ── Helper functions ──────────────────────────────────────────────────
801
802    #[test]
803    fn test_triples_to_atoms_roundtrip() {
804        let triples = vec![Triple::new("http://s", "http://p", "http://o")];
805        let atoms = triples_to_atoms(&triples);
806        assert_eq!(atoms.len(), 1);
807        if let RuleAtom::Triple {
808            subject,
809            predicate,
810            object,
811        } = &atoms[0]
812        {
813            assert_eq!(term_to_str(subject).expect("should succeed"), "http://s");
814            assert_eq!(term_to_str(predicate).expect("should succeed"), "http://p");
815            assert_eq!(term_to_str(object).expect("should succeed"), "http://o");
816        } else {
817            panic!("Expected Triple atom");
818        }
819    }
820
821    #[test]
822    fn test_property_chain_produces_derived_edges() {
823        let config = MultiHopConfig {
824            max_hops: 3,
825            include_inferred: true,
826            ..Default::default()
827        };
828        let engine = MultiHopEngine::new(config);
829        let seeds = vec![make_seed("http://a", 0.9)];
830        let triples = vec![
831            make_triple("http://a", "http://partOf", "http://b"),
832            make_triple("http://b", "http://locatedIn", "http://c"),
833        ];
834        let rules = vec![property_chain_rule(
835            "http://partOf",
836            "http://locatedIn",
837            "http://indirectlyIn",
838        )];
839        let paths = engine
840            .reason(&seeds, &triples, &rules)
841            .expect("should succeed");
842        assert!(!paths.is_empty());
843    }
844}
845
846// ─── Additional tests ─────────────────────────────────────────────────────────
847
848#[cfg(test)]
849mod additional_tests {
850    use super::*;
851    use crate::ScoreSource;
852
853    fn make_seed(uri: &str, score: f64) -> ScoredEntity {
854        ScoredEntity {
855            uri: uri.to_string(),
856            score,
857            source: ScoreSource::Vector,
858            metadata: HashMap::new(),
859        }
860    }
861
862    fn make_triple(s: &str, p: &str, o: &str) -> Triple {
863        Triple::new(s, p, o)
864    }
865
866    // ── Rule builder tests ────────────────────────────────────────────────
867
868    #[test]
869    fn test_symmetry_rule_head_swapped() {
870        let rule = symmetry_rule("http://sameAs");
871        // head should have subject=Variable(Y), object=Variable(X) — swapped from body
872        assert!(matches!(&rule.head[0], RuleAtom::Triple { .. }));
873        if let RuleAtom::Triple {
874            subject, object, ..
875        } = &rule.head[0]
876        {
877            // Both are Variable terms; they should be different variable names
878            match (subject, object) {
879                (Term::Variable(sv), Term::Variable(ov)) => {
880                    assert_ne!(sv, ov, "Head subject and object variables should differ");
881                }
882                _ => panic!("Expected Variable terms in head"),
883            }
884        } else {
885            panic!("Expected Triple head");
886        }
887    }
888
889    #[test]
890    fn test_transitivity_rule_has_shared_variable() {
891        // body[0].object == body[1].subject (shared Y)
892        let rule = transitivity_rule("http://subClassOf");
893        if let (RuleAtom::Triple { object: obj0, .. }, RuleAtom::Triple { subject: subj1, .. }) =
894            (&rule.body[0], &rule.body[1])
895        {
896            // Both should be a Variable("Y")
897            matches!(obj0, Term::Variable(v) if v == "Y");
898            matches!(subj1, Term::Variable(v) if v == "Y");
899        }
900    }
901
902    #[test]
903    fn test_property_chain_rule_body_predicates() {
904        let rule = property_chain_rule("http://partOf", "http://locatedIn", "http://indirectlyIn");
905        if let RuleAtom::Triple { predicate: p1, .. } = &rule.body[0] {
906            assert_eq!(term_to_str(p1).expect("should succeed"), "http://partOf");
907        }
908        if let RuleAtom::Triple { predicate: p2, .. } = &rule.body[1] {
909            assert_eq!(term_to_str(p2).expect("should succeed"), "http://locatedIn");
910        }
911        if let RuleAtom::Triple { predicate: ph, .. } = &rule.head[0] {
912            assert_eq!(
913                term_to_str(ph).expect("should succeed"),
914                "http://indirectlyIn"
915            );
916        }
917    }
918
919    // ── term_to_str helper ────────────────────────────────────────────────
920
921    #[test]
922    fn test_term_to_str_constant() {
923        let t = Term::Constant("http://example.org/x".to_string());
924        assert_eq!(
925            term_to_str(&t).expect("should succeed"),
926            "http://example.org/x"
927        );
928    }
929
930    #[test]
931    fn test_term_to_str_literal() {
932        let t = Term::Literal("hello world".to_string());
933        assert_eq!(term_to_str(&t).expect("should succeed"), "hello world");
934    }
935
936    #[test]
937    fn test_term_to_str_variable_returns_none() {
938        let t = Term::Variable("X".to_string());
939        assert!(term_to_str(&t).is_none());
940    }
941
942    // ── atoms_to_edges helper ─────────────────────────────────────────────
943
944    #[test]
945    fn test_atoms_to_edges_filters_non_triple() {
946        let atoms = vec![RuleAtom::Triple {
947            subject: Term::Constant("http://s".to_string()),
948            predicate: Term::Constant("http://p".to_string()),
949            object: Term::Constant("http://o".to_string()),
950        }];
951        let edges = atoms_to_edges(&atoms, false);
952        assert_eq!(edges.len(), 1);
953        assert_eq!(edges[0].subject, "http://s");
954        assert!(!edges[0].inferred);
955    }
956
957    #[test]
958    fn test_atoms_to_edges_inferred_flag() {
959        let atoms = vec![RuleAtom::Triple {
960            subject: Term::Constant("http://s".to_string()),
961            predicate: Term::Constant("http://p".to_string()),
962            object: Term::Constant("http://o".to_string()),
963        }];
964        let edges = atoms_to_edges(&atoms, true);
965        assert!(edges[0].inferred);
966    }
967
968    // ── Path scoring ──────────────────────────────────────────────────────
969
970    #[test]
971    fn test_score_inference_penalised_zero_inferred_equals_inverse() {
972        let config = MultiHopConfig {
973            scoring_fn: PathScoringFn::InferencePenalised,
974            ..Default::default()
975        };
976        let engine = MultiHopEngine::new(config);
977        // 0 inferred hops → 0.8^0 = 1.0 → same as InverseHopCount
978        let s = engine.score_path(2, 0, 1.0);
979        assert!((s - 0.5).abs() < 1e-9);
980    }
981
982    #[test]
983    fn test_score_seed_weighted_scales_with_seed_score() {
984        let config = MultiHopConfig {
985            scoring_fn: PathScoringFn::SeedWeighted,
986            ..Default::default()
987        };
988        let engine = MultiHopEngine::new(config);
989        let s1 = engine.score_path(1, 0, 1.0);
990        let s2 = engine.score_path(1, 0, 0.5);
991        assert!((s1 - 2.0 * s2).abs() < 1e-9, "s1={s1}, s2={s2}");
992    }
993
994    // ── GraphEdge ─────────────────────────────────────────────────────────
995
996    #[test]
997    fn test_graph_edge_equality() {
998        let e1 = GraphEdge {
999            subject: "http://a".to_string(),
1000            predicate: "http://p".to_string(),
1001            object: "http://b".to_string(),
1002            inferred: false,
1003        };
1004        let e2 = e1.clone();
1005        assert_eq!(e1, e2);
1006    }
1007
1008    // ── MultiHopConfig ────────────────────────────────────────────────────
1009
1010    #[test]
1011    fn test_multihop_config_defaults() {
1012        let cfg = MultiHopConfig::default();
1013        assert_eq!(cfg.max_hops, 3);
1014        assert_eq!(cfg.max_paths, 50);
1015        assert!(cfg.include_inferred);
1016        assert!(cfg.allowed_predicates.is_empty());
1017        assert!(cfg.blocked_predicates.is_empty());
1018    }
1019
1020    // ── Cycle detection ───────────────────────────────────────────────────
1021
1022    #[test]
1023    fn test_reason_cycle_in_graph_does_not_loop() {
1024        // a → b → c → a (cycle); BFS should not loop infinitely
1025        let engine = MultiHopEngine::default();
1026        let seeds = vec![make_seed("http://a", 0.9)];
1027        let triples = vec![
1028            make_triple("http://a", "http://p", "http://b"),
1029            make_triple("http://b", "http://p", "http://c"),
1030            make_triple("http://c", "http://p", "http://a"), // back edge
1031        ];
1032        let paths = engine.reason(&seeds, &triples, &[]);
1033        assert!(paths.is_ok(), "Should not error on cyclic graphs");
1034        // Should return some paths without hanging
1035        let paths = paths.expect("should succeed");
1036        assert!(paths.len() < 1000, "Cycle guard should bound path count");
1037    }
1038
1039    #[test]
1040    fn test_hop_path_has_inferred_hop_false_for_asserted() {
1041        let path = HopPath {
1042            edges: vec![GraphEdge {
1043                subject: "http://s".to_string(),
1044                predicate: "http://p".to_string(),
1045                object: "http://o".to_string(),
1046                inferred: false,
1047            }],
1048            start: "http://s".to_string(),
1049            end: "http://o".to_string(),
1050            score: 0.8,
1051            inferred_hops: 0,
1052            fired_rules: vec![],
1053        };
1054        assert!(!path.has_inferred_hop());
1055        assert_eq!(path.hop_count(), 1);
1056    }
1057
1058    #[test]
1059    fn test_hop_path_has_inferred_hop_true_for_inferred() {
1060        let path = HopPath {
1061            edges: vec![GraphEdge {
1062                subject: "http://s".to_string(),
1063                predicate: "http://p".to_string(),
1064                object: "http://o".to_string(),
1065                inferred: true,
1066            }],
1067            start: "http://s".to_string(),
1068            end: "http://o".to_string(),
1069            score: 0.5,
1070            inferred_hops: 1,
1071            fired_rules: vec!["rule1".to_string()],
1072        };
1073        assert!(path.has_inferred_hop());
1074    }
1075
1076    // ── Multiple seeds produce paths from each ────────────────────────────
1077
1078    #[test]
1079    fn test_reason_two_seeds_produce_paths_from_both() {
1080        let engine = MultiHopEngine::default();
1081        let seeds = vec![make_seed("http://a", 0.9), make_seed("http://x", 0.8)];
1082        let triples = vec![
1083            make_triple("http://a", "http://p", "http://b"),
1084            make_triple("http://x", "http://q", "http://y"),
1085        ];
1086        let paths = engine
1087            .reason(&seeds, &triples, &[])
1088            .expect("should succeed");
1089        let from_a = paths.iter().any(|p| p.start == "http://a");
1090        let from_x = paths.iter().any(|p| p.start == "http://x");
1091        assert!(from_a, "Expected paths from http://a");
1092        assert!(from_x, "Expected paths from http://x");
1093    }
1094
1095    // ── Symmetry rule produces inferred edges ─────────────────────────────
1096
1097    #[test]
1098    fn test_reason_symmetry_rule_adds_reverse_edge() {
1099        let config = MultiHopConfig {
1100            max_hops: 2,
1101            include_inferred: true,
1102            ..Default::default()
1103        };
1104        let engine = MultiHopEngine::new(config);
1105        let seeds = vec![make_seed("http://b", 0.9)];
1106        let triples = vec![make_triple("http://a", "http://sameAs", "http://b")];
1107        let rules = vec![symmetry_rule("http://sameAs")];
1108        let paths = engine
1109            .reason(&seeds, &triples, &rules)
1110            .expect("should succeed");
1111        // Should find path from http://b via the inferred reverse edge
1112        let has_inferred = paths.iter().any(|p| p.has_inferred_hop());
1113        assert!(has_inferred, "Symmetry rule should create inferred edges");
1114    }
1115
1116    // ── Budget guard ──────────────────────────────────────────────────────
1117
1118    #[test]
1119    fn test_reason_budget_guard_limits_expansion() {
1120        let config = MultiHopConfig {
1121            max_edges_budget: 5, // very small
1122            max_hops: 10,
1123            max_paths: 1000,
1124            ..Default::default()
1125        };
1126        let engine = MultiHopEngine::new(config);
1127        let seeds = vec![make_seed("http://a", 0.9)];
1128        // Large star graph: http://a → http://n0..n99
1129        let triples: Vec<Triple> = (0..100)
1130            .map(|i| make_triple("http://a", "http://p", &format!("http://n{i}")))
1131            .collect();
1132        let paths = engine
1133            .reason(&seeds, &triples, &[])
1134            .expect("should succeed");
1135        // Budget=5 should stop after at most a few paths
1136        assert!(paths.len() < 100, "Budget guard should limit path count");
1137    }
1138
1139    // ── Scoring sorted descending ─────────────────────────────────────────
1140
1141    #[test]
1142    fn test_reason_paths_sorted_descending() {
1143        let engine = MultiHopEngine::default();
1144        let seeds = vec![make_seed("http://a", 1.0)];
1145        let triples = vec![
1146            make_triple("http://a", "http://p", "http://b"),
1147            make_triple("http://b", "http://p", "http://c"),
1148            make_triple("http://c", "http://p", "http://d"),
1149        ];
1150        let paths = engine
1151            .reason(&seeds, &triples, &[])
1152            .expect("should succeed");
1153        for i in 1..paths.len() {
1154            assert!(
1155                paths[i - 1].score >= paths[i].score,
1156                "Paths should be sorted descending: {} < {}",
1157                paths[i - 1].score,
1158                paths[i].score
1159            );
1160        }
1161    }
1162}