mod algorithms;
mod traversal;
use codemem_core::{CodememError, Edge, GraphBackend, GraphNode, NodeKind, RawGraphMetrics};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::Direction;
use std::collections::{HashMap, HashSet, VecDeque};
pub struct GraphEngine {
pub(crate) graph: DiGraph<String, f64>,
pub(crate) id_to_index: HashMap<String, NodeIndex>,
pub(crate) nodes: HashMap<String, GraphNode>,
pub(crate) edges: HashMap<String, Edge>,
pub(crate) edge_adj: HashMap<String, Vec<String>>,
pub(crate) cached_pagerank: HashMap<String, f64>,
pub(crate) cached_betweenness: HashMap<String, f64>,
}
impl GraphEngine {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
id_to_index: HashMap::new(),
nodes: HashMap::new(),
edges: HashMap::new(),
edge_adj: HashMap::new(),
cached_pagerank: HashMap::new(),
cached_betweenness: HashMap::new(),
}
}
pub fn from_storage(storage: &dyn codemem_core::StorageBackend) -> Result<Self, CodememError> {
let mut engine = Self::new();
let nodes = storage.all_graph_nodes()?;
for node in nodes {
engine.add_node(node)?;
}
let edges = storage.all_graph_edges()?;
for edge in edges {
engine.add_edge(edge)?;
}
engine.compute_centrality();
Ok(engine)
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn expand(
&self,
start_ids: &[String],
max_hops: usize,
) -> Result<Vec<GraphNode>, CodememError> {
let mut visited = std::collections::HashSet::new();
let mut result = Vec::new();
for start_id in start_ids {
let nodes = self.bfs(start_id, max_hops)?;
for node in nodes {
if visited.insert(node.id.clone()) {
result.push(node);
}
}
}
Ok(result)
}
pub fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, CodememError> {
let idx = self
.id_to_index
.get(node_id)
.ok_or_else(|| CodememError::NotFound(format!("Node {node_id}")))?;
let mut result = Vec::new();
for neighbor_idx in self.graph.neighbors(*idx) {
if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
if let Some(node) = self.nodes.get(neighbor_id) {
result.push(node.clone());
}
}
}
Ok(result)
}
pub fn connected_components(&self) -> Vec<Vec<String>> {
let mut visited: HashSet<NodeIndex> = HashSet::new();
let mut components: Vec<Vec<String>> = Vec::new();
for &start_idx in self.id_to_index.values() {
if visited.contains(&start_idx) {
continue;
}
let mut component: Vec<String> = Vec::new();
let mut queue: VecDeque<NodeIndex> = VecDeque::new();
queue.push_back(start_idx);
visited.insert(start_idx);
while let Some(current) = queue.pop_front() {
if let Some(node_id) = self.graph.node_weight(current) {
component.push(node_id.clone());
}
for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
for neighbor in self.graph.neighbors_directed(current, Direction::Incoming) {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
component.sort();
components.push(component);
}
components.sort();
components
}
pub fn compute_centrality(&mut self) {
let n = self.nodes.len();
if n <= 1 {
for node in self.nodes.values_mut() {
node.centrality = 0.0;
}
return;
}
let denominator = (n - 1) as f64;
let centrality_map: HashMap<String, f64> = self
.id_to_index
.iter()
.map(|(id, &idx)| {
let in_deg = self
.graph
.neighbors_directed(idx, Direction::Incoming)
.count();
let out_deg = self
.graph
.neighbors_directed(idx, Direction::Outgoing)
.count();
let centrality = (in_deg + out_deg) as f64 / denominator;
(id.clone(), centrality)
})
.collect();
for (id, centrality) in ¢rality_map {
if let Some(node) = self.nodes.get_mut(id) {
node.centrality = *centrality;
}
}
}
pub fn get_all_nodes(&self) -> Vec<GraphNode> {
self.nodes.values().cloned().collect()
}
pub fn get_node_ref(&self, id: &str) -> Option<&GraphNode> {
self.nodes.get(id)
}
pub fn get_edges_ref(&self, node_id: &str) -> Vec<&Edge> {
self.edge_adj
.get(node_id)
.map(|edge_ids| {
edge_ids
.iter()
.filter_map(|eid| self.edges.get(eid))
.collect()
})
.unwrap_or_default()
}
pub fn recompute_centrality(&mut self) {
self.recompute_centrality_with_options(true);
}
pub fn recompute_centrality_for_namespace(&mut self, namespace: &str) {
let current_node_ids: HashSet<String> = self
.nodes
.iter()
.filter(|(_, n)| n.namespace.as_deref() == Some(namespace))
.map(|(id, _)| id.clone())
.collect();
self.cached_pagerank.retain(|id, _| {
self.nodes
.get(id)
.map(|n| n.namespace.as_deref() != Some(namespace))
.unwrap_or(true) });
let scores = self.pagerank_for_namespace(
namespace,
codemem_core::PAGERANK_DAMPING_DEFAULT,
codemem_core::PAGERANK_ITERATIONS_DEFAULT,
codemem_core::PAGERANK_TOLERANCE_DEFAULT,
);
for (id, score) in scores {
self.cached_pagerank.insert(id, score);
}
tracing::debug!(
namespace = %namespace,
scores_updated = self.cached_pagerank.iter().filter(|(id, _)| current_node_ids.contains(*id)).count(),
"PageRank recomputed for namespace"
);
}
pub fn recompute_centrality_with_options(&mut self, include_betweenness: bool) {
self.cached_pagerank = self.pagerank(
codemem_core::PAGERANK_DAMPING_DEFAULT,
codemem_core::PAGERANK_ITERATIONS_DEFAULT,
codemem_core::PAGERANK_TOLERANCE_DEFAULT,
);
if include_betweenness {
self.cached_betweenness = self.betweenness_centrality();
} else {
self.cached_betweenness.clear();
}
}
pub fn ensure_betweenness_computed(&mut self) {
if self.cached_betweenness.is_empty() && self.graph.node_count() > 0 {
self.cached_betweenness = self.betweenness_centrality();
}
}
pub fn get_pagerank(&self, node_id: &str) -> f64 {
self.cached_pagerank.get(node_id).copied().unwrap_or(0.0)
}
pub fn get_betweenness(&self, node_id: &str) -> f64 {
self.cached_betweenness.get(node_id).copied().unwrap_or(0.0)
}
pub fn raw_graph_metrics_for_memory(&self, memory_id: &str) -> Option<RawGraphMetrics> {
let idx = *self.id_to_index.get(memory_id)?;
let mut max_pagerank = 0.0_f64;
let mut max_betweenness = 0.0_f64;
let mut code_neighbor_count = 0_usize;
let mut total_edge_weight = 0.0_f64;
let mut memory_neighbor_count = 0_usize;
let mut memory_edge_weight = 0.0_f64;
let now = chrono::Utc::now();
for direction in &[Direction::Outgoing, Direction::Incoming] {
for neighbor_idx in self.graph.neighbors_directed(idx, *direction) {
if let Some(neighbor_id) = self.graph.node_weight(neighbor_idx) {
let neighbor_node = self.nodes.get(neighbor_id.as_str());
if let Some(n) = neighbor_node {
if n.valid_to.is_some_and(|vt| vt <= now) {
continue;
}
}
let is_code_node = neighbor_node
.map(|n| n.kind != NodeKind::Memory)
.unwrap_or(false);
let mut edge_w = 0.0_f64;
if let Some(edge_ids) = self.edge_adj.get(memory_id) {
for eid in edge_ids {
if let Some(edge) = self.edges.get(eid) {
if (edge.src == memory_id && edge.dst == *neighbor_id)
|| (edge.dst == memory_id && edge.src == *neighbor_id)
{
edge_w = edge.weight;
break;
}
}
}
}
if is_code_node {
code_neighbor_count += 1;
total_edge_weight += edge_w;
let pr = self
.cached_pagerank
.get(neighbor_id)
.copied()
.unwrap_or(0.0);
let bt = self
.cached_betweenness
.get(neighbor_id)
.copied()
.unwrap_or(0.0);
max_pagerank = max_pagerank.max(pr);
max_betweenness = max_betweenness.max(bt);
} else {
memory_neighbor_count += 1;
memory_edge_weight += edge_w;
}
}
}
}
if code_neighbor_count == 0 && memory_neighbor_count == 0 {
return None;
}
Some(RawGraphMetrics {
max_pagerank,
max_betweenness,
code_neighbor_count,
total_edge_weight,
memory_neighbor_count,
memory_edge_weight,
})
}
#[cfg(test)]
pub fn max_degree(&self) -> f64 {
if self.nodes.len() <= 1 {
return 1.0;
}
self.id_to_index
.values()
.map(|&idx| {
let in_deg = self
.graph
.neighbors_directed(idx, Direction::Incoming)
.count();
let out_deg = self
.graph
.neighbors_directed(idx, Direction::Outgoing)
.count();
(in_deg + out_deg) as f64
})
.fold(1.0f64, f64::max)
}
}
#[cfg(test)]
#[path = "../tests/graph_tests.rs"]
mod tests;