agentic_memory/graph/
traversal.rs1use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::types::{AmemError, AmemResult, Edge, EdgeType};
6
7use super::MemoryGraph;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TraversalDirection {
12 Forward,
14 Backward,
16 Both,
18}
19
20#[allow(clippy::type_complexity)]
22pub fn bfs_traverse(
23 graph: &MemoryGraph,
24 start_id: u64,
25 edge_types: &[EdgeType],
26 direction: TraversalDirection,
27 max_depth: u32,
28 max_results: usize,
29 min_confidence: f32,
30) -> AmemResult<(Vec<u64>, Vec<Edge>, HashMap<u64, u32>)> {
31 if graph.get_node(start_id).is_none() {
32 return Err(AmemError::NodeNotFound(start_id));
33 }
34
35 let edge_set: HashSet<EdgeType> = edge_types.iter().copied().collect();
36 let mut visited: HashSet<u64> = HashSet::new();
37 let mut visited_order: Vec<u64> = Vec::new();
38 let mut edges_traversed: Vec<Edge> = Vec::new();
39 let mut depths: HashMap<u64, u32> = HashMap::new();
40 let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
41
42 visited.insert(start_id);
43 visited_order.push(start_id);
44 depths.insert(start_id, 0);
45 queue.push_back((start_id, 0));
46
47 while let Some((current_id, depth)) = queue.pop_front() {
48 if depth >= max_depth {
49 continue;
50 }
51 if visited_order.len() >= max_results {
52 break;
53 }
54
55 let mut neighbors: Vec<(u64, Edge)> = Vec::new();
56
57 if direction == TraversalDirection::Forward || direction == TraversalDirection::Both {
59 for edge in graph.edges_from(current_id) {
60 if edge_set.contains(&edge.edge_type) {
61 neighbors.push((edge.target_id, *edge));
62 }
63 }
64 }
65
66 if direction == TraversalDirection::Backward || direction == TraversalDirection::Both {
68 for edge in graph.edges_to(current_id) {
69 if edge_set.contains(&edge.edge_type) {
70 neighbors.push((edge.source_id, *edge));
71 }
72 }
73 }
74
75 for (neighbor_id, edge) in neighbors {
76 if visited.contains(&neighbor_id) {
77 continue;
78 }
79 if visited_order.len() >= max_results {
80 break;
81 }
82
83 if let Some(node) = graph.get_node(neighbor_id) {
85 if node.confidence < min_confidence {
86 continue;
87 }
88 } else {
89 continue;
90 }
91
92 visited.insert(neighbor_id);
93 visited_order.push(neighbor_id);
94 depths.insert(neighbor_id, depth + 1);
95 edges_traversed.push(edge);
96 queue.push_back((neighbor_id, depth + 1));
97 }
98 }
99
100 Ok((visited_order, edges_traversed, depths))
101}