Skip to main content

agentic_memory/graph/
traversal.rs

1//! Graph traversal algorithms (BFS).
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::types::{AmemError, AmemResult, Edge, EdgeType};
6
7use super::MemoryGraph;
8
9/// Direction for graph traversal.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TraversalDirection {
12    /// Follow outgoing edges (source -> target).
13    Forward,
14    /// Follow incoming edges (target <- source).
15    Backward,
16    /// Follow edges in both directions.
17    Both,
18}
19
20/// BFS traversal from a starting node, following specific edge types.
21#[allow(clippy::type_complexity)]
22pub fn bfs_traverse(
23    graph: &MemoryGraph,
24    start_id: u64,
25    edge_types: &[EdgeType],
26    direction: TraversalDirection,
27    max_depth: u32,
28    max_results: usize,
29    min_confidence: f32,
30) -> AmemResult<(Vec<u64>, Vec<Edge>, HashMap<u64, u32>)> {
31    if graph.get_node(start_id).is_none() {
32        return Err(AmemError::NodeNotFound(start_id));
33    }
34
35    let edge_set: HashSet<EdgeType> = edge_types.iter().copied().collect();
36    let mut visited: HashSet<u64> = HashSet::new();
37    let mut visited_order: Vec<u64> = Vec::new();
38    let mut edges_traversed: Vec<Edge> = Vec::new();
39    let mut depths: HashMap<u64, u32> = HashMap::new();
40    let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
41
42    visited.insert(start_id);
43    visited_order.push(start_id);
44    depths.insert(start_id, 0);
45    queue.push_back((start_id, 0));
46
47    while let Some((current_id, depth)) = queue.pop_front() {
48        if depth >= max_depth {
49            continue;
50        }
51        if visited_order.len() >= max_results {
52            break;
53        }
54
55        let mut neighbors: Vec<(u64, Edge)> = Vec::new();
56
57        // Forward: follow outgoing edges
58        if direction == TraversalDirection::Forward || direction == TraversalDirection::Both {
59            for edge in graph.edges_from(current_id) {
60                if edge_set.contains(&edge.edge_type) {
61                    neighbors.push((edge.target_id, *edge));
62                }
63            }
64        }
65
66        // Backward: follow incoming edges
67        if direction == TraversalDirection::Backward || direction == TraversalDirection::Both {
68            for edge in graph.edges_to(current_id) {
69                if edge_set.contains(&edge.edge_type) {
70                    neighbors.push((edge.source_id, *edge));
71                }
72            }
73        }
74
75        for (neighbor_id, edge) in neighbors {
76            if visited.contains(&neighbor_id) {
77                continue;
78            }
79            if visited_order.len() >= max_results {
80                break;
81            }
82
83            // Check confidence threshold
84            if let Some(node) = graph.get_node(neighbor_id) {
85                if node.confidence < min_confidence {
86                    continue;
87                }
88            } else {
89                continue;
90            }
91
92            visited.insert(neighbor_id);
93            visited_order.push(neighbor_id);
94            depths.insert(neighbor_id, depth + 1);
95            edges_traversed.push(edge);
96            queue.push_back((neighbor_id, depth + 1));
97        }
98    }
99
100    Ok((visited_order, edges_traversed, depths))
101}