1use crate::entity::Entity;
4use crate::relation::{Direction, Relation};
5use serde::{Deserialize, Serialize};
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TraversalQuery {
12 pub start: String,
14
15 #[serde(skip_serializing_if = "Option::is_none")]
17 pub target: Option<String>,
18
19 #[serde(default = "default_depth")]
21 pub max_depth: u32,
22
23 #[serde(default)]
25 pub direction: Direction,
26
27 #[serde(default)]
29 pub entity_type_filter: Vec<String>,
30
31 #[serde(default)]
33 pub relation_type_filter: Vec<String>,
34
35 #[serde(default)]
37 pub use_weights: bool,
38
39 #[serde(default)]
41 pub all_paths: bool,
42
43 #[serde(default = "default_max_paths")]
45 pub max_paths: usize,
46}
47
48fn default_depth() -> u32 {
49 10
50}
51
52fn default_max_paths() -> usize {
53 5
54}
55
56impl Default for TraversalQuery {
57 fn default() -> Self {
58 Self {
59 start: String::new(),
60 target: None,
61 max_depth: default_depth(),
62 direction: Direction::Both,
63 entity_type_filter: Vec::new(),
64 relation_type_filter: Vec::new(),
65 use_weights: false,
66 all_paths: false,
67 max_paths: default_max_paths(),
68 }
69 }
70}
71
72impl TraversalQuery {
73 pub fn new(start: impl Into<String>) -> Self {
75 Self {
76 start: start.into(),
77 ..Default::default()
78 }
79 }
80
81 pub fn find_path_to(mut self, target: impl Into<String>) -> Self {
83 self.target = Some(target.into());
84 self
85 }
86
87 pub fn with_depth(mut self, depth: u32) -> Self {
89 self.max_depth = depth;
90 self
91 }
92
93 pub fn with_direction(mut self, direction: Direction) -> Self {
95 self.direction = direction;
96 self
97 }
98
99 pub fn filter_entity_types(mut self, types: Vec<String>) -> Self {
101 self.entity_type_filter = types;
102 self
103 }
104
105 pub fn filter_relation_types(mut self, types: Vec<String>) -> Self {
107 self.relation_type_filter = types;
108 self
109 }
110
111 pub fn weighted(mut self) -> Self {
113 self.use_weights = true;
114 self
115 }
116
117 pub fn all_paths(mut self, max: usize) -> Self {
119 self.all_paths = true;
120 self.max_paths = max;
121 self
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GraphPath {
128 pub nodes: Vec<String>,
130
131 pub edges: Vec<PathEdge>,
133
134 pub total_weight: f64,
136
137 pub length: usize,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct PathEdge {
144 pub from: String,
145 pub to: String,
146 pub relation_type: String,
147 pub weight: Option<f64>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct TraversalResult {
153 pub start: String,
155
156 #[serde(skip_serializing_if = "Option::is_none")]
158 pub target: Option<String>,
159
160 pub paths: Vec<GraphPath>,
162
163 pub visited_entities: Vec<String>,
165
166 pub entities: Vec<Entity>,
168
169 pub relations: Vec<Relation>,
171
172 pub stats: TraversalStats,
174}
175
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
178pub struct TraversalStats {
179 pub nodes_visited: usize,
180 pub edges_traversed: usize,
181 pub max_depth_reached: u32,
182 pub path_found: bool,
183}
184
185#[derive(Clone, PartialEq)]
187struct DijkstraState {
188 cost: f64,
189 node: String,
190}
191
192impl Eq for DijkstraState {}
193
194impl Ord for DijkstraState {
195 fn cmp(&self, other: &Self) -> Ordering {
196 other
198 .cost
199 .partial_cmp(&self.cost)
200 .unwrap_or(Ordering::Equal)
201 }
202}
203
204impl PartialOrd for DijkstraState {
205 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
206 Some(self.cmp(other))
207 }
208}
209
210pub struct TraversalEngine;
212
213impl TraversalEngine {
214 pub fn execute(
216 query: &TraversalQuery,
217 entities: &HashMap<String, Entity>,
218 relations: &[Relation],
219 ) -> TraversalResult {
220 tracing::debug!(
221 "Executing traversal: start={}, target={:?}, depth={}, direction={:?}",
222 query.start,
223 query.target,
224 query.max_depth,
225 query.direction
226 );
227
228 if query.target.is_some() {
229 if query.use_weights {
230 Self::dijkstra_path(query, entities, relations)
231 } else {
232 Self::bfs_path(query, entities, relations)
233 }
234 } else {
235 Self::filtered_bfs(query, entities, relations)
236 }
237 }
238
239 fn bfs_path(
241 query: &TraversalQuery,
242 entities: &HashMap<String, Entity>,
243 relations: &[Relation],
244 ) -> TraversalResult {
245 let target = query.target.as_ref().unwrap();
246 let mut visited: HashSet<String> = HashSet::new();
247 let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
248 let mut queue: VecDeque<(String, u32)> = VecDeque::new();
249 let mut stats = TraversalStats::default();
250
251 queue.push_back((query.start.clone(), 0));
252 visited.insert(query.start.clone());
253
254 while let Some((current, depth)) = queue.pop_front() {
255 stats.nodes_visited += 1;
256 stats.max_depth_reached = stats.max_depth_reached.max(depth);
257
258 if ¤t == target {
259 stats.path_found = true;
260 tracing::debug!("BFS found path at depth {}", depth);
261 break;
262 }
263
264 if depth >= query.max_depth {
265 continue;
266 }
267
268 for rel in Self::get_neighbors(¤t, &query.direction, relations) {
269 stats.edges_traversed += 1;
270
271 if !query.relation_type_filter.is_empty()
273 && !query.relation_type_filter.contains(&rel.relation_type)
274 {
275 continue;
276 }
277
278 let next = if rel.from_name == current {
279 &rel.to_name
280 } else {
281 &rel.from_name
282 };
283
284 if let Some(entity) = entities.get(next) {
286 if !query.entity_type_filter.is_empty()
287 && !query.entity_type_filter.contains(&entity.entity_type.0)
288 {
289 continue;
290 }
291 }
292
293 if !visited.contains(next) {
294 visited.insert(next.clone());
295 parent.insert(
296 next.clone(),
297 (
298 current.clone(),
299 PathEdge {
300 from: rel.from_name.clone(),
301 to: rel.to_name.clone(),
302 relation_type: rel.relation_type.clone(),
303 weight: rel.weight,
304 },
305 ),
306 );
307 queue.push_back((next.clone(), depth + 1));
308 }
309 }
310 }
311
312 let paths = if stats.path_found {
314 vec![Self::reconstruct_path(&query.start, target, &parent)]
315 } else {
316 vec![]
317 };
318
319 Self::build_result(query, paths, &visited, entities, relations, stats)
320 }
321
322 fn dijkstra_path(
324 query: &TraversalQuery,
325 entities: &HashMap<String, Entity>,
326 relations: &[Relation],
327 ) -> TraversalResult {
328 let target = query.target.as_ref().unwrap();
329 let mut dist: HashMap<String, f64> = HashMap::new();
330 let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
331 let mut heap = BinaryHeap::new();
332 let mut stats = TraversalStats::default();
333
334 dist.insert(query.start.clone(), 0.0);
335 heap.push(DijkstraState {
336 cost: 0.0,
337 node: query.start.clone(),
338 });
339
340 while let Some(DijkstraState { cost, node }) = heap.pop() {
341 stats.nodes_visited += 1;
342
343 if &node == target {
344 stats.path_found = true;
345 tracing::debug!("Dijkstra found path with cost {}", cost);
346 break;
347 }
348
349 if cost > *dist.get(&node).unwrap_or(&f64::INFINITY) {
351 continue;
352 }
353
354 for rel in Self::get_neighbors(&node, &query.direction, relations) {
355 stats.edges_traversed += 1;
356
357 if !query.relation_type_filter.is_empty()
359 && !query.relation_type_filter.contains(&rel.relation_type)
360 {
361 continue;
362 }
363
364 let next = if rel.from_name == node {
365 &rel.to_name
366 } else {
367 &rel.from_name
368 };
369
370 if let Some(entity) = entities.get(next) {
372 if !query.entity_type_filter.is_empty()
373 && !query.entity_type_filter.contains(&entity.entity_type.0)
374 {
375 continue;
376 }
377 }
378
379 let edge_weight = rel.weight.unwrap_or(1.0);
380 let new_cost = cost + edge_weight;
381
382 if new_cost < *dist.get(next).unwrap_or(&f64::INFINITY) {
383 dist.insert(next.clone(), new_cost);
384 parent.insert(
385 next.clone(),
386 (
387 node.clone(),
388 PathEdge {
389 from: rel.from_name.clone(),
390 to: rel.to_name.clone(),
391 relation_type: rel.relation_type.clone(),
392 weight: rel.weight,
393 },
394 ),
395 );
396 heap.push(DijkstraState {
397 cost: new_cost,
398 node: next.clone(),
399 });
400 }
401 }
402 }
403
404 let paths = if stats.path_found {
405 vec![Self::reconstruct_path(&query.start, target, &parent)]
406 } else {
407 vec![]
408 };
409
410 let visited: HashSet<String> = dist.keys().cloned().collect();
411 Self::build_result(query, paths, &visited, entities, relations, stats)
412 }
413
414 fn filtered_bfs(
416 query: &TraversalQuery,
417 entities: &HashMap<String, Entity>,
418 relations: &[Relation],
419 ) -> TraversalResult {
420 let mut visited: HashSet<String> = HashSet::new();
421 let mut queue: VecDeque<(String, u32)> = VecDeque::new();
422 let mut stats = TraversalStats::default();
423
424 queue.push_back((query.start.clone(), 0));
425 visited.insert(query.start.clone());
426
427 while let Some((current, depth)) = queue.pop_front() {
428 stats.nodes_visited += 1;
429 stats.max_depth_reached = stats.max_depth_reached.max(depth);
430
431 if depth >= query.max_depth {
432 continue;
433 }
434
435 for rel in Self::get_neighbors(¤t, &query.direction, relations) {
436 stats.edges_traversed += 1;
437
438 if !query.relation_type_filter.is_empty()
440 && !query.relation_type_filter.contains(&rel.relation_type)
441 {
442 continue;
443 }
444
445 let next = if rel.from_name == current {
446 &rel.to_name
447 } else {
448 &rel.from_name
449 };
450
451 if let Some(entity) = entities.get(next) {
453 if !query.entity_type_filter.is_empty()
454 && !query.entity_type_filter.contains(&entity.entity_type.0)
455 {
456 continue;
457 }
458 }
459
460 if !visited.contains(next) {
461 visited.insert(next.clone());
462 queue.push_back((next.clone(), depth + 1));
463 }
464 }
465 }
466
467 tracing::debug!(
468 "Filtered BFS visited {} nodes, traversed {} edges",
469 stats.nodes_visited,
470 stats.edges_traversed
471 );
472
473 Self::build_result(query, vec![], &visited, entities, relations, stats)
474 }
475
476 fn get_neighbors<'a>(
478 node: &str,
479 direction: &Direction,
480 relations: &'a [Relation],
481 ) -> Vec<&'a Relation> {
482 relations
483 .iter()
484 .filter(|rel| match direction {
485 Direction::Outgoing => rel.from_name == node,
486 Direction::Incoming => rel.to_name == node,
487 Direction::Both => rel.from_name == node || rel.to_name == node,
488 })
489 .collect()
490 }
491
492 fn reconstruct_path(
494 start: &str,
495 end: &str,
496 parent: &HashMap<String, (String, PathEdge)>,
497 ) -> GraphPath {
498 let mut nodes = vec![end.to_string()];
499 let mut edges = Vec::new();
500 let mut current = end.to_string();
501 let mut total_weight = 0.0;
502
503 while ¤t != start {
504 if let Some((prev, edge)) = parent.get(¤t) {
505 total_weight += edge.weight.unwrap_or(1.0);
506 edges.push(edge.clone());
507 nodes.push(prev.clone());
508 current = prev.clone();
509 } else {
510 break;
511 }
512 }
513
514 nodes.reverse();
515 edges.reverse();
516
517 GraphPath {
518 length: edges.len(),
519 nodes,
520 edges,
521 total_weight,
522 }
523 }
524
525 fn build_result(
527 query: &TraversalQuery,
528 paths: Vec<GraphPath>,
529 visited: &HashSet<String>,
530 entities: &HashMap<String, Entity>,
531 relations: &[Relation],
532 stats: TraversalStats,
533 ) -> TraversalResult {
534 let visited_entities: Vec<String> = visited.iter().cloned().collect();
535
536 let result_entities: Vec<Entity> = visited_entities
537 .iter()
538 .filter_map(|name| entities.get(name).cloned())
539 .collect();
540
541 let result_relations: Vec<Relation> = relations
542 .iter()
543 .filter(|r| visited.contains(&r.from_name) && visited.contains(&r.to_name))
544 .cloned()
545 .collect();
546
547 TraversalResult {
548 start: query.start.clone(),
549 target: query.target.clone(),
550 paths,
551 visited_entities,
552 entities: result_entities,
553 relations: result_relations,
554 stats,
555 }
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::project::ProjectId;
563
564 fn create_test_graph() -> (HashMap<String, Entity>, Vec<Relation>) {
565 let project_id = ProjectId::new();
566
567 let mut entities = HashMap::new();
569 for name in ["A", "B", "C", "D", "E", "F"] {
570 let entity = Entity::new(project_id.clone(), name, "node");
571 entities.insert(name.to_string(), entity);
572 }
573
574 let relations = vec![
580 Relation::from_names(project_id.clone(), "A", "B", "connects").with_weight(1.0),
581 Relation::from_names(project_id.clone(), "B", "C", "connects").with_weight(2.0),
582 Relation::from_names(project_id.clone(), "C", "D", "connects").with_weight(1.0),
583 Relation::from_names(project_id.clone(), "B", "E", "connects").with_weight(1.0),
584 Relation::from_names(project_id.clone(), "C", "F", "connects").with_weight(1.0),
585 Relation::from_names(project_id.clone(), "E", "F", "connects").with_weight(3.0),
586 ];
587
588 (entities, relations)
589 }
590
591 #[test]
592 fn test_bfs_shortest_path() {
593 let (entities, relations) = create_test_graph();
594 let query = TraversalQuery::new("A").find_path_to("D");
595 let result = TraversalEngine::execute(&query, &entities, &relations);
596
597 assert!(result.stats.path_found);
598 assert_eq!(result.paths.len(), 1);
599 assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "D"]);
600 assert_eq!(result.paths[0].length, 3);
601 }
602
603 #[test]
604 fn test_dijkstra_weighted_path() {
605 let (entities, relations) = create_test_graph();
606
607 let query = TraversalQuery::new("A").find_path_to("F").weighted();
612 let result = TraversalEngine::execute(&query, &entities, &relations);
613
614 assert!(result.stats.path_found);
615 assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "F"]);
616 assert!((result.paths[0].total_weight - 4.0).abs() < 0.001);
617 }
618
619 #[test]
620 fn test_filtered_traversal() {
621 let (entities, relations) = create_test_graph();
622 let query = TraversalQuery::new("A").with_depth(2);
623 let result = TraversalEngine::execute(&query, &entities, &relations);
624
625 assert!(result.visited_entities.contains(&"A".to_string()));
627 assert!(result.visited_entities.contains(&"B".to_string()));
628 assert!(result.visited_entities.contains(&"C".to_string()));
629 assert!(result.visited_entities.contains(&"E".to_string()));
630 }
631
632 #[test]
633 fn test_no_path_found() {
634 let project_id = ProjectId::new();
635 let mut entities = HashMap::new();
636 entities.insert("A".to_string(), Entity::new(project_id.clone(), "A", "node"));
637 entities.insert("B".to_string(), Entity::new(project_id.clone(), "B", "node"));
638 let query = TraversalQuery::new("A").find_path_to("B");
641 let result = TraversalEngine::execute(&query, &entities, &[]);
642
643 assert!(!result.stats.path_found);
644 assert!(result.paths.is_empty());
645 }
646
647 #[test]
648 fn test_direction_filtering() {
649 let (entities, relations) = create_test_graph();
650
651 let outgoing = TraversalQuery::new("B")
653 .with_direction(Direction::Outgoing)
654 .with_depth(1);
655 let result = TraversalEngine::execute(&outgoing, &entities, &relations);
656 assert!(result.visited_entities.contains(&"C".to_string()));
657 assert!(result.visited_entities.contains(&"E".to_string()));
658 assert!(!result.visited_entities.contains(&"A".to_string()));
659
660 let incoming = TraversalQuery::new("B")
662 .with_direction(Direction::Incoming)
663 .with_depth(1);
664 let result = TraversalEngine::execute(&incoming, &entities, &relations);
665 assert!(result.visited_entities.contains(&"A".to_string()));
666 assert!(!result.visited_entities.contains(&"C".to_string()));
667 }
668
669 #[test]
670 fn test_relation_type_filter() {
671 let project_id = ProjectId::new();
672 let mut entities = HashMap::new();
673 for name in ["A", "B", "C"] {
674 entities.insert(
675 name.to_string(),
676 Entity::new(project_id.clone(), name, "node"),
677 );
678 }
679
680 let relations = vec![
681 Relation::from_names(project_id.clone(), "A", "B", "works_at"),
682 Relation::from_names(project_id.clone(), "B", "C", "knows"),
683 ];
684
685 let query = TraversalQuery::new("A")
687 .with_depth(2)
688 .filter_relation_types(vec!["works_at".to_string()]);
689 let result = TraversalEngine::execute(&query, &entities, &relations);
690
691 assert!(result.visited_entities.contains(&"B".to_string()));
693 assert!(!result.visited_entities.contains(&"C".to_string()));
694 }
695}