Skip to main content

codemem_engine/
graph_ops.rs

1//! Engine facade methods for graph algorithms.
2//!
3//! These wrap lock acquisition + graph algorithm calls so the MCP/API
4//! transport layers don't interact with the graph mutex directly.
5
6use crate::CodememEngine;
7use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, NodeKind, RelationshipType};
8use std::collections::HashMap;
9
10// ── Result Types ─────────────────────────────────────────────────────────────
11
12/// A node with its PageRank score.
13#[derive(Debug, Clone)]
14pub struct RankedNode {
15    pub id: String,
16    pub score: f64,
17    pub kind: Option<String>,
18    pub label: Option<String>,
19}
20
21/// In-memory graph statistics snapshot.
22#[derive(Debug, Clone)]
23pub struct GraphStats {
24    pub node_count: usize,
25    pub edge_count: usize,
26    pub node_kind_counts: HashMap<String, usize>,
27    pub relationship_type_counts: HashMap<String, usize>,
28}
29
30// ── Engine Methods ───────────────────────────────────────────────────────────
31
32impl CodememEngine {
33    /// BFS or DFS traversal from a start node, with optional kind/relationship filters.
34    pub fn graph_traverse(
35        &self,
36        start_id: &str,
37        depth: usize,
38        algorithm: &str,
39        exclude_kinds: &[NodeKind],
40        include_relationships: Option<&[RelationshipType]>,
41    ) -> Result<Vec<GraphNode>, CodememError> {
42        let graph = self.lock_graph()?;
43        let has_filters = !exclude_kinds.is_empty() || include_relationships.is_some();
44
45        if has_filters {
46            match algorithm {
47                "bfs" => graph.bfs_filtered(start_id, depth, exclude_kinds, include_relationships),
48                "dfs" => graph.dfs_filtered(start_id, depth, exclude_kinds, include_relationships),
49                _ => Err(CodememError::InvalidInput(format!(
50                    "Unknown algorithm: {algorithm}"
51                ))),
52            }
53        } else {
54            match algorithm {
55                "bfs" => graph.bfs(start_id, depth),
56                "dfs" => graph.dfs(start_id, depth),
57                _ => Err(CodememError::InvalidInput(format!(
58                    "Unknown algorithm: {algorithm}"
59                ))),
60            }
61        }
62    }
63
64    /// Get in-memory graph statistics.
65    pub fn graph_stats(&self) -> Result<GraphStats, CodememError> {
66        let graph = self.lock_graph()?;
67        let stats = graph.stats();
68        Ok(GraphStats {
69            node_count: stats.node_count,
70            edge_count: stats.edge_count,
71            node_kind_counts: stats.node_kind_counts,
72            relationship_type_counts: stats.relationship_type_counts,
73        })
74    }
75
76    /// Get all edges for a node.
77    pub fn get_node_edges(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
78        let graph = self.lock_graph()?;
79        graph.get_edges(node_id)
80    }
81
82    /// Run Louvain community detection at the given resolution.
83    pub fn louvain_communities(&self, resolution: f64) -> Result<Vec<Vec<String>>, CodememError> {
84        let graph = self.lock_graph()?;
85        Ok(graph.louvain_communities(resolution))
86    }
87
88    /// Compute PageRank and return the top-k nodes with their scores,
89    /// kinds, and labels.
90    pub fn find_important_nodes(
91        &self,
92        top_k: usize,
93        damping: f64,
94    ) -> Result<Vec<RankedNode>, CodememError> {
95        let graph = self.lock_graph()?;
96        let scores = graph.pagerank(damping, 100, 1e-6);
97
98        let mut sorted: Vec<(String, f64)> = scores.into_iter().collect();
99        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100        sorted.truncate(top_k);
101
102        let results = sorted
103            .into_iter()
104            .map(|(id, score)| {
105                let node = graph.get_node(&id).ok().flatten();
106                RankedNode {
107                    id,
108                    score,
109                    kind: node.as_ref().map(|n| n.kind.to_string()),
110                    label: node.as_ref().map(|n| n.label.clone()),
111                }
112            })
113            .collect();
114
115        Ok(results)
116    }
117}