1use crate::Triple;
7use std::collections::{HashMap, HashSet, VecDeque};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum PathScoring {
14 Uniform,
16 PathLength,
18 #[default]
21 AttentionWeighted,
22}
23
24#[derive(Debug, Clone)]
28pub struct MultiHopReasoningConfig {
29 pub max_hops: u8,
31 pub min_confidence: f64,
33 pub path_scoring: PathScoring,
35 pub max_paths_per_pair: usize,
37 pub max_frontier: usize,
39}
40
41impl Default for MultiHopReasoningConfig {
42 fn default() -> Self {
43 Self {
44 max_hops: 3,
45 min_confidence: 0.0,
46 path_scoring: PathScoring::default(),
47 max_paths_per_pair: 20,
48 max_frontier: 10_000,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
57pub struct HopPath {
58 pub entities: Vec<String>,
60 pub relations: Vec<String>,
62 pub score: f64,
64}
65
66impl HopPath {
67 pub fn hop_count(&self) -> usize {
69 self.relations.len()
70 }
71}
72
73pub struct KnowledgeGraph {
77 adj: HashMap<String, Vec<(String, String)>>,
79}
80
81impl KnowledgeGraph {
82 pub fn from_triples(triples: &[Triple]) -> Self {
84 let mut adj: HashMap<String, Vec<(String, String)>> = HashMap::new();
85 for t in triples {
86 adj.entry(t.subject.clone())
87 .or_default()
88 .push((t.predicate.clone(), t.object.clone()));
89 }
90 Self { adj }
91 }
92
93 pub fn neighbours(&self, node: &str) -> &[(String, String)] {
95 self.adj.get(node).map(|v| v.as_slice()).unwrap_or(&[])
96 }
97
98 pub fn node_count(&self) -> usize {
100 self.adj.len()
101 }
102}
103
104pub struct MultiHopPathFinder {
108 config: MultiHopReasoningConfig,
109}
110
111impl MultiHopPathFinder {
112 pub fn new(config: MultiHopReasoningConfig) -> Self {
114 Self { config }
115 }
116
117 pub fn with_defaults() -> Self {
119 Self::new(MultiHopReasoningConfig::default())
120 }
121
122 pub fn find_paths(
126 &self,
127 start: &str,
128 end: &str,
129 max_hops: u8,
130 graph: &KnowledgeGraph,
131 ) -> Vec<HopPath> {
132 struct State {
134 node: String,
135 entities: Vec<String>,
136 relations: Vec<String>,
137 visited: HashSet<String>,
138 }
139
140 let mut queue: VecDeque<State> = VecDeque::new();
141 queue.push_back(State {
142 node: start.to_string(),
143 entities: vec![start.to_string()],
144 relations: vec![],
145 visited: {
146 let mut h = HashSet::new();
147 h.insert(start.to_string());
148 h
149 },
150 });
151
152 let mut paths: Vec<HopPath> = Vec::new();
153 let mut frontier_visited = 0usize;
154
155 while let Some(state) = queue.pop_front() {
156 if frontier_visited >= self.config.max_frontier {
157 break;
158 }
159 frontier_visited += 1;
160
161 let hops_so_far = state.relations.len() as u8;
162
163 if hops_so_far >= max_hops {
164 continue;
165 }
166
167 for (pred, obj) in graph.neighbours(&state.node) {
168 if state.visited.contains(obj) {
169 continue;
170 }
171 let mut new_entities = state.entities.clone();
172 new_entities.push(obj.clone());
173 let mut new_relations = state.relations.clone();
174 new_relations.push(pred.clone());
175
176 if obj == end {
177 let score = self.score_path(&new_relations, &self.config.path_scoring);
178 if score >= self.config.min_confidence {
179 paths.push(HopPath {
180 entities: new_entities,
181 relations: new_relations,
182 score,
183 });
184 if paths.len() >= self.config.max_paths_per_pair {
185 paths.sort_by(|a, b| {
186 b.score
187 .partial_cmp(&a.score)
188 .unwrap_or(std::cmp::Ordering::Equal)
189 });
190 return paths;
191 }
192 }
193 } else {
194 let mut new_visited = state.visited.clone();
195 new_visited.insert(obj.clone());
196 queue.push_back(State {
197 node: obj.clone(),
198 entities: new_entities,
199 relations: new_relations,
200 visited: new_visited,
201 });
202 }
203 }
204 }
205
206 paths.sort_by(|a, b| {
207 b.score
208 .partial_cmp(&a.score)
209 .unwrap_or(std::cmp::Ordering::Equal)
210 });
211 paths
212 }
213
214 pub fn score_path(&self, relations: &[String], scoring: &PathScoring) -> f64 {
216 let hops = relations.len();
217 match scoring {
218 PathScoring::Uniform => 1.0,
219 PathScoring::PathLength => 1.0 / (1.0 + hops as f64),
220 PathScoring::AttentionWeighted => {
221 if hops == 0 {
222 return 0.0;
223 }
224 let sum: f64 = (0..hops).map(|i| 1.0 / (i as f64 + 1.0)).sum();
226 sum / hops as f64
227 }
228 }
229 }
230}
231
232#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn simple_graph() -> (Vec<Triple>, KnowledgeGraph) {
239 let triples = vec![
240 Triple::new("http://a", "http://rel/r1", "http://b"),
241 Triple::new("http://b", "http://rel/r2", "http://c"),
242 Triple::new("http://c", "http://rel/r3", "http://d"),
243 Triple::new("http://a", "http://rel/direct", "http://c"),
244 ];
245 let graph = KnowledgeGraph::from_triples(&triples);
246 (triples, graph)
247 }
248
249 #[test]
252 fn test_path_scoring_default_is_attention_weighted() {
253 assert_eq!(PathScoring::default(), PathScoring::AttentionWeighted);
254 }
255
256 #[test]
257 fn test_score_uniform_always_one() {
258 let finder = MultiHopPathFinder::with_defaults();
259 let rels: Vec<String> = vec!["r1".to_string(), "r2".to_string(), "r3".to_string()];
260 let s = finder.score_path(&rels, &PathScoring::Uniform);
261 assert!((s - 1.0).abs() < f64::EPSILON);
262 }
263
264 #[test]
265 fn test_score_path_length_decreases_with_hops() {
266 let finder = MultiHopPathFinder::with_defaults();
267 let s1 = finder.score_path(&["r".to_string()], &PathScoring::PathLength);
268 let s2 = finder.score_path(
269 &["r".to_string(), "r2".to_string()],
270 &PathScoring::PathLength,
271 );
272 assert!(s1 > s2, "Longer path should score lower: {s1} vs {s2}");
273 }
274
275 #[test]
276 fn test_score_attention_weighted_single_hop() {
277 let finder = MultiHopPathFinder::with_defaults();
278 let s = finder.score_path(&["r".to_string()], &PathScoring::AttentionWeighted);
279 assert!((s - 1.0).abs() < f64::EPSILON);
281 }
282
283 #[test]
284 fn test_score_attention_weighted_two_hops() {
285 let finder = MultiHopPathFinder::with_defaults();
286 let rels: Vec<String> = vec!["r1".to_string(), "r2".to_string()];
287 let s = finder.score_path(&rels, &PathScoring::AttentionWeighted);
288 assert!((s - 0.75).abs() < 1e-9, "Expected 0.75, got {s}");
290 }
291
292 #[test]
293 fn test_score_attention_weighted_empty_is_zero() {
294 let finder = MultiHopPathFinder::with_defaults();
295 let s = finder.score_path(&[], &PathScoring::AttentionWeighted);
296 assert!((s - 0.0).abs() < f64::EPSILON);
297 }
298
299 #[test]
302 fn test_config_defaults() {
303 let cfg = MultiHopReasoningConfig::default();
304 assert_eq!(cfg.max_hops, 3);
305 assert!((cfg.min_confidence - 0.0).abs() < f64::EPSILON);
306 assert_eq!(cfg.path_scoring, PathScoring::AttentionWeighted);
307 assert_eq!(cfg.max_paths_per_pair, 20);
308 }
309
310 #[test]
313 fn test_knowledge_graph_node_count() {
314 let (_, graph) = simple_graph();
315 assert_eq!(graph.node_count(), 3); }
317
318 #[test]
319 fn test_knowledge_graph_neighbours() {
320 let triples = vec![Triple::new("http://x", "http://p", "http://y")];
321 let graph = KnowledgeGraph::from_triples(&triples);
322 let nb = graph.neighbours("http://x");
323 assert_eq!(nb.len(), 1);
324 assert_eq!(nb[0].0, "http://p");
325 assert_eq!(nb[0].1, "http://y");
326 }
327
328 #[test]
329 fn test_knowledge_graph_missing_node_returns_empty() {
330 let graph = KnowledgeGraph::from_triples(&[]);
331 assert!(graph.neighbours("http://nobody").is_empty());
332 }
333
334 #[test]
337 fn test_find_paths_direct_one_hop() {
338 let (_, graph) = simple_graph();
339 let finder = MultiHopPathFinder::with_defaults();
340 let paths = finder.find_paths("http://a", "http://b", 1, &graph);
341 assert!(!paths.is_empty());
342 assert_eq!(paths[0].hop_count(), 1);
343 }
344
345 #[test]
346 fn test_find_paths_two_hop() {
347 let (_, graph) = simple_graph();
348 let finder = MultiHopPathFinder::with_defaults();
349 let paths = finder.find_paths("http://a", "http://c", 3, &graph);
350 assert!(!paths.is_empty());
352 let hop_counts: Vec<usize> = paths.iter().map(|p| p.hop_count()).collect();
353 assert!(
354 hop_counts.contains(&1) || hop_counts.contains(&2),
355 "Expected 1- or 2-hop path"
356 );
357 }
358
359 #[test]
360 fn test_find_paths_no_path_returns_empty() {
361 let triples = vec![Triple::new("http://a", "http://p", "http://b")];
362 let graph = KnowledgeGraph::from_triples(&triples);
363 let finder = MultiHopPathFinder::with_defaults();
364 let paths = finder.find_paths("http://b", "http://a", 3, &graph);
366 assert!(paths.is_empty());
367 }
368
369 #[test]
370 fn test_find_paths_respects_max_hops() {
371 let (_, graph) = simple_graph();
372 let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
373 max_hops: 1,
374 ..Default::default()
375 });
376 let paths = finder.find_paths("http://a", "http://c", 1, &graph);
378 for p in &paths {
379 assert!(
380 p.hop_count() <= 1,
381 "Found path with {} hops > max 1",
382 p.hop_count()
383 );
384 }
385 }
386
387 #[test]
388 fn test_find_paths_sorted_descending() {
389 let (_, graph) = simple_graph();
390 let finder = MultiHopPathFinder::with_defaults();
391 let paths = finder.find_paths("http://a", "http://c", 3, &graph);
392 for i in 1..paths.len() {
393 assert!(
394 paths[i - 1].score >= paths[i].score,
395 "Paths not sorted: {} < {}",
396 paths[i - 1].score,
397 paths[i].score
398 );
399 }
400 }
401
402 #[test]
403 fn test_find_paths_min_confidence_filters() {
404 let (_, graph) = simple_graph();
405 let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
406 min_confidence: 10.0, ..Default::default()
408 });
409 let paths = finder.find_paths("http://a", "http://b", 3, &graph);
410 assert!(paths.is_empty());
411 }
412
413 #[test]
414 fn test_hop_path_hop_count() {
415 let path = HopPath {
416 entities: vec!["a".to_string(), "b".to_string(), "c".to_string()],
417 relations: vec!["r1".to_string(), "r2".to_string()],
418 score: 0.75,
419 };
420 assert_eq!(path.hop_count(), 2);
421 }
422
423 #[test]
424 fn test_find_paths_three_hop() {
425 let (_, graph) = simple_graph();
426 let finder = MultiHopPathFinder::with_defaults();
427 let paths = finder.find_paths("http://a", "http://d", 3, &graph);
428 assert!(!paths.is_empty());
430 let three_hop = paths.iter().any(|p| p.hop_count() == 3);
431 assert!(three_hop, "Expected at least one 3-hop path");
432 }
433
434 #[test]
435 fn test_find_paths_uniform_scoring() {
436 let (_, graph) = simple_graph();
437 let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
438 path_scoring: PathScoring::Uniform,
439 ..Default::default()
440 });
441 let paths = finder.find_paths("http://a", "http://b", 3, &graph);
442 for p in &paths {
443 assert!((p.score - 1.0).abs() < f64::EPSILON);
444 }
445 }
446
447 #[test]
448 fn test_path_scoring_path_length_formula() {
449 let finder = MultiHopPathFinder::with_defaults();
450 let rels = vec!["r1".to_string(), "r2".to_string()];
452 let s = finder.score_path(&rels, &PathScoring::PathLength);
453 assert!((s - 1.0 / 3.0).abs() < 1e-9, "Expected 1/3, got {s}");
454 }
455
456 #[test]
457 fn test_find_paths_max_paths_per_pair() {
458 let triples: Vec<Triple> = (0..50)
460 .map(|i| Triple::new("http://src", "http://p", format!("http://t{i}")))
461 .chain(std::iter::once(Triple::new(
462 "http://src",
463 "http://p",
464 "http://target",
465 )))
466 .collect();
467 let graph = KnowledgeGraph::from_triples(&triples);
468 let finder = MultiHopPathFinder::new(MultiHopReasoningConfig {
469 max_paths_per_pair: 3,
470 ..Default::default()
471 });
472 let paths = finder.find_paths("http://src", "http://target", 1, &graph);
473 assert!(paths.len() <= 3);
474 }
475}