Skip to main content

brainwires_knowledge/knowledge/
relationship_graph.rs

1//! Relationship Graph Storage
2//!
3//! Stores and queries entity relationships for enhanced context retrieval.
4//! Uses an in-memory graph with optional persistence to LanceDB.
5
6use crate::knowledge::entity::{Entity, EntityStore, EntityType, Relationship};
7use std::collections::{HashMap, HashSet, VecDeque};
8
9// Re-export graph types from core (canonical definitions)
10pub use brainwires_core::graph::{EdgeType, GraphEdge, GraphNode};
11
12/// Relationship graph for entity context
13#[derive(Debug, Default)]
14pub struct RelationshipGraph {
15    nodes: HashMap<String, GraphNode>,
16    edges: Vec<GraphEdge>,
17    adjacency: HashMap<String, Vec<usize>>, // node -> edge indices
18}
19
20impl RelationshipGraph {
21    /// Create a new empty relationship graph.
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Build graph from entity store
27    pub fn from_entity_store(store: &EntityStore) -> Self {
28        let mut graph = Self::new();
29
30        // Add nodes from top entities
31        for entity in store.get_top_entities(100) {
32            graph.add_node(GraphNode {
33                entity_name: entity.name.clone(),
34                entity_type: entity.entity_type.clone(),
35                message_ids: entity.message_ids.clone(),
36                mention_count: entity.mention_count,
37                importance: Self::calculate_importance(entity),
38            });
39        }
40
41        graph
42    }
43
44    /// Calculate importance score for an entity.
45    ///
46    /// Combines log-scaled mention count, entity-type bonus, and message-spread
47    /// proxy into a score in `[0.0, 1.0]`.
48    ///
49    /// **Known limitation**: `ln(1) = 0`, so the mention-count component
50    /// contributes nothing for entities seen exactly once. The type bonus and
51    /// message-spread proxy still apply, so the score is non-zero, but a
52    /// single-mention entity is scored identically regardless of how many
53    /// times it was seen (just once vs. genuinely once). Use
54    /// `ln(mention_count + 1)` to remove this discontinuity if needed.
55    pub fn calculate_importance(entity: &Entity) -> f32 {
56        let mut score = 0.0;
57
58        // Base score from mentions
59        score += (entity.mention_count as f32).ln().max(0.0) * 0.3;
60
61        // Type-based importance
62        score += match entity.entity_type {
63            EntityType::File => 0.4,
64            EntityType::Function => 0.3,
65            EntityType::Type => 0.35,
66            EntityType::Error => 0.25,
67            EntityType::Concept => 0.2,
68            EntityType::Variable => 0.1,
69            EntityType::Command => 0.15,
70        };
71
72        // Recency (would need timestamp context)
73        // For now, use message count as proxy
74        score += (entity.message_ids.len() as f32 * 0.05).min(0.2);
75
76        score.clamp(0.0, 1.0)
77    }
78
79    /// Add a node to the graph
80    pub fn add_node(&mut self, node: GraphNode) {
81        let name = node.entity_name.clone();
82        if !self.adjacency.contains_key(&name) {
83            self.adjacency.insert(name.clone(), Vec::new());
84        }
85        self.nodes.insert(name, node);
86    }
87
88    /// Add an edge to the graph
89    pub fn add_edge(&mut self, edge: GraphEdge) {
90        let idx = self.edges.len();
91
92        // Update adjacency list for both directions
93        if let Some(adj) = self.adjacency.get_mut(&edge.from) {
94            adj.push(idx);
95        }
96        if let Some(adj) = self.adjacency.get_mut(&edge.to) {
97            adj.push(idx);
98        }
99
100        self.edges.push(edge);
101    }
102
103    /// Add relationship as edge
104    pub fn add_relationship(&mut self, rel: &Relationship) {
105        let (from, to, edge_type, message_id) = match rel {
106            Relationship::CoOccurs {
107                entity_a,
108                entity_b,
109                message_id,
110            } => (
111                entity_a.clone(),
112                entity_b.clone(),
113                EdgeType::CoOccurs,
114                Some(message_id.clone()),
115            ),
116            Relationship::Contains {
117                container,
118                contained,
119            } => (
120                container.clone(),
121                contained.clone(),
122                EdgeType::Contains,
123                None,
124            ),
125            Relationship::References { from, to } => {
126                (from.clone(), to.clone(), EdgeType::References, None)
127            }
128            Relationship::DependsOn {
129                dependent,
130                dependency,
131            } => (
132                dependent.clone(),
133                dependency.clone(),
134                EdgeType::DependsOn,
135                None,
136            ),
137            Relationship::Modifies {
138                modifier, modified, ..
139            } => (modifier.clone(), modified.clone(), EdgeType::Modifies, None),
140            Relationship::Defines {
141                definer, defined, ..
142            } => (definer.clone(), defined.clone(), EdgeType::Defines, None),
143        };
144
145        // Only add edge if both nodes exist
146        if self.nodes.contains_key(&from) && self.nodes.contains_key(&to) {
147            self.add_edge(GraphEdge {
148                from,
149                to,
150                weight: edge_type.weight(),
151                edge_type,
152                message_id,
153            });
154        }
155    }
156
157    /// Get node by name
158    pub fn get_node(&self, name: &str) -> Option<&GraphNode> {
159        self.nodes.get(name)
160    }
161
162    /// Get all neighbors of a node
163    pub fn get_neighbors(&self, name: &str) -> Vec<&GraphNode> {
164        let mut neighbors = Vec::new();
165
166        if let Some(edge_indices) = self.adjacency.get(name) {
167            for &idx in edge_indices {
168                if let Some(edge) = self.edges.get(idx) {
169                    let neighbor_name = if edge.from == name {
170                        &edge.to
171                    } else {
172                        &edge.from
173                    };
174                    if let Some(node) = self.nodes.get(neighbor_name) {
175                        neighbors.push(node);
176                    }
177                }
178            }
179        }
180
181        neighbors
182    }
183
184    /// Get edges for a node
185    pub fn get_edges(&self, name: &str) -> Vec<&GraphEdge> {
186        self.adjacency
187            .get(name)
188            .map(|indices| {
189                indices
190                    .iter()
191                    .filter_map(|&idx| self.edges.get(idx))
192                    .collect()
193            })
194            .unwrap_or_default()
195    }
196
197    /// Find shortest path between two entities using BFS
198    pub fn find_path(&self, from: &str, to: &str) -> Option<Vec<String>> {
199        if !self.nodes.contains_key(from) || !self.nodes.contains_key(to) {
200            return None;
201        }
202
203        if from == to {
204            return Some(vec![from.to_string()]);
205        }
206
207        let mut visited = HashSet::new();
208        let mut queue = VecDeque::new();
209        let mut parent: HashMap<String, String> = HashMap::new();
210
211        queue.push_back(from.to_string());
212        visited.insert(from.to_string());
213
214        while let Some(current) = queue.pop_front() {
215            for neighbor in self.get_neighbors(&current) {
216                if !visited.contains(&neighbor.entity_name) {
217                    visited.insert(neighbor.entity_name.clone());
218                    parent.insert(neighbor.entity_name.clone(), current.clone());
219
220                    if neighbor.entity_name == to {
221                        // Reconstruct path
222                        let mut path = vec![to.to_string()];
223                        let mut node = to.to_string();
224                        while let Some(p) = parent.get(&node) {
225                            path.push(p.clone());
226                            node = p.clone();
227                        }
228                        path.reverse();
229                        return Some(path);
230                    }
231
232                    queue.push_back(neighbor.entity_name.clone());
233                }
234            }
235        }
236
237        None
238    }
239
240    /// Get all context related to an entity (traverses graph)
241    pub fn get_entity_context(&self, entity: &str, max_depth: usize) -> EntityContext {
242        let mut context = EntityContext {
243            root: entity.to_string(),
244            related_entities: Vec::new(),
245            message_ids: HashSet::new(),
246        };
247
248        if let Some(node) = self.nodes.get(entity) {
249            context.message_ids.extend(node.message_ids.clone());
250        }
251
252        let mut visited = HashSet::new();
253        let mut queue: VecDeque<(String, usize)> = VecDeque::new();
254
255        queue.push_back((entity.to_string(), 0));
256        visited.insert(entity.to_string());
257
258        while let Some((current, depth)) = queue.pop_front() {
259            if depth >= max_depth {
260                continue;
261            }
262
263            for edge in self.get_edges(&current) {
264                let neighbor = if edge.from == current {
265                    &edge.to
266                } else {
267                    &edge.from
268                };
269
270                if !visited.contains(neighbor) {
271                    visited.insert(neighbor.clone());
272
273                    if let Some(node) = self.nodes.get(neighbor) {
274                        context.related_entities.push(RelatedEntity {
275                            name: neighbor.clone(),
276                            entity_type: node.entity_type.clone(),
277                            relationship: edge.edge_type.clone(),
278                            distance: depth + 1,
279                            relevance: edge.weight * (0.8_f32).powi((depth + 1) as i32),
280                        });
281                        context.message_ids.extend(node.message_ids.clone());
282                    }
283
284                    queue.push_back((neighbor.clone(), depth + 1));
285                }
286            }
287        }
288
289        // Sort by relevance
290        context.related_entities.sort_by(|a, b| {
291            b.relevance
292                .partial_cmp(&a.relevance)
293                .unwrap_or(std::cmp::Ordering::Equal)
294        });
295
296        context
297    }
298
299    /// Find entities most relevant to a query (by name matching)
300    pub fn search(&self, query: &str, limit: usize) -> Vec<&GraphNode> {
301        let query_lower = query.to_lowercase();
302        let query_words: HashSet<_> = query_lower.split_whitespace().collect();
303
304        let mut scored: Vec<_> = self
305            .nodes
306            .values()
307            .map(|node| {
308                let name_lower = node.entity_name.to_lowercase();
309                let mut score = 0.0;
310
311                // Exact match
312                if name_lower == query_lower {
313                    score += 1.0;
314                }
315                // Contains query
316                else if name_lower.contains(&query_lower) {
317                    score += 0.7;
318                }
319                // Query contains name
320                else if query_lower.contains(&name_lower) {
321                    score += 0.5;
322                }
323                // Word overlap
324                else {
325                    let name_words: HashSet<_> =
326                        name_lower.split(|c: char| !c.is_alphanumeric()).collect();
327                    let overlap = query_words.intersection(&name_words).count();
328                    score += overlap as f32 * 0.3;
329                }
330
331                // Boost by importance
332                score *= 1.0 + node.importance * 0.5;
333
334                (node, score)
335            })
336            .filter(|(_, score)| *score > 0.0)
337            .collect();
338
339        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
340
341        scored
342            .into_iter()
343            .take(limit)
344            .map(|(node, _)| node)
345            .collect()
346    }
347
348    /// Get graph statistics
349    pub fn stats(&self) -> GraphStats {
350        let mut type_counts = HashMap::new();
351        for node in self.nodes.values() {
352            *type_counts.entry(node.entity_type.as_str()).or_insert(0) += 1;
353        }
354
355        let mut edge_type_counts = HashMap::new();
356        for edge in &self.edges {
357            *edge_type_counts
358                .entry(format!("{:?}", edge.edge_type))
359                .or_insert(0) += 1;
360        }
361
362        GraphStats {
363            node_count: self.nodes.len(),
364            edge_count: self.edges.len(),
365            nodes_by_type: type_counts,
366            edges_by_type: edge_type_counts,
367        }
368    }
369
370    // ============ SEAL Integration Methods ============
371
372    /// Get entities that would be impacted by changes to a given entity.
373    /// Traverses the graph to find dependent entities up to a specified depth.
374    pub fn get_impact_set(&self, entity: &str, depth: usize) -> Vec<ImpactedEntity> {
375        let mut impacts = Vec::new();
376        let mut visited = HashSet::new();
377        let mut queue: VecDeque<(String, usize, f32)> = VecDeque::new();
378
379        if !self.nodes.contains_key(entity) {
380            return impacts;
381        }
382
383        queue.push_back((entity.to_string(), 0, 1.0));
384        visited.insert(entity.to_string());
385
386        while let Some((current, current_depth, current_impact)) = queue.pop_front() {
387            if current_depth >= depth {
388                continue;
389            }
390
391            for edge in self.get_edges(&current) {
392                let neighbor = if edge.from == current {
393                    &edge.to
394                } else {
395                    &edge.from
396                };
397
398                if !visited.contains(neighbor) {
399                    visited.insert(neighbor.clone());
400
401                    // Calculate impact factor based on edge type and weight
402                    let impact_factor = match edge.edge_type {
403                        EdgeType::DependsOn => 0.9,
404                        EdgeType::Contains => 0.8,
405                        EdgeType::Modifies => 0.7,
406                        EdgeType::References => 0.5,
407                        EdgeType::Defines => 0.6,
408                        EdgeType::CoOccurs => 0.3,
409                    };
410
411                    let propagated_impact = current_impact * impact_factor * edge.weight;
412
413                    if let Some(node) = self.nodes.get(neighbor) {
414                        impacts.push(ImpactedEntity {
415                            name: neighbor.clone(),
416                            entity_type: node.entity_type.clone(),
417                            distance: current_depth + 1,
418                            impact_score: propagated_impact,
419                            impact_path: vec![current.clone(), neighbor.clone()],
420                        });
421                    }
422
423                    queue.push_back((neighbor.clone(), current_depth + 1, propagated_impact));
424                }
425            }
426        }
427
428        // Sort by impact score (highest first)
429        impacts.sort_by(|a, b| {
430            b.impact_score
431                .partial_cmp(&a.impact_score)
432                .unwrap_or(std::cmp::Ordering::Equal)
433        });
434
435        impacts
436    }
437
438    /// Find clusters of related entities using connected component analysis
439    pub fn find_clusters(&self) -> Vec<EntityCluster> {
440        let mut clusters = Vec::new();
441        let mut visited = HashSet::new();
442
443        for node_name in self.nodes.keys() {
444            if visited.contains(node_name) {
445                continue;
446            }
447
448            // BFS to find all connected nodes
449            let mut cluster_nodes = Vec::new();
450            let mut queue = VecDeque::new();
451            queue.push_back(node_name.clone());
452            visited.insert(node_name.clone());
453
454            while let Some(current) = queue.pop_front() {
455                if let Some(node) = self.nodes.get(&current) {
456                    cluster_nodes.push(node.clone());
457                }
458
459                for neighbor in self.get_neighbors(&current) {
460                    if !visited.contains(&neighbor.entity_name) {
461                        visited.insert(neighbor.entity_name.clone());
462                        queue.push_back(neighbor.entity_name.clone());
463                    }
464                }
465            }
466
467            if !cluster_nodes.is_empty() {
468                // Calculate cluster metrics
469                let total_importance: f32 = cluster_nodes.iter().map(|n| n.importance).sum();
470                let avg_importance = total_importance / cluster_nodes.len() as f32;
471
472                // Find dominant type
473                let mut type_counts = HashMap::new();
474                for node in &cluster_nodes {
475                    *type_counts.entry(node.entity_type.clone()).or_insert(0) += 1;
476                }
477                let dominant_type = type_counts
478                    .into_iter()
479                    .max_by_key(|(_, count)| *count)
480                    .map(|(t, _)| t);
481
482                clusters.push(EntityCluster {
483                    id: clusters.len(),
484                    nodes: cluster_nodes,
485                    avg_importance,
486                    dominant_type,
487                });
488            }
489        }
490
491        // Sort clusters by size (largest first)
492        clusters.sort_by(|a, b| b.nodes.len().cmp(&a.nodes.len()));
493
494        clusters
495    }
496
497    /// Suggest related entities given a set of entities.
498    /// Uses co-occurrence and relationship analysis.
499    pub fn suggest_related(&self, entities: &[&str]) -> Vec<SuggestedEntity> {
500        let mut scores: HashMap<String, f32> = HashMap::new();
501        let entity_set: HashSet<_> = entities.iter().copied().collect();
502
503        for entity in entities {
504            // Get direct neighbors
505            for neighbor in self.get_neighbors(entity) {
506                if !entity_set.contains(neighbor.entity_name.as_str()) {
507                    *scores.entry(neighbor.entity_name.clone()).or_default() += neighbor.importance;
508                }
509            }
510
511            // Get second-degree neighbors with lower weight
512            for first_neighbor in self.get_neighbors(entity) {
513                if entity_set.contains(first_neighbor.entity_name.as_str()) {
514                    continue;
515                }
516                for second_neighbor in self.get_neighbors(&first_neighbor.entity_name) {
517                    if !entity_set.contains(second_neighbor.entity_name.as_str())
518                        && second_neighbor.entity_name != *entity
519                    {
520                        *scores
521                            .entry(second_neighbor.entity_name.clone())
522                            .or_default() += second_neighbor.importance * 0.5;
523                    }
524                }
525            }
526        }
527
528        // Convert to suggestions
529        let mut suggestions: Vec<_> = scores
530            .into_iter()
531            .filter_map(|(name, score)| {
532                self.nodes.get(&name).map(|node| SuggestedEntity {
533                    name: name.clone(),
534                    entity_type: node.entity_type.clone(),
535                    relevance_score: score,
536                    reason: self.get_suggestion_reason(&name, entities),
537                })
538            })
539            .collect();
540
541        // Sort by relevance
542        suggestions.sort_by(|a, b| {
543            b.relevance_score
544                .partial_cmp(&a.relevance_score)
545                .unwrap_or(std::cmp::Ordering::Equal)
546        });
547
548        suggestions.truncate(10);
549        suggestions
550    }
551
552    /// Get a reason for suggesting an entity
553    fn get_suggestion_reason(&self, suggested: &str, source_entities: &[&str]) -> String {
554        for source in source_entities {
555            // Check direct relationship
556            let edges = self.get_edges(source);
557            for edge in edges {
558                let other = if edge.from == *source {
559                    &edge.to
560                } else {
561                    &edge.from
562                };
563                if other == suggested {
564                    return format!("{:?} by {}", edge.edge_type, source);
565                }
566            }
567        }
568        "Related through graph".to_string()
569    }
570
571    /// Get the most central nodes in the graph (by connectivity)
572    pub fn get_central_nodes(&self, limit: usize) -> Vec<&GraphNode> {
573        let mut centrality: Vec<_> = self
574            .nodes
575            .iter()
576            .map(|(name, node)| {
577                let degree = self.adjacency.get(name).map(|v| v.len()).unwrap_or(0);
578                let weighted_score = node.importance * 0.7 + (degree as f32 / 10.0).min(0.3);
579                (node, weighted_score)
580            })
581            .collect();
582
583        centrality.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
584
585        centrality.into_iter().take(limit).map(|(n, _)| n).collect()
586    }
587
588    // ============ Spectral Graph Methods ============
589
590    /// Convert this graph to a dense weighted adjacency matrix.
591    ///
592    /// Returns `(adjacency_matrix, node_names)` where `node_names[i]` is the
593    /// entity name for row/column `i`. Multi-edges between the same pair are
594    /// summed.
595    #[cfg(feature = "spectral")]
596    fn to_adjacency_matrix(&self) -> (ndarray::Array2<f32>, Vec<String>) {
597        let names: Vec<String> = self.nodes.keys().cloned().collect();
598        let n = names.len();
599        let idx: HashMap<&str, usize> = names
600            .iter()
601            .enumerate()
602            .map(|(i, s)| (s.as_str(), i))
603            .collect();
604
605        let mut adj = ndarray::Array2::<f32>::zeros((n, n));
606        for edge in &self.edges {
607            if let (Some(&i), Some(&j)) = (idx.get(edge.from.as_str()), idx.get(edge.to.as_str())) {
608                adj[[i, j]] += edge.weight;
609                adj[[j, i]] += edge.weight;
610            }
611        }
612
613        (adj, names)
614    }
615
616    /// Find semantic communities within connected components using spectral clustering.
617    ///
618    /// Unlike `find_clusters` which only finds connected components, this method
619    /// discovers tightly-coupled groups *within* a connected component by analyzing
620    /// the graph's spectral properties (Fiedler vector of the Laplacian).
621    ///
622    /// # Arguments
623    ///
624    /// * `k` - Number of clusters to find. If the graph has fewer natural clusters,
625    ///   fewer may be returned.
626    #[cfg(feature = "spectral")]
627    pub fn spectral_clusters(&self, k: usize) -> Vec<EntityCluster> {
628        if self.nodes.is_empty() || k == 0 {
629            return Vec::new();
630        }
631
632        let (adj, names) = self.to_adjacency_matrix();
633        let assignments = match crate::spectral::graph_ops::spectral_cluster(&adj, k) {
634            Some(a) => a,
635            None => return self.find_clusters(), // fall back to connected components
636        };
637
638        // Group nodes by cluster assignment
639        let max_cluster = assignments.iter().copied().max().unwrap_or(0);
640        let mut cluster_nodes: Vec<Vec<GraphNode>> = vec![Vec::new(); max_cluster + 1];
641
642        for (i, &cluster_id) in assignments.iter().enumerate() {
643            if let Some(node) = self.nodes.get(&names[i]) {
644                cluster_nodes[cluster_id].push(node.clone());
645            }
646        }
647
648        // Build EntityCluster for each non-empty group
649        cluster_nodes
650            .into_iter()
651            .enumerate()
652            .filter(|(_, nodes)| !nodes.is_empty())
653            .map(|(id, nodes)| {
654                let avg_importance =
655                    nodes.iter().map(|n| n.importance).sum::<f32>() / nodes.len() as f32;
656                let mut type_counts = HashMap::new();
657                for node in &nodes {
658                    *type_counts.entry(node.entity_type.clone()).or_insert(0) += 1;
659                }
660                let dominant_type = type_counts
661                    .into_iter()
662                    .max_by_key(|(_, c)| *c)
663                    .map(|(t, _)| t);
664
665                EntityCluster {
666                    id,
667                    nodes,
668                    avg_importance,
669                    dominant_type,
670                }
671            })
672            .collect()
673    }
674
675    /// Compute spectral centrality for all nodes.
676    ///
677    /// Returns nodes sorted by spectral centrality (highest first). Nodes with
678    /// high centrality are structural bridges between communities — important
679    /// for understanding cross-cutting concerns in the codebase.
680    ///
681    /// This complements `get_central_nodes` which uses degree centrality.
682    /// Spectral centrality captures *structural position* rather than just
683    /// connection count.
684    #[cfg(feature = "spectral")]
685    pub fn spectral_central_nodes(&self, limit: usize) -> Vec<(&GraphNode, f32)> {
686        if self.nodes.is_empty() {
687            return Vec::new();
688        }
689
690        let (adj, names) = self.to_adjacency_matrix();
691        let scores = crate::spectral::graph_ops::spectral_centrality(&adj);
692
693        scores
694            .into_iter()
695            .filter_map(|(i, score)| self.nodes.get(&names[i]).map(|node| (node, score)))
696            .take(limit)
697            .collect()
698    }
699
700    /// Compute the algebraic connectivity of this graph.
701    ///
702    /// This is the second-smallest eigenvalue of the Laplacian, measuring how
703    /// well-connected the graph is:
704    /// - 0 = disconnected (multiple components)
705    /// - Small = bottleneck exists (near-disconnection)
706    /// - Large = well-connected
707    ///
708    /// Useful for monitoring knowledge graph health as entities accumulate.
709    #[cfg(feature = "spectral")]
710    pub fn connectivity(&self) -> f32 {
711        if self.nodes.len() < 2 {
712            return 0.0;
713        }
714        let (adj, _) = self.to_adjacency_matrix();
715        crate::spectral::graph_ops::algebraic_connectivity(&adj)
716    }
717
718    /// Prune redundant edges using spectral sparsification.
719    ///
720    /// Removes edges that are structurally redundant (many alternative paths
721    /// exist) while preserving edges that are critical for connectivity
722    /// (bridges, bottlenecks).
723    ///
724    /// # Arguments
725    ///
726    /// * `epsilon` - Approximation quality. 0.3 = aggressive pruning (~30% edges
727    ///   removed), 0.1 = conservative (~10% removed). The sparsified graph
728    ///   preserves spectral properties within (1 ± epsilon) of the original.
729    #[cfg(feature = "spectral")]
730    pub fn sparsify(&mut self, epsilon: f32) {
731        if self.nodes.len() < 4 {
732            return; // too small to benefit
733        }
734
735        let (adj, names) = self.to_adjacency_matrix();
736        let sparse_adj = crate::spectral::graph_ops::sparsify(&adj, epsilon);
737
738        let idx: HashMap<&str, usize> = names
739            .iter()
740            .enumerate()
741            .map(|(i, s)| (s.as_str(), i))
742            .collect();
743
744        // Rebuild edges: keep only those present in the sparsified adjacency
745        let mut new_edges = Vec::new();
746        let mut new_adjacency: HashMap<String, Vec<usize>> = HashMap::new();
747
748        // Initialize adjacency lists
749        for name in self.nodes.keys() {
750            new_adjacency.insert(name.clone(), Vec::new());
751        }
752
753        for edge in &self.edges {
754            if let (Some(&i), Some(&j)) = (idx.get(edge.from.as_str()), idx.get(edge.to.as_str())) {
755                if sparse_adj[[i, j]] > 0.0 {
756                    let edge_idx = new_edges.len();
757                    if let Some(adj_list) = new_adjacency.get_mut(&edge.from) {
758                        adj_list.push(edge_idx);
759                    }
760                    if let Some(adj_list) = new_adjacency.get_mut(&edge.to) {
761                        adj_list.push(edge_idx);
762                    }
763                    new_edges.push(edge.clone());
764                }
765            }
766        }
767
768        self.edges = new_edges;
769        self.adjacency = new_adjacency;
770    }
771}
772
773impl brainwires_core::graph::RelationshipGraphT for RelationshipGraph {
774    fn get_node(&self, name: &str) -> Option<&GraphNode> {
775        self.nodes.get(name)
776    }
777
778    fn get_neighbors(&self, name: &str) -> Vec<&GraphNode> {
779        RelationshipGraph::get_neighbors(self, name)
780    }
781
782    fn get_edges(&self, name: &str) -> Vec<&GraphEdge> {
783        RelationshipGraph::get_edges(self, name)
784    }
785
786    fn search(&self, query: &str, limit: usize) -> Vec<&GraphNode> {
787        RelationshipGraph::search(self, query, limit)
788    }
789
790    fn find_path(&self, from: &str, to: &str) -> Option<Vec<String>> {
791        RelationshipGraph::find_path(self, from, to)
792    }
793}
794
795/// Entity impacted by changes to another entity
796#[derive(Debug, Clone)]
797pub struct ImpactedEntity {
798    /// Entity name.
799    pub name: String,
800    /// Entity type.
801    pub entity_type: EntityType,
802    /// Graph distance from the change source.
803    pub distance: usize,
804    /// Computed impact score.
805    pub impact_score: f32,
806    /// Path of entities from source to this entity.
807    pub impact_path: Vec<String>,
808}
809
810/// A cluster of related entities
811#[derive(Debug)]
812pub struct EntityCluster {
813    /// Cluster identifier.
814    pub id: usize,
815    /// Nodes in this cluster.
816    pub nodes: Vec<GraphNode>,
817    /// Average importance of nodes.
818    pub avg_importance: f32,
819    /// Most common entity type in the cluster.
820    pub dominant_type: Option<EntityType>,
821}
822
823/// A suggested related entity
824#[derive(Debug)]
825pub struct SuggestedEntity {
826    /// Entity name.
827    pub name: String,
828    /// Entity type.
829    pub entity_type: EntityType,
830    /// How relevant this suggestion is.
831    pub relevance_score: f32,
832    /// Why this entity was suggested.
833    pub reason: String,
834}
835
836/// Context gathered for an entity
837#[derive(Debug)]
838pub struct EntityContext {
839    /// Root entity name.
840    pub root: String,
841    /// Entities related to the root.
842    pub related_entities: Vec<RelatedEntity>,
843    /// Message IDs relevant to this context.
844    pub message_ids: HashSet<String>,
845}
846
847/// A related entity with relationship info
848#[derive(Debug)]
849pub struct RelatedEntity {
850    /// Entity name.
851    pub name: String,
852    /// Entity type.
853    pub entity_type: EntityType,
854    /// Type of relationship to the root.
855    pub relationship: EdgeType,
856    /// Graph distance from the root.
857    pub distance: usize,
858    /// Relevance score.
859    pub relevance: f32,
860}
861
862/// Graph statistics
863#[derive(Debug)]
864pub struct GraphStats {
865    /// Total number of nodes.
866    pub node_count: usize,
867    /// Total number of edges.
868    pub edge_count: usize,
869    /// Node counts grouped by entity type.
870    pub nodes_by_type: HashMap<&'static str, usize>,
871    /// Edge counts grouped by edge type.
872    pub edges_by_type: HashMap<String, usize>,
873}
874
875#[cfg(test)]
876mod tests {
877    use super::*;
878
879    fn create_test_graph() -> RelationshipGraph {
880        let mut graph = RelationshipGraph::new();
881
882        // Add nodes
883        graph.add_node(GraphNode {
884            entity_name: "src/main.rs".to_string(),
885            entity_type: EntityType::File,
886            message_ids: vec!["msg1".to_string(), "msg2".to_string()],
887            mention_count: 5,
888            importance: 0.8,
889        });
890
891        graph.add_node(GraphNode {
892            entity_name: "main".to_string(),
893            entity_type: EntityType::Function,
894            message_ids: vec!["msg1".to_string()],
895            mention_count: 2,
896            importance: 0.6,
897        });
898
899        graph.add_node(GraphNode {
900            entity_name: "Config".to_string(),
901            entity_type: EntityType::Type,
902            message_ids: vec!["msg2".to_string()],
903            mention_count: 3,
904            importance: 0.7,
905        });
906
907        // Add edges
908        graph.add_edge(GraphEdge {
909            from: "src/main.rs".to_string(),
910            to: "main".to_string(),
911            edge_type: EdgeType::Contains,
912            weight: 0.9,
913            message_id: Some("msg1".to_string()),
914        });
915
916        graph.add_edge(GraphEdge {
917            from: "main".to_string(),
918            to: "Config".to_string(),
919            edge_type: EdgeType::References,
920            weight: 0.6,
921            message_id: Some("msg2".to_string()),
922        });
923
924        graph
925    }
926
927    #[test]
928    fn test_add_and_get_node() {
929        let graph = create_test_graph();
930
931        let node = graph.get_node("src/main.rs");
932        assert!(node.is_some());
933        assert_eq!(node.unwrap().mention_count, 5);
934    }
935
936    #[test]
937    fn test_get_neighbors() {
938        let graph = create_test_graph();
939
940        let neighbors = graph.get_neighbors("src/main.rs");
941        assert_eq!(neighbors.len(), 1);
942        assert_eq!(neighbors[0].entity_name, "main");
943    }
944
945    #[test]
946    fn test_find_path() {
947        let graph = create_test_graph();
948
949        let path = graph.find_path("src/main.rs", "Config");
950        assert!(path.is_some());
951        let path = path.unwrap();
952        assert_eq!(path.len(), 3);
953        assert_eq!(path[0], "src/main.rs");
954        assert_eq!(path[2], "Config");
955    }
956
957    #[test]
958    fn test_get_entity_context() {
959        let graph = create_test_graph();
960
961        let context = graph.get_entity_context("src/main.rs", 2);
962        assert_eq!(context.root, "src/main.rs");
963        assert!(!context.related_entities.is_empty());
964        assert!(!context.message_ids.is_empty());
965    }
966
967    #[test]
968    fn test_search() {
969        let graph = create_test_graph();
970
971        let results = graph.search("main", 5);
972        assert!(!results.is_empty());
973        // Should find both main function and src/main.rs
974        assert!(results.iter().any(|n| n.entity_name == "main"));
975    }
976
977    #[test]
978    fn test_graph_stats() {
979        let graph = create_test_graph();
980
981        let stats = graph.stats();
982        assert_eq!(stats.node_count, 3);
983        assert_eq!(stats.edge_count, 2);
984    }
985
986    #[test]
987    fn test_empty_path() {
988        let graph = create_test_graph();
989
990        // Add disconnected node
991        let mut graph = graph;
992        graph.add_node(GraphNode {
993            entity_name: "isolated".to_string(),
994            entity_type: EntityType::Concept,
995            message_ids: vec![],
996            mention_count: 1,
997            importance: 0.1,
998        });
999
1000        let path = graph.find_path("src/main.rs", "isolated");
1001        assert!(path.is_none());
1002    }
1003
1004    // ============ SEAL Integration Tests ============
1005
1006    #[test]
1007    fn test_get_impact_set() {
1008        let graph = create_test_graph();
1009
1010        let impacts = graph.get_impact_set("src/main.rs", 2);
1011        assert!(!impacts.is_empty());
1012
1013        // Should find main and Config
1014        let names: Vec<_> = impacts.iter().map(|i| i.name.as_str()).collect();
1015        assert!(names.contains(&"main"));
1016    }
1017
1018    #[test]
1019    fn test_get_impact_set_empty() {
1020        let graph = create_test_graph();
1021
1022        // Non-existent entity should return empty
1023        let impacts = graph.get_impact_set("nonexistent", 2);
1024        assert!(impacts.is_empty());
1025    }
1026
1027    #[test]
1028    fn test_find_clusters() {
1029        let mut graph = create_test_graph();
1030
1031        // Add a disconnected node to create a second cluster
1032        graph.add_node(GraphNode {
1033            entity_name: "isolated".to_string(),
1034            entity_type: EntityType::Concept,
1035            message_ids: vec![],
1036            mention_count: 1,
1037            importance: 0.1,
1038        });
1039
1040        let clusters = graph.find_clusters();
1041        assert_eq!(clusters.len(), 2);
1042
1043        // First cluster should be the larger connected one
1044        assert_eq!(clusters[0].nodes.len(), 3);
1045        assert_eq!(clusters[1].nodes.len(), 1);
1046    }
1047
1048    #[test]
1049    fn test_suggest_related() {
1050        let graph = create_test_graph();
1051
1052        let suggestions = graph.suggest_related(&["src/main.rs"]);
1053
1054        // Should suggest main (direct neighbor)
1055        let suggested_names: Vec<_> = suggestions.iter().map(|s| s.name.as_str()).collect();
1056        assert!(suggested_names.contains(&"main"));
1057    }
1058
1059    #[test]
1060    fn test_get_central_nodes() {
1061        let graph = create_test_graph();
1062
1063        let central = graph.get_central_nodes(2);
1064        assert!(!central.is_empty());
1065
1066        // main.rs should be among the most central (has edges)
1067        let names: Vec<_> = central.iter().map(|n| n.entity_name.as_str()).collect();
1068        assert!(names.contains(&"src/main.rs") || names.contains(&"main"));
1069    }
1070}