use crate::knowledge::entity::{Entity, EntityStore, EntityType, Relationship};
use std::collections::{HashMap, HashSet, VecDeque};
pub use brainwires_core::graph::{EdgeType, GraphEdge, GraphNode};
#[derive(Debug, Default)]
pub struct RelationshipGraph {
nodes: HashMap<String, GraphNode>,
edges: Vec<GraphEdge>,
adjacency: HashMap<String, Vec<usize>>, }
impl RelationshipGraph {
pub fn new() -> Self {
Self::default()
}
pub fn from_entity_store(store: &EntityStore) -> Self {
let mut graph = Self::new();
for entity in store.get_top_entities(100) {
graph.add_node(GraphNode {
entity_name: entity.name.clone(),
entity_type: entity.entity_type.clone(),
message_ids: entity.message_ids.clone(),
mention_count: entity.mention_count,
importance: Self::calculate_importance(entity),
});
}
graph
}
pub fn calculate_importance(entity: &Entity) -> f32 {
let mut score = 0.0;
score += (entity.mention_count as f32).ln().max(0.0) * 0.3;
score += match entity.entity_type {
EntityType::File => 0.4,
EntityType::Function => 0.3,
EntityType::Type => 0.35,
EntityType::Error => 0.25,
EntityType::Concept => 0.2,
EntityType::Variable => 0.1,
EntityType::Command => 0.15,
};
score += (entity.message_ids.len() as f32 * 0.05).min(0.2);
score.clamp(0.0, 1.0)
}
pub fn add_node(&mut self, node: GraphNode) {
let name = node.entity_name.clone();
if !self.adjacency.contains_key(&name) {
self.adjacency.insert(name.clone(), Vec::new());
}
self.nodes.insert(name, node);
}
pub fn add_edge(&mut self, edge: GraphEdge) {
let idx = self.edges.len();
if let Some(adj) = self.adjacency.get_mut(&edge.from) {
adj.push(idx);
}
if let Some(adj) = self.adjacency.get_mut(&edge.to) {
adj.push(idx);
}
self.edges.push(edge);
}
pub fn add_relationship(&mut self, rel: &Relationship) {
let (from, to, edge_type, message_id) = match rel {
Relationship::CoOccurs {
entity_a,
entity_b,
message_id,
} => (
entity_a.clone(),
entity_b.clone(),
EdgeType::CoOccurs,
Some(message_id.clone()),
),
Relationship::Contains {
container,
contained,
} => (
container.clone(),
contained.clone(),
EdgeType::Contains,
None,
),
Relationship::References { from, to } => {
(from.clone(), to.clone(), EdgeType::References, None)
}
Relationship::DependsOn {
dependent,
dependency,
} => (
dependent.clone(),
dependency.clone(),
EdgeType::DependsOn,
None,
),
Relationship::Modifies {
modifier, modified, ..
} => (modifier.clone(), modified.clone(), EdgeType::Modifies, None),
Relationship::Defines {
definer, defined, ..
} => (definer.clone(), defined.clone(), EdgeType::Defines, None),
};
if self.nodes.contains_key(&from) && self.nodes.contains_key(&to) {
self.add_edge(GraphEdge {
from,
to,
weight: edge_type.weight(),
edge_type,
message_id,
});
}
}
pub fn get_node(&self, name: &str) -> Option<&GraphNode> {
self.nodes.get(name)
}
pub fn get_neighbors(&self, name: &str) -> Vec<&GraphNode> {
let mut neighbors = Vec::new();
if let Some(edge_indices) = self.adjacency.get(name) {
for &idx in edge_indices {
if let Some(edge) = self.edges.get(idx) {
let neighbor_name = if edge.from == name {
&edge.to
} else {
&edge.from
};
if let Some(node) = self.nodes.get(neighbor_name) {
neighbors.push(node);
}
}
}
}
neighbors
}
pub fn get_edges(&self, name: &str) -> Vec<&GraphEdge> {
self.adjacency
.get(name)
.map(|indices| {
indices
.iter()
.filter_map(|&idx| self.edges.get(idx))
.collect()
})
.unwrap_or_default()
}
pub fn find_path(&self, from: &str, to: &str) -> Option<Vec<String>> {
if !self.nodes.contains_key(from) || !self.nodes.contains_key(to) {
return None;
}
if from == to {
return Some(vec![from.to_string()]);
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut parent: HashMap<String, String> = HashMap::new();
queue.push_back(from.to_string());
visited.insert(from.to_string());
while let Some(current) = queue.pop_front() {
for neighbor in self.get_neighbors(¤t) {
if !visited.contains(&neighbor.entity_name) {
visited.insert(neighbor.entity_name.clone());
parent.insert(neighbor.entity_name.clone(), current.clone());
if neighbor.entity_name == to {
let mut path = vec![to.to_string()];
let mut node = to.to_string();
while let Some(p) = parent.get(&node) {
path.push(p.clone());
node = p.clone();
}
path.reverse();
return Some(path);
}
queue.push_back(neighbor.entity_name.clone());
}
}
}
None
}
pub fn get_entity_context(&self, entity: &str, max_depth: usize) -> EntityContext {
let mut context = EntityContext {
root: entity.to_string(),
related_entities: Vec::new(),
message_ids: HashSet::new(),
};
if let Some(node) = self.nodes.get(entity) {
context.message_ids.extend(node.message_ids.clone());
}
let mut visited = HashSet::new();
let mut queue: VecDeque<(String, usize)> = VecDeque::new();
queue.push_back((entity.to_string(), 0));
visited.insert(entity.to_string());
while let Some((current, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
for edge in self.get_edges(¤t) {
let neighbor = if edge.from == current {
&edge.to
} else {
&edge.from
};
if !visited.contains(neighbor) {
visited.insert(neighbor.clone());
if let Some(node) = self.nodes.get(neighbor) {
context.related_entities.push(RelatedEntity {
name: neighbor.clone(),
entity_type: node.entity_type.clone(),
relationship: edge.edge_type.clone(),
distance: depth + 1,
relevance: edge.weight * (0.8_f32).powi((depth + 1) as i32),
});
context.message_ids.extend(node.message_ids.clone());
}
queue.push_back((neighbor.clone(), depth + 1));
}
}
}
context.related_entities.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
context
}
pub fn search(&self, query: &str, limit: usize) -> Vec<&GraphNode> {
let query_lower = query.to_lowercase();
let query_words: HashSet<_> = query_lower.split_whitespace().collect();
let mut scored: Vec<_> = self
.nodes
.values()
.map(|node| {
let name_lower = node.entity_name.to_lowercase();
let mut score = 0.0;
if name_lower == query_lower {
score += 1.0;
}
else if name_lower.contains(&query_lower) {
score += 0.7;
}
else if query_lower.contains(&name_lower) {
score += 0.5;
}
else {
let name_words: HashSet<_> =
name_lower.split(|c: char| !c.is_alphanumeric()).collect();
let overlap = query_words.intersection(&name_words).count();
score += overlap as f32 * 0.3;
}
score *= 1.0 + node.importance * 0.5;
(node, score)
})
.filter(|(_, score)| *score > 0.0)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(limit)
.map(|(node, _)| node)
.collect()
}
pub fn stats(&self) -> GraphStats {
let mut type_counts = HashMap::new();
for node in self.nodes.values() {
*type_counts.entry(node.entity_type.as_str()).or_insert(0) += 1;
}
let mut edge_type_counts = HashMap::new();
for edge in &self.edges {
*edge_type_counts
.entry(format!("{:?}", edge.edge_type))
.or_insert(0) += 1;
}
GraphStats {
node_count: self.nodes.len(),
edge_count: self.edges.len(),
nodes_by_type: type_counts,
edges_by_type: edge_type_counts,
}
}
pub fn get_impact_set(&self, entity: &str, depth: usize) -> Vec<ImpactedEntity> {
let mut impacts = Vec::new();
let mut visited = HashSet::new();
let mut queue: VecDeque<(String, usize, f32)> = VecDeque::new();
if !self.nodes.contains_key(entity) {
return impacts;
}
queue.push_back((entity.to_string(), 0, 1.0));
visited.insert(entity.to_string());
while let Some((current, current_depth, current_impact)) = queue.pop_front() {
if current_depth >= depth {
continue;
}
for edge in self.get_edges(¤t) {
let neighbor = if edge.from == current {
&edge.to
} else {
&edge.from
};
if !visited.contains(neighbor) {
visited.insert(neighbor.clone());
let impact_factor = match edge.edge_type {
EdgeType::DependsOn => 0.9,
EdgeType::Contains => 0.8,
EdgeType::Modifies => 0.7,
EdgeType::References => 0.5,
EdgeType::Defines => 0.6,
EdgeType::CoOccurs => 0.3,
};
let propagated_impact = current_impact * impact_factor * edge.weight;
if let Some(node) = self.nodes.get(neighbor) {
impacts.push(ImpactedEntity {
name: neighbor.clone(),
entity_type: node.entity_type.clone(),
distance: current_depth + 1,
impact_score: propagated_impact,
impact_path: vec![current.clone(), neighbor.clone()],
});
}
queue.push_back((neighbor.clone(), current_depth + 1, propagated_impact));
}
}
}
impacts.sort_by(|a, b| {
b.impact_score
.partial_cmp(&a.impact_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
impacts
}
pub fn find_clusters(&self) -> Vec<EntityCluster> {
let mut clusters = Vec::new();
let mut visited = HashSet::new();
for node_name in self.nodes.keys() {
if visited.contains(node_name) {
continue;
}
let mut cluster_nodes = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(node_name.clone());
visited.insert(node_name.clone());
while let Some(current) = queue.pop_front() {
if let Some(node) = self.nodes.get(¤t) {
cluster_nodes.push(node.clone());
}
for neighbor in self.get_neighbors(¤t) {
if !visited.contains(&neighbor.entity_name) {
visited.insert(neighbor.entity_name.clone());
queue.push_back(neighbor.entity_name.clone());
}
}
}
if !cluster_nodes.is_empty() {
let total_importance: f32 = cluster_nodes.iter().map(|n| n.importance).sum();
let avg_importance = total_importance / cluster_nodes.len() as f32;
let mut type_counts = HashMap::new();
for node in &cluster_nodes {
*type_counts.entry(node.entity_type.clone()).or_insert(0) += 1;
}
let dominant_type = type_counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(t, _)| t);
clusters.push(EntityCluster {
id: clusters.len(),
nodes: cluster_nodes,
avg_importance,
dominant_type,
});
}
}
clusters.sort_by(|a, b| b.nodes.len().cmp(&a.nodes.len()));
clusters
}
pub fn suggest_related(&self, entities: &[&str]) -> Vec<SuggestedEntity> {
let mut scores: HashMap<String, f32> = HashMap::new();
let entity_set: HashSet<_> = entities.iter().copied().collect();
for entity in entities {
for neighbor in self.get_neighbors(entity) {
if !entity_set.contains(neighbor.entity_name.as_str()) {
*scores.entry(neighbor.entity_name.clone()).or_default() += neighbor.importance;
}
}
for first_neighbor in self.get_neighbors(entity) {
if entity_set.contains(first_neighbor.entity_name.as_str()) {
continue;
}
for second_neighbor in self.get_neighbors(&first_neighbor.entity_name) {
if !entity_set.contains(second_neighbor.entity_name.as_str())
&& second_neighbor.entity_name != *entity
{
*scores
.entry(second_neighbor.entity_name.clone())
.or_default() += second_neighbor.importance * 0.5;
}
}
}
}
let mut suggestions: Vec<_> = scores
.into_iter()
.filter_map(|(name, score)| {
self.nodes.get(&name).map(|node| SuggestedEntity {
name: name.clone(),
entity_type: node.entity_type.clone(),
relevance_score: score,
reason: self.get_suggestion_reason(&name, entities),
})
})
.collect();
suggestions.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
suggestions.truncate(10);
suggestions
}
fn get_suggestion_reason(&self, suggested: &str, source_entities: &[&str]) -> String {
for source in source_entities {
let edges = self.get_edges(source);
for edge in edges {
let other = if edge.from == *source {
&edge.to
} else {
&edge.from
};
if other == suggested {
return format!("{:?} by {}", edge.edge_type, source);
}
}
}
"Related through graph".to_string()
}
pub fn get_central_nodes(&self, limit: usize) -> Vec<&GraphNode> {
let mut centrality: Vec<_> = self
.nodes
.iter()
.map(|(name, node)| {
let degree = self.adjacency.get(name).map(|v| v.len()).unwrap_or(0);
let weighted_score = node.importance * 0.7 + (degree as f32 / 10.0).min(0.3);
(node, weighted_score)
})
.collect();
centrality.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
centrality.into_iter().take(limit).map(|(n, _)| n).collect()
}
#[cfg(feature = "spectral")]
fn to_adjacency_matrix(&self) -> (ndarray::Array2<f32>, Vec<String>) {
let names: Vec<String> = self.nodes.keys().cloned().collect();
let n = names.len();
let idx: HashMap<&str, usize> = names
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
let mut adj = ndarray::Array2::<f32>::zeros((n, n));
for edge in &self.edges {
if let (Some(&i), Some(&j)) = (idx.get(edge.from.as_str()), idx.get(edge.to.as_str())) {
adj[[i, j]] += edge.weight;
adj[[j, i]] += edge.weight;
}
}
(adj, names)
}
#[cfg(feature = "spectral")]
pub fn spectral_clusters(&self, k: usize) -> Vec<EntityCluster> {
if self.nodes.is_empty() || k == 0 {
return Vec::new();
}
let (adj, names) = self.to_adjacency_matrix();
let assignments = match crate::spectral::graph_ops::spectral_cluster(&adj, k) {
Some(a) => a,
None => return self.find_clusters(), };
let max_cluster = assignments.iter().copied().max().unwrap_or(0);
let mut cluster_nodes: Vec<Vec<GraphNode>> = vec![Vec::new(); max_cluster + 1];
for (i, &cluster_id) in assignments.iter().enumerate() {
if let Some(node) = self.nodes.get(&names[i]) {
cluster_nodes[cluster_id].push(node.clone());
}
}
cluster_nodes
.into_iter()
.enumerate()
.filter(|(_, nodes)| !nodes.is_empty())
.map(|(id, nodes)| {
let avg_importance =
nodes.iter().map(|n| n.importance).sum::<f32>() / nodes.len() as f32;
let mut type_counts = HashMap::new();
for node in &nodes {
*type_counts.entry(node.entity_type.clone()).or_insert(0) += 1;
}
let dominant_type = type_counts
.into_iter()
.max_by_key(|(_, c)| *c)
.map(|(t, _)| t);
EntityCluster {
id,
nodes,
avg_importance,
dominant_type,
}
})
.collect()
}
#[cfg(feature = "spectral")]
pub fn spectral_central_nodes(&self, limit: usize) -> Vec<(&GraphNode, f32)> {
if self.nodes.is_empty() {
return Vec::new();
}
let (adj, names) = self.to_adjacency_matrix();
let scores = crate::spectral::graph_ops::spectral_centrality(&adj);
scores
.into_iter()
.filter_map(|(i, score)| self.nodes.get(&names[i]).map(|node| (node, score)))
.take(limit)
.collect()
}
#[cfg(feature = "spectral")]
pub fn connectivity(&self) -> f32 {
if self.nodes.len() < 2 {
return 0.0;
}
let (adj, _) = self.to_adjacency_matrix();
crate::spectral::graph_ops::algebraic_connectivity(&adj)
}
#[cfg(feature = "spectral")]
pub fn sparsify(&mut self, epsilon: f32) {
if self.nodes.len() < 4 {
return; }
let (adj, names) = self.to_adjacency_matrix();
let sparse_adj = crate::spectral::graph_ops::sparsify(&adj, epsilon);
let idx: HashMap<&str, usize> = names
.iter()
.enumerate()
.map(|(i, s)| (s.as_str(), i))
.collect();
let mut new_edges = Vec::new();
let mut new_adjacency: HashMap<String, Vec<usize>> = HashMap::new();
for name in self.nodes.keys() {
new_adjacency.insert(name.clone(), Vec::new());
}
for edge in &self.edges {
if let (Some(&i), Some(&j)) = (idx.get(edge.from.as_str()), idx.get(edge.to.as_str())) {
if sparse_adj[[i, j]] > 0.0 {
let edge_idx = new_edges.len();
if let Some(adj_list) = new_adjacency.get_mut(&edge.from) {
adj_list.push(edge_idx);
}
if let Some(adj_list) = new_adjacency.get_mut(&edge.to) {
adj_list.push(edge_idx);
}
new_edges.push(edge.clone());
}
}
}
self.edges = new_edges;
self.adjacency = new_adjacency;
}
}
impl brainwires_core::graph::RelationshipGraphT for RelationshipGraph {
fn get_node(&self, name: &str) -> Option<&GraphNode> {
self.nodes.get(name)
}
fn get_neighbors(&self, name: &str) -> Vec<&GraphNode> {
RelationshipGraph::get_neighbors(self, name)
}
fn get_edges(&self, name: &str) -> Vec<&GraphEdge> {
RelationshipGraph::get_edges(self, name)
}
fn search(&self, query: &str, limit: usize) -> Vec<&GraphNode> {
RelationshipGraph::search(self, query, limit)
}
fn find_path(&self, from: &str, to: &str) -> Option<Vec<String>> {
RelationshipGraph::find_path(self, from, to)
}
}
#[derive(Debug, Clone)]
pub struct ImpactedEntity {
pub name: String,
pub entity_type: EntityType,
pub distance: usize,
pub impact_score: f32,
pub impact_path: Vec<String>,
}
#[derive(Debug)]
pub struct EntityCluster {
pub id: usize,
pub nodes: Vec<GraphNode>,
pub avg_importance: f32,
pub dominant_type: Option<EntityType>,
}
#[derive(Debug)]
pub struct SuggestedEntity {
pub name: String,
pub entity_type: EntityType,
pub relevance_score: f32,
pub reason: String,
}
#[derive(Debug)]
pub struct EntityContext {
pub root: String,
pub related_entities: Vec<RelatedEntity>,
pub message_ids: HashSet<String>,
}
#[derive(Debug)]
pub struct RelatedEntity {
pub name: String,
pub entity_type: EntityType,
pub relationship: EdgeType,
pub distance: usize,
pub relevance: f32,
}
#[derive(Debug)]
pub struct GraphStats {
pub node_count: usize,
pub edge_count: usize,
pub nodes_by_type: HashMap<&'static str, usize>,
pub edges_by_type: HashMap<String, usize>,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> RelationshipGraph {
let mut graph = RelationshipGraph::new();
graph.add_node(GraphNode {
entity_name: "src/main.rs".to_string(),
entity_type: EntityType::File,
message_ids: vec!["msg1".to_string(), "msg2".to_string()],
mention_count: 5,
importance: 0.8,
});
graph.add_node(GraphNode {
entity_name: "main".to_string(),
entity_type: EntityType::Function,
message_ids: vec!["msg1".to_string()],
mention_count: 2,
importance: 0.6,
});
graph.add_node(GraphNode {
entity_name: "Config".to_string(),
entity_type: EntityType::Type,
message_ids: vec!["msg2".to_string()],
mention_count: 3,
importance: 0.7,
});
graph.add_edge(GraphEdge {
from: "src/main.rs".to_string(),
to: "main".to_string(),
edge_type: EdgeType::Contains,
weight: 0.9,
message_id: Some("msg1".to_string()),
});
graph.add_edge(GraphEdge {
from: "main".to_string(),
to: "Config".to_string(),
edge_type: EdgeType::References,
weight: 0.6,
message_id: Some("msg2".to_string()),
});
graph
}
#[test]
fn test_add_and_get_node() {
let graph = create_test_graph();
let node = graph.get_node("src/main.rs");
assert!(node.is_some());
assert_eq!(node.unwrap().mention_count, 5);
}
#[test]
fn test_get_neighbors() {
let graph = create_test_graph();
let neighbors = graph.get_neighbors("src/main.rs");
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0].entity_name, "main");
}
#[test]
fn test_find_path() {
let graph = create_test_graph();
let path = graph.find_path("src/main.rs", "Config");
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0], "src/main.rs");
assert_eq!(path[2], "Config");
}
#[test]
fn test_get_entity_context() {
let graph = create_test_graph();
let context = graph.get_entity_context("src/main.rs", 2);
assert_eq!(context.root, "src/main.rs");
assert!(!context.related_entities.is_empty());
assert!(!context.message_ids.is_empty());
}
#[test]
fn test_search() {
let graph = create_test_graph();
let results = graph.search("main", 5);
assert!(!results.is_empty());
assert!(results.iter().any(|n| n.entity_name == "main"));
}
#[test]
fn test_graph_stats() {
let graph = create_test_graph();
let stats = graph.stats();
assert_eq!(stats.node_count, 3);
assert_eq!(stats.edge_count, 2);
}
#[test]
fn test_empty_path() {
let graph = create_test_graph();
let mut graph = graph;
graph.add_node(GraphNode {
entity_name: "isolated".to_string(),
entity_type: EntityType::Concept,
message_ids: vec![],
mention_count: 1,
importance: 0.1,
});
let path = graph.find_path("src/main.rs", "isolated");
assert!(path.is_none());
}
#[test]
fn test_get_impact_set() {
let graph = create_test_graph();
let impacts = graph.get_impact_set("src/main.rs", 2);
assert!(!impacts.is_empty());
let names: Vec<_> = impacts.iter().map(|i| i.name.as_str()).collect();
assert!(names.contains(&"main"));
}
#[test]
fn test_get_impact_set_empty() {
let graph = create_test_graph();
let impacts = graph.get_impact_set("nonexistent", 2);
assert!(impacts.is_empty());
}
#[test]
fn test_find_clusters() {
let mut graph = create_test_graph();
graph.add_node(GraphNode {
entity_name: "isolated".to_string(),
entity_type: EntityType::Concept,
message_ids: vec![],
mention_count: 1,
importance: 0.1,
});
let clusters = graph.find_clusters();
assert_eq!(clusters.len(), 2);
assert_eq!(clusters[0].nodes.len(), 3);
assert_eq!(clusters[1].nodes.len(), 1);
}
#[test]
fn test_suggest_related() {
let graph = create_test_graph();
let suggestions = graph.suggest_related(&["src/main.rs"]);
let suggested_names: Vec<_> = suggestions.iter().map(|s| s.name.as_str()).collect();
assert!(suggested_names.contains(&"main"));
}
#[test]
fn test_get_central_nodes() {
let graph = create_test_graph();
let central = graph.get_central_nodes(2);
assert!(!central.is_empty());
let names: Vec<_> = central.iter().map(|n| n.entity_name.as_str()).collect();
assert!(names.contains(&"src/main.rs") || names.contains(&"main"));
}
}