Skip to main content

agentic_memory/engine/
graph_algo.rs

1//! Graph algorithms: centrality (PageRank, degree, betweenness) and shortest path (queries 10-11).
2
3use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::TraversalDirection;
6use crate::graph::MemoryGraph;
7use crate::types::{AmemResult, Edge, EdgeType, EventType};
8
9/// Which centrality algorithm to use.
10#[derive(Debug, Clone)]
11pub enum CentralityAlgorithm {
12    /// Standard PageRank — importance flows through edges.
13    PageRank { damping: f32 },
14    /// Degree centrality — simple count of connections.
15    Degree,
16    /// Betweenness centrality — how often a node appears on shortest paths.
17    Betweenness,
18}
19
20/// Parameters for a centrality query.
21pub struct CentralityParams {
22    pub algorithm: CentralityAlgorithm,
23    pub max_iterations: u32,
24    pub tolerance: f32,
25    pub top_k: usize,
26    pub event_types: Vec<EventType>,
27    pub edge_types: Vec<EdgeType>,
28}
29
30/// Result of a centrality computation.
31pub struct CentralityResult {
32    /// Node ID → centrality score, sorted by score descending.
33    pub scores: Vec<(u64, f32)>,
34    pub algorithm: CentralityAlgorithm,
35    pub iterations: u32,
36    pub converged: bool,
37}
38
39/// Parameters for shortest path query.
40pub struct ShortestPathParams {
41    pub source_id: u64,
42    pub target_id: u64,
43    pub edge_types: Vec<EdgeType>,
44    pub direction: TraversalDirection,
45    pub max_depth: u32,
46    pub weighted: bool,
47}
48
49/// Result of a shortest path query.
50pub struct PathResult {
51    /// Ordered list of node IDs from source to target (inclusive). Empty if no path.
52    pub path: Vec<u64>,
53    /// Edges traversed along the path.
54    pub edges: Vec<Edge>,
55    /// Total path length.
56    pub cost: f32,
57    pub found: bool,
58}
59
60impl super::query::QueryEngine {
61    /// Compute centrality scores for nodes in the graph.
62    pub fn centrality(
63        &self,
64        graph: &MemoryGraph,
65        params: CentralityParams,
66    ) -> AmemResult<CentralityResult> {
67        let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
68        let edge_filter: HashSet<EdgeType> = params.edge_types.iter().copied().collect();
69
70        // Collect candidate nodes
71        let node_ids: Vec<u64> = graph
72            .nodes()
73            .iter()
74            .filter(|n| type_filter.is_empty() || type_filter.contains(&n.event_type))
75            .map(|n| n.id)
76            .collect();
77
78        let node_set: HashSet<u64> = node_ids.iter().copied().collect();
79
80        // Collect relevant edges
81        let edges: Vec<&Edge> = graph
82            .edges()
83            .iter()
84            .filter(|e| {
85                node_set.contains(&e.source_id)
86                    && node_set.contains(&e.target_id)
87                    && (edge_filter.is_empty() || edge_filter.contains(&e.edge_type))
88            })
89            .collect();
90
91        match params.algorithm {
92            CentralityAlgorithm::PageRank { damping } => self.pagerank(
93                &node_ids,
94                &edges,
95                damping,
96                params.max_iterations,
97                params.tolerance,
98                params.top_k,
99            ),
100            CentralityAlgorithm::Degree => self.degree_centrality(&node_ids, &edges, params.top_k),
101            CentralityAlgorithm::Betweenness => {
102                self.betweenness_centrality(&node_ids, &edges, params.top_k)
103            }
104        }
105    }
106
107    fn pagerank(
108        &self,
109        node_ids: &[u64],
110        edges: &[&Edge],
111        damping: f32,
112        max_iterations: u32,
113        tolerance: f32,
114        top_k: usize,
115    ) -> AmemResult<CentralityResult> {
116        let n = node_ids.len();
117        if n == 0 {
118            return Ok(CentralityResult {
119                scores: Vec::new(),
120                algorithm: CentralityAlgorithm::PageRank { damping },
121                iterations: 0,
122                converged: true,
123            });
124        }
125
126        let id_to_idx: HashMap<u64, usize> = node_ids
127            .iter()
128            .enumerate()
129            .map(|(i, &id)| (id, i))
130            .collect();
131
132        // Build outgoing edges and incoming edges
133        let mut outgoing: Vec<Vec<usize>> = vec![Vec::new(); n];
134        let mut incoming: Vec<Vec<usize>> = vec![Vec::new(); n];
135
136        for edge in edges {
137            if let (Some(&src_idx), Some(&tgt_idx)) = (
138                id_to_idx.get(&edge.source_id),
139                id_to_idx.get(&edge.target_id),
140            ) {
141                outgoing[src_idx].push(tgt_idx);
142                incoming[tgt_idx].push(src_idx);
143            }
144        }
145
146        let mut pr = vec![1.0 / n as f32; n];
147        let mut iterations = 0;
148        let mut converged = false;
149
150        for _ in 0..max_iterations {
151            iterations += 1;
152            let mut new_pr = vec![(1.0 - damping) / n as f32; n];
153
154            // Dangling node rank
155            let dangling_sum: f32 = (0..n)
156                .filter(|&i| outgoing[i].is_empty())
157                .map(|i| pr[i])
158                .sum();
159
160            for i in 0..n {
161                new_pr[i] += damping * dangling_sum / n as f32;
162                for &j in &incoming[i] {
163                    let out_degree = outgoing[j].len() as f32;
164                    if out_degree > 0.0 {
165                        new_pr[i] += damping * pr[j] / out_degree;
166                    }
167                }
168            }
169
170            // Check convergence
171            let max_diff = (0..n)
172                .map(|i| (new_pr[i] - pr[i]).abs())
173                .fold(0.0f32, f32::max);
174
175            pr = new_pr;
176
177            if max_diff < tolerance {
178                converged = true;
179                break;
180            }
181        }
182
183        let mut scores: Vec<(u64, f32)> = node_ids
184            .iter()
185            .zip(pr.iter())
186            .map(|(&id, &s)| (id, s))
187            .collect();
188        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189        scores.truncate(top_k);
190
191        Ok(CentralityResult {
192            scores,
193            algorithm: CentralityAlgorithm::PageRank { damping },
194            iterations,
195            converged,
196        })
197    }
198
199    fn degree_centrality(
200        &self,
201        node_ids: &[u64],
202        edges: &[&Edge],
203        top_k: usize,
204    ) -> AmemResult<CentralityResult> {
205        let n = node_ids.len();
206        let mut degrees: HashMap<u64, u32> = HashMap::new();
207        for &id in node_ids {
208            degrees.insert(id, 0);
209        }
210
211        for edge in edges {
212            *degrees.entry(edge.source_id).or_insert(0) += 1;
213            *degrees.entry(edge.target_id).or_insert(0) += 1;
214        }
215
216        let max_possible = if n > 1 { 2 * (n - 1) } else { 1 };
217
218        let mut scores: Vec<(u64, f32)> = degrees
219            .into_iter()
220            .map(|(id, deg)| (id, deg as f32 / max_possible as f32))
221            .collect();
222        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
223        scores.truncate(top_k);
224
225        Ok(CentralityResult {
226            scores,
227            algorithm: CentralityAlgorithm::Degree,
228            iterations: 0,
229            converged: true,
230        })
231    }
232
233    fn betweenness_centrality(
234        &self,
235        node_ids: &[u64],
236        edges: &[&Edge],
237        top_k: usize,
238    ) -> AmemResult<CentralityResult> {
239        let n = node_ids.len();
240        if n == 0 {
241            return Ok(CentralityResult {
242                scores: Vec::new(),
243                algorithm: CentralityAlgorithm::Betweenness,
244                iterations: 0,
245                converged: true,
246            });
247        }
248
249        let id_to_idx: HashMap<u64, usize> = node_ids
250            .iter()
251            .enumerate()
252            .map(|(i, &id)| (id, i))
253            .collect();
254
255        // Build adjacency list (both directions for undirected betweenness)
256        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
257        for edge in edges {
258            if let (Some(&src), Some(&tgt)) = (
259                id_to_idx.get(&edge.source_id),
260                id_to_idx.get(&edge.target_id),
261            ) {
262                adj[src].push(tgt);
263                adj[tgt].push(src);
264            }
265        }
266
267        let mut betweenness = vec![0.0f32; n];
268
269        // Sample source nodes if graph is large
270        let sources: Vec<usize> = if n > 10_000 {
271            (0..1000.min(n)).collect()
272        } else {
273            (0..n).collect()
274        };
275
276        // Brandes' algorithm
277        for &s in &sources {
278            let mut stack: Vec<usize> = Vec::new();
279            let mut pred: Vec<Vec<usize>> = vec![Vec::new(); n];
280            let mut sigma = vec![0.0f64; n];
281            sigma[s] = 1.0;
282            let mut dist: Vec<i64> = vec![-1; n];
283            dist[s] = 0;
284            let mut queue = VecDeque::new();
285            queue.push_back(s);
286
287            while let Some(v) = queue.pop_front() {
288                stack.push(v);
289                for &w in &adj[v] {
290                    if dist[w] < 0 {
291                        queue.push_back(w);
292                        dist[w] = dist[v] + 1;
293                    }
294                    if dist[w] == dist[v] + 1 {
295                        sigma[w] += sigma[v];
296                        pred[w].push(v);
297                    }
298                }
299            }
300
301            let mut delta = vec![0.0f64; n];
302            while let Some(w) = stack.pop() {
303                for &v in &pred[w] {
304                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
305                }
306                if w != s {
307                    betweenness[w] += delta[w] as f32;
308                }
309            }
310        }
311
312        // Normalize
313        let norm = if n > 2 {
314            ((n - 1) * (n - 2)) as f32
315        } else {
316            1.0
317        };
318
319        let mut scores: Vec<(u64, f32)> = node_ids
320            .iter()
321            .enumerate()
322            .map(|(i, &id)| (id, betweenness[i] / norm))
323            .collect();
324        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
325        scores.truncate(top_k);
326
327        Ok(CentralityResult {
328            scores,
329            algorithm: CentralityAlgorithm::Betweenness,
330            iterations: 0,
331            converged: true,
332        })
333    }
334
335    /// Find the shortest path between two nodes.
336    pub fn shortest_path(
337        &self,
338        graph: &MemoryGraph,
339        params: ShortestPathParams,
340    ) -> AmemResult<PathResult> {
341        // Same node
342        if params.source_id == params.target_id {
343            return Ok(PathResult {
344                path: vec![params.source_id],
345                edges: Vec::new(),
346                cost: 0.0,
347                found: true,
348            });
349        }
350
351        // Check nodes exist
352        if graph.get_node(params.source_id).is_none() {
353            return Err(crate::types::AmemError::NodeNotFound(params.source_id));
354        }
355        if graph.get_node(params.target_id).is_none() {
356            return Err(crate::types::AmemError::NodeNotFound(params.target_id));
357        }
358
359        let edge_filter: HashSet<EdgeType> = params.edge_types.iter().copied().collect();
360
361        if params.weighted {
362            self.dijkstra_path(graph, &params, &edge_filter)
363        } else {
364            self.bidirectional_bfs(graph, &params, &edge_filter)
365        }
366    }
367
368    fn bidirectional_bfs(
369        &self,
370        graph: &MemoryGraph,
371        params: &ShortestPathParams,
372        edge_filter: &HashSet<EdgeType>,
373    ) -> AmemResult<PathResult> {
374        let mut forward_visited: HashMap<u64, u64> = HashMap::new(); // node -> parent
375        let mut backward_visited: HashMap<u64, u64> = HashMap::new();
376        let mut forward_queue: VecDeque<(u64, u32)> = VecDeque::new();
377        let mut backward_queue: VecDeque<(u64, u32)> = VecDeque::new();
378
379        forward_visited.insert(params.source_id, params.source_id);
380        backward_visited.insert(params.target_id, params.target_id);
381        forward_queue.push_back((params.source_id, 0));
382        backward_queue.push_back((params.target_id, 0));
383
384        let half_depth = params.max_depth / 2 + 1;
385        let mut meeting_node: Option<u64> = None;
386
387        // Helper to get neighbors
388        let get_neighbors = |node_id: u64, forward: bool| -> Vec<u64> {
389            let mut neighbors = Vec::new();
390            match params.direction {
391                TraversalDirection::Forward | TraversalDirection::Both => {
392                    if forward {
393                        for edge in graph.edges_from(node_id) {
394                            if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
395                                neighbors.push(edge.target_id);
396                            }
397                        }
398                    }
399                }
400                TraversalDirection::Backward => {}
401            }
402            match params.direction {
403                TraversalDirection::Backward | TraversalDirection::Both => {
404                    if forward {
405                        for edge in graph.edges_to(node_id) {
406                            if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
407                                neighbors.push(edge.source_id);
408                            }
409                        }
410                    }
411                }
412                TraversalDirection::Forward => {}
413            }
414            // For backward search, reverse the directions
415            if !forward {
416                let mut rev_neighbors = Vec::new();
417                match params.direction {
418                    TraversalDirection::Forward | TraversalDirection::Both => {
419                        for edge in graph.edges_to(node_id) {
420                            if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
421                                rev_neighbors.push(edge.source_id);
422                            }
423                        }
424                    }
425                    TraversalDirection::Backward => {}
426                }
427                match params.direction {
428                    TraversalDirection::Backward | TraversalDirection::Both => {
429                        for edge in graph.edges_from(node_id) {
430                            if edge_filter.is_empty() || edge_filter.contains(&edge.edge_type) {
431                                rev_neighbors.push(edge.target_id);
432                            }
433                        }
434                    }
435                    TraversalDirection::Forward => {}
436                }
437                return rev_neighbors;
438            }
439            neighbors
440        };
441
442        'outer: while !forward_queue.is_empty() || !backward_queue.is_empty() {
443            // Expand forward
444            if let Some((node, depth)) = forward_queue.pop_front() {
445                if depth < half_depth {
446                    for neighbor in get_neighbors(node, true) {
447                        forward_visited.entry(neighbor).or_insert_with(|| {
448                            forward_queue.push_back((neighbor, depth + 1));
449                            node
450                        });
451                        if backward_visited.contains_key(&neighbor) {
452                            forward_visited.entry(neighbor).or_insert(node);
453                            meeting_node = Some(neighbor);
454                            break 'outer;
455                        }
456                    }
457                }
458            }
459
460            // Expand backward
461            if let Some((node, depth)) = backward_queue.pop_front() {
462                if depth < half_depth {
463                    for neighbor in get_neighbors(node, false) {
464                        backward_visited.entry(neighbor).or_insert_with(|| {
465                            backward_queue.push_back((neighbor, depth + 1));
466                            node
467                        });
468                        if forward_visited.contains_key(&neighbor) {
469                            backward_visited.entry(neighbor).or_insert(node);
470                            meeting_node = Some(neighbor);
471                            break 'outer;
472                        }
473                    }
474                }
475            }
476        }
477
478        match meeting_node {
479            Some(mid) => {
480                // Reconstruct path
481                let mut forward_path = Vec::new();
482                let mut current = mid;
483                while current != params.source_id {
484                    forward_path.push(current);
485                    current = forward_visited[&current];
486                }
487                forward_path.push(params.source_id);
488                forward_path.reverse();
489
490                let mut backward_path = Vec::new();
491                current = mid;
492                while current != params.target_id {
493                    current = backward_visited[&current];
494                    backward_path.push(current);
495                }
496
497                let mut path = forward_path;
498                path.extend(backward_path);
499
500                let cost = (path.len() - 1) as f32;
501
502                // Collect edges along the path
503                let mut edges = Vec::new();
504                for i in 0..path.len() - 1 {
505                    for edge in graph.edges_from(path[i]) {
506                        if edge.target_id == path[i + 1] {
507                            edges.push(*edge);
508                            break;
509                        }
510                    }
511                    if edges.len() < i + 1 {
512                        // Try reverse direction
513                        for edge in graph.edges_from(path[i + 1]) {
514                            if edge.target_id == path[i] {
515                                edges.push(*edge);
516                                break;
517                            }
518                        }
519                    }
520                }
521
522                Ok(PathResult {
523                    path,
524                    edges,
525                    cost,
526                    found: true,
527                })
528            }
529            None => Ok(PathResult {
530                path: Vec::new(),
531                edges: Vec::new(),
532                cost: 0.0,
533                found: false,
534            }),
535        }
536    }
537
538    fn dijkstra_path(
539        &self,
540        graph: &MemoryGraph,
541        params: &ShortestPathParams,
542        edge_filter: &HashSet<EdgeType>,
543    ) -> AmemResult<PathResult> {
544        #[derive(PartialEq)]
545        struct State {
546            cost: f32,
547            node: u64,
548        }
549        impl Eq for State {}
550        impl PartialOrd for State {
551            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
552                Some(self.cmp(other))
553            }
554        }
555        impl Ord for State {
556            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
557                other
558                    .cost
559                    .partial_cmp(&self.cost)
560                    .unwrap_or(std::cmp::Ordering::Equal)
561            }
562        }
563
564        let mut dist: HashMap<u64, f32> = HashMap::new();
565        let mut prev: HashMap<u64, u64> = HashMap::new();
566        let mut heap = BinaryHeap::new();
567
568        dist.insert(params.source_id, 0.0);
569        heap.push(State {
570            cost: 0.0,
571            node: params.source_id,
572        });
573
574        while let Some(State { cost, node }) = heap.pop() {
575            if node == params.target_id {
576                // Reconstruct path
577                let mut path = Vec::new();
578                let mut current = params.target_id;
579                while current != params.source_id {
580                    path.push(current);
581                    current = prev[&current];
582                }
583                path.push(params.source_id);
584                path.reverse();
585
586                // Collect edges
587                let mut edges = Vec::new();
588                for i in 0..path.len() - 1 {
589                    for edge in graph.edges_from(path[i]) {
590                        if edge.target_id == path[i + 1] {
591                            edges.push(*edge);
592                            break;
593                        }
594                    }
595                }
596
597                return Ok(PathResult {
598                    path,
599                    edges,
600                    cost,
601                    found: true,
602                });
603            }
604
605            if cost > *dist.get(&node).unwrap_or(&f32::INFINITY) {
606                continue;
607            }
608
609            // Explore neighbors
610            for edge in graph.edges_from(node) {
611                if !edge_filter.is_empty() && !edge_filter.contains(&edge.edge_type) {
612                    continue;
613                }
614                let edge_cost = 1.0 - edge.weight; // Higher weight = lower cost
615                let next_cost = cost + edge_cost;
616
617                if next_cost < *dist.get(&edge.target_id).unwrap_or(&f32::INFINITY) {
618                    dist.insert(edge.target_id, next_cost);
619                    prev.insert(edge.target_id, node);
620                    heap.push(State {
621                        cost: next_cost,
622                        node: edge.target_id,
623                    });
624                }
625            }
626
627            // If direction allows backward/both, also check incoming edges
628            if matches!(
629                params.direction,
630                TraversalDirection::Backward | TraversalDirection::Both
631            ) {
632                for edge in graph.edges_to(node) {
633                    if !edge_filter.is_empty() && !edge_filter.contains(&edge.edge_type) {
634                        continue;
635                    }
636                    let edge_cost = 1.0 - edge.weight;
637                    let next_cost = cost + edge_cost;
638
639                    if next_cost < *dist.get(&edge.source_id).unwrap_or(&f32::INFINITY) {
640                        dist.insert(edge.source_id, next_cost);
641                        prev.insert(edge.source_id, node);
642                        heap.push(State {
643                            cost: next_cost,
644                            node: edge.source_id,
645                        });
646                    }
647                }
648            }
649        }
650
651        Ok(PathResult {
652            path: Vec::new(),
653            edges: Vec::new(),
654            cost: 0.0,
655            found: false,
656        })
657    }
658}