Skip to main content

oxirs_graphrag/reasoning/
path_finder.rs

1//! Multi-hop path finder with configurable path scoring strategies
2//!
3//! Provides `MultiHopPathFinder` for finding paths between entities in a
4//! knowledge graph, supporting Uniform, AttentionWeighted, and PathLength scoring.
5
6use crate::Triple;
7use std::collections::{HashMap, HashSet, VecDeque};
8
9// ── PathScoring ────────────────────────────────────────────────────────────────
10
11/// Strategy for scoring multi-hop paths
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum PathScoring {
14    /// All paths receive the same score of 1.0
15    Uniform,
16    /// Score decreases with path length: 1.0 / (1 + hop_count)
17    PathLength,
18    /// Simulated attention-weighted: assigns higher weight to earlier hops
19    /// score = sum over i of 1/(i+1), normalised by hop_count
20    #[default]
21    AttentionWeighted,
22}
23
24// ── MultiHopReasoningConfig ────────────────────────────────────────────────────
25
26/// Configuration for multi-hop path finding
27#[derive(Debug, Clone)]
28pub struct MultiHopReasoningConfig {
29    /// Maximum number of hops to traverse per path
30    pub max_hops: u8,
31    /// Minimum confidence threshold for a path to be returned
32    pub min_confidence: f64,
33    /// Path scoring strategy
34    pub path_scoring: PathScoring,
35    /// Maximum number of paths to return per (start, end) pair
36    pub max_paths_per_pair: usize,
37    /// Maximum BFS frontier size to prevent explosion on dense graphs
38    pub max_frontier: usize,
39}
40
41impl Default for MultiHopReasoningConfig {
42    fn default() -> Self {
43        Self {
44            max_hops: 3,
45            min_confidence: 0.0,
46            path_scoring: PathScoring::default(),
47            max_paths_per_pair: 20,
48            max_frontier: 10_000,
49        }
50    }
51}
52
53// ── HopPath ────────────────────────────────────────────────────────────────────
54
55/// A single path through the knowledge graph
56#[derive(Debug, Clone)]
57pub struct HopPath {
58    /// Ordered entity URIs from start to end
59    pub entities: Vec<String>,
60    /// Relation predicates connecting consecutive entities
61    pub relations: Vec<String>,
62    /// Path score computed by the chosen `PathScoring` strategy
63    pub score: f64,
64}
65
66impl HopPath {
67    /// Number of hops (edges) in this path
68    pub fn hop_count(&self) -> usize {
69        self.relations.len()
70    }
71}
72
73// ── KnowledgeGraph (adjacency helper) ──────────────────────────────────────────
74
75/// Lightweight adjacency representation built from RDF triples.
76pub struct KnowledgeGraph {
77    /// Forward adjacency: subject -> Vec<(predicate, object)>
78    adj: HashMap<String, Vec<(String, String)>>,
79}
80
81impl KnowledgeGraph {
82    /// Build from a slice of RDF triples.
83    pub fn from_triples(triples: &[Triple]) -> Self {
84        let mut adj: HashMap<String, Vec<(String, String)>> = HashMap::new();
85        for t in triples {
86            adj.entry(t.subject.clone())
87                .or_default()
88                .push((t.predicate.clone(), t.object.clone()));
89        }
90        Self { adj }
91    }
92
93    /// Iterate over the (predicate, object) neighbours of a node.
94    pub fn neighbours(&self, node: &str) -> &[(String, String)] {
95        self.adj.get(node).map(|v| v.as_slice()).unwrap_or(&[])
96    }
97
98    /// Number of unique subject nodes
99    pub fn node_count(&self) -> usize {
100        self.adj.len()
101    }
102}
103
104// ── MultiHopPathFinder ─────────────────────────────────────────────────────────
105
106/// Finds multi-hop paths between entities using BFS.
107pub struct MultiHopPathFinder {
108    config: MultiHopReasoningConfig,
109}
110
111impl MultiHopPathFinder {
112    /// Create with the given config.
113    pub fn new(config: MultiHopReasoningConfig) -> Self {
114        Self { config }
115    }
116
117    /// Create with default config.
118    pub fn with_defaults() -> Self {
119        Self::new(MultiHopReasoningConfig::default())
120    }
121
122    /// Find all paths from `start` to `end` in the given graph, up to `max_hops`.
123    ///
124    /// Returns paths sorted descending by score, limited to `max_paths_per_pair`.
125    pub fn find_paths(
126        &self,
127        start: &str,
128        end: &str,
129        max_hops: u8,
130        graph: &KnowledgeGraph,
131    ) -> Vec<HopPath> {
132        // BFS: queue holds (current_node, entity_path, relation_path, visited)
133        struct State {
134            node: String,
135            entities: Vec<String>,
136            relations: Vec<String>,
137            visited: HashSet<String>,
138        }
139
140        let mut queue: VecDeque<State> = VecDeque::new();
141        queue.push_back(State {
142            node: start.to_string(),
143            entities: vec![start.to_string()],
144            relations: vec![],
145            visited: {
146                let mut h = HashSet::new();
147                h.insert(start.to_string());
148                h
149            },
150        });
151
152        let mut paths: Vec<HopPath> = Vec::new();
153        let mut frontier_visited = 0usize;
154
155        while let Some(state) = queue.pop_front() {
156            if frontier_visited >= self.config.max_frontier {
157                break;
158            }
159            frontier_visited += 1;
160
161            let hops_so_far = state.relations.len() as u8;
162
163            if hops_so_far >= max_hops {
164                continue;
165            }
166
167            for (pred, obj) in graph.neighbours(&state.node) {
168                if state.visited.contains(obj) {
169                    continue;
170                }
171                let mut new_entities = state.entities.clone();
172                new_entities.push(obj.clone());
173                let mut new_relations = state.relations.clone();
174                new_relations.push(pred.clone());
175
176                if obj == end {
177                    let score = self.score_path(&new_relations, &self.config.path_scoring);
178                    if score >= self.config.min_confidence {
179                        paths.push(HopPath {
180                            entities: new_entities,
181                            relations: new_relations,
182                            score,
183                        });
184                        if paths.len() >= self.config.max_paths_per_pair {
185                            paths.sort_by(|a, b| {
186                                b.score
187                                    .partial_cmp(&a.score)
188                                    .unwrap_or(std::cmp::Ordering::Equal)
189                            });
190                            return paths;
191                        }
192                    }
193                } else {
194                    let mut new_visited = state.visited.clone();
195                    new_visited.insert(obj.clone());
196                    queue.push_back(State {
197                        node: obj.clone(),
198                        entities: new_entities,
199                        relations: new_relations,
200                        visited: new_visited,
201                    });
202                }
203            }
204        }
205
206        paths.sort_by(|a, b| {
207            b.score
208                .partial_cmp(&a.score)
209                .unwrap_or(std::cmp::Ordering::Equal)
210        });
211        paths
212    }
213
214    /// Score a path given a list of relations and the chosen strategy.
215    pub fn score_path(&self, relations: &[String], scoring: &PathScoring) -> f64 {
216        let hops = relations.len();
217        match scoring {
218            PathScoring::Uniform => 1.0,
219            PathScoring::PathLength => 1.0 / (1.0 + hops as f64),
220            PathScoring::AttentionWeighted => {
221                if hops == 0 {
222                    return 0.0;
223                }
224                // sum(1/(i+1)) for i in 0..hops, divided by hops for normalisation
225                let sum: f64 = (0..hops).map(|i| 1.0 / (i as f64 + 1.0)).sum();
226                sum / hops as f64
227            }
228        }
229    }
230}
231
232// ── Tests ─────────────────────────────────────────────────────────────────────
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn simple_graph() -> (Vec<Triple>, KnowledgeGraph) {
239        let triples = vec![
240            Triple::new("http://a", "http://rel/r1", "http://b"),
241            Triple::new("http://b", "http://rel/r2", "http://c"),
242            Triple::new("http://c", "http://rel/r3", "http://d"),
243            Triple::new("http://a", "http://rel/direct", "http://c"),
244        ];
245        let graph = KnowledgeGraph::from_triples(&triples);
246        (triples, graph)
247    }
248
249    // ── PathScoring ──────────────────────────────────────────────────────
250
251    #[test]
252    fn test_path_scoring_default_is_attention_weighted() {
253        assert_eq!(PathScoring::default(), PathScoring::AttentionWeighted);
254    }
255
256    #[test]
257    fn test_score_uniform_always_one() {
258        let finder = MultiHopPathFinder::with_defaults();
259        let rels: Vec<String> = vec!["r1".to_string(), "r2".to_string(), "r3".to_string()];
260        let s = finder.score_path(&rels, &PathScoring::Uniform);
261        assert!((s - 1.0).abs() < f64::EPSILON);
262    }
263
264    #[test]
265    fn test_score_path_length_decreases_with_hops() {
266        let finder = MultiHopPathFinder::with_defaults();
267        let s1 = finder.score_path(&["r".to_string()], &PathScoring::PathLength);
268        let s2 = finder.score_path(
269            &["r".to_string(), "r2".to_string()],
270            &PathScoring::PathLength,
271        );
272        assert!(s1 > s2, "Longer path should score lower: {s1} vs {s2}");
273    }
274
275    #[test]
276    fn test_score_attention_weighted_single_hop() {
277        let finder = MultiHopPathFinder::with_defaults();
278        let s = finder.score_path(&["r".to_string()], &PathScoring::AttentionWeighted);
279        // 1 hop: sum=1.0, div by 1 = 1.0
280        assert!((s - 1.0).abs() < f64::EPSILON);
281    }
282
283    #[test]
284    fn test_score_attention_weighted_two_hops() {
285        let finder = MultiHopPathFinder::with_defaults();
286        let rels: Vec<String> = vec!["r1".to_string(), "r2".to_string()];
287        let s = finder.score_path(&rels, &PathScoring::AttentionWeighted);
288        // sum = 1 + 0.5 = 1.5, div by 2 = 0.75
289        assert!((s - 0.75).abs() < 1e-9, "Expected 0.75, got {s}");
290    }
291
292    #[test]
293    fn test_score_attention_weighted_empty_is_zero() {
294        let finder = MultiHopPathFinder::with_defaults();
295        let s = finder.score_path(&[], &PathScoring::AttentionWeighted);
296        assert!((s - 0.0).abs() < f64::EPSILON);
297    }
298
299    // ── MultiHopReasoningConfig ──────────────────────────────────────────
300
301    #[test]
302    fn test_config_defaults() {
303        let cfg = MultiHopReasoningConfig::default();
304        assert_eq!(cfg.max_hops, 3);
305        assert!((cfg.min_confidence - 0.0).abs() < f64::EPSILON);
306        assert_eq!(cfg.path_scoring, PathScoring::AttentionWeighted);
307        assert_eq!(cfg.max_paths_per_pair, 20);
308    }
309
310    // ── KnowledgeGraph ───────────────────────────────────────────────────
311
312    #[test]
313    fn test_knowledge_graph_node_count() {
314        let (_, graph) = simple_graph();
315        assert_eq!(graph.node_count(), 3); // a, b, c (d and a-direct-c are objects only for unique subjects)
316    }
317
318    #[test]
319    fn test_knowledge_graph_neighbours() {
320        let triples = vec![Triple::new("http://x", "http://p", "http://y")];
321        let graph = KnowledgeGraph::from_triples(&triples);
322        let nb = graph.neighbours("http://x");
323        assert_eq!(nb.len(), 1);
324        assert_eq!(nb[0].0, "http://p");
325        assert_eq!(nb[0].1, "http://y");
326    }
327
328    #[test]
329    fn test_knowledge_graph_missing_node_returns_empty() {
330        let graph = KnowledgeGraph::from_triples(&[]);
331        assert!(graph.neighbours("http://nobody").is_empty());
332    }
333
334    // ── MultiHopPathFinder::find_paths ────────────────────────────────────
335
336    #[test]
337    fn test_find_paths_direct_one_hop() {
338        let (_, graph) = simple_graph();
339        let finder = MultiHopPathFinder::with_defaults();
340        let paths = finder.find_paths("http://a", "http://b", 1, &graph);
341        assert!(!paths.is_empty());
342        assert_eq!(paths[0].hop_count(), 1);
343    }
344
345    #[test]
346    fn test_find_paths_two_hop() {
347        let (_, graph) = simple_graph();
348        let finder = MultiHopPathFinder::with_defaults();
349        let paths = finder.find_paths("http://a", "http://c", 3, &graph);
350        // direct (1 hop) and via b (2 hops) — both valid
351        assert!(!paths.is_empty());
352        let hop_counts: Vec<usize> = paths.iter().map(|p| p.hop_count()).collect();
353        assert!(
354            hop_counts.contains(&1) || hop_counts.contains(&2),
355            "Expected 1- or 2-hop path"
356        );
357    }
358
359    #[test]
360    fn test_find_paths_no_path_returns_empty() {
361        let triples = vec![Triple::new("http://a", "http://p", "http://b")];
362        let graph = KnowledgeGraph::from_triples(&triples);
363        let finder = MultiHopPathFinder::with_defaults();
364        // No path from b to a (directed)
365        let paths = finder.find_paths("http://b", "http://a", 3, &graph);
366        assert!(paths.is_empty());
367    }
368
369    #[test]
370    fn test_find_paths_respects_max_hops() {
371        let (_, graph) = simple_graph();
372        let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
373            max_hops: 1,
374            ..Default::default()
375        });
376        // Only 1-hop direct path a->c should be found; 2-hop via b should NOT
377        let paths = finder.find_paths("http://a", "http://c", 1, &graph);
378        for p in &paths {
379            assert!(
380                p.hop_count() <= 1,
381                "Found path with {} hops > max 1",
382                p.hop_count()
383            );
384        }
385    }
386
387    #[test]
388    fn test_find_paths_sorted_descending() {
389        let (_, graph) = simple_graph();
390        let finder = MultiHopPathFinder::with_defaults();
391        let paths = finder.find_paths("http://a", "http://c", 3, &graph);
392        for i in 1..paths.len() {
393            assert!(
394                paths[i - 1].score >= paths[i].score,
395                "Paths not sorted: {} < {}",
396                paths[i - 1].score,
397                paths[i].score
398            );
399        }
400    }
401
402    #[test]
403    fn test_find_paths_min_confidence_filters() {
404        let (_, graph) = simple_graph();
405        let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
406            min_confidence: 10.0, // impossibly high
407            ..Default::default()
408        });
409        let paths = finder.find_paths("http://a", "http://b", 3, &graph);
410        assert!(paths.is_empty());
411    }
412
413    #[test]
414    fn test_hop_path_hop_count() {
415        let path = HopPath {
416            entities: vec!["a".to_string(), "b".to_string(), "c".to_string()],
417            relations: vec!["r1".to_string(), "r2".to_string()],
418            score: 0.75,
419        };
420        assert_eq!(path.hop_count(), 2);
421    }
422
423    #[test]
424    fn test_find_paths_three_hop() {
425        let (_, graph) = simple_graph();
426        let finder = MultiHopPathFinder::with_defaults();
427        let paths = finder.find_paths("http://a", "http://d", 3, &graph);
428        // a->b->c->d is a 3-hop path
429        assert!(!paths.is_empty());
430        let three_hop = paths.iter().any(|p| p.hop_count() == 3);
431        assert!(three_hop, "Expected at least one 3-hop path");
432    }
433
434    #[test]
435    fn test_find_paths_uniform_scoring() {
436        let (_, graph) = simple_graph();
437        let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
438            path_scoring: PathScoring::Uniform,
439            ..Default::default()
440        });
441        let paths = finder.find_paths("http://a", "http://b", 3, &graph);
442        for p in &paths {
443            assert!((p.score - 1.0).abs() < f64::EPSILON);
444        }
445    }
446
447    #[test]
448    fn test_path_scoring_path_length_formula() {
449        let finder = MultiHopPathFinder::with_defaults();
450        // 2 hops: 1/(1+2) = 1/3
451        let rels = vec!["r1".to_string(), "r2".to_string()];
452        let s = finder.score_path(&rels, &PathScoring::PathLength);
453        assert!((s - 1.0 / 3.0).abs() < 1e-9, "Expected 1/3, got {s}");
454    }
455
456    #[test]
457    fn test_find_paths_max_paths_per_pair() {
458        // Create a star-like graph from a to many targets, all 1-hop
459        let triples: Vec<Triple> = (0..50)
460            .map(|i| Triple::new("http://src", "http://p", format!("http://t{i}")))
461            .chain(std::iter::once(Triple::new(
462                "http://src",
463                "http://p",
464                "http://target",
465            )))
466            .collect();
467        let graph = KnowledgeGraph::from_triples(&triples);
468        let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
469            max_paths_per_pair: 3,
470            ..Default::default()
471        });
472        let paths = finder.find_paths("http://src", "http://target", 1, &graph);
473        assert!(paths.len() <= 3);
474    }
475}