Skip to main content

engram/graph/
mod.rs

1//! Knowledge graph visualization (RML-894 improvements)
2//!
3//! Provides:
4//! - Interactive graph visualization with vis.js
5//! - Graph clustering and community detection
6//! - Graph statistics and metrics
7//! - Export to multiple formats (HTML, DOT, JSON)
8//! - Filtering and traversal utilities
9//! - Temporal knowledge graph with validity periods (RML-1235)
10
11pub mod coactivation;
12pub mod conflicts;
13#[cfg(feature = "duckdb-graph")]
14pub mod duckdb_graph;
15pub mod temporal;
16pub mod triplets;
17
18use chrono::{DateTime, Utc};
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet, VecDeque};
21
22use crate::types::{CrossReference, Memory, MemoryId};
23
24/// Graph node
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct GraphNode {
27    pub id: MemoryId,
28    pub label: String,
29    pub memory_type: String,
30    pub importance: f32,
31    pub tags: Vec<String>,
32}
33
34/// Graph edge
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GraphEdge {
37    pub from: MemoryId,
38    pub to: MemoryId,
39    pub edge_type: String,
40    pub score: f32,
41    pub confidence: f32,
42}
43
44/// Knowledge graph structure
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct KnowledgeGraph {
47    pub nodes: Vec<GraphNode>,
48    pub edges: Vec<GraphEdge>,
49}
50
51impl KnowledgeGraph {
52    /// Create graph from memories and cross-references
53    pub fn from_data(memories: &[Memory], crossrefs: &[CrossReference]) -> Self {
54        let nodes: Vec<GraphNode> = memories
55            .iter()
56            .map(|m| GraphNode {
57                id: m.id,
58                label: truncate_label(&m.content, 50),
59                memory_type: m.memory_type.as_str().to_string(),
60                importance: m.importance,
61                tags: m.tags.clone(),
62            })
63            .collect();
64
65        let memory_ids: std::collections::HashSet<MemoryId> =
66            memories.iter().map(|m| m.id).collect();
67
68        let edges: Vec<GraphEdge> = crossrefs
69            .iter()
70            .filter(|cr| memory_ids.contains(&cr.from_id) && memory_ids.contains(&cr.to_id))
71            .map(|cr| GraphEdge {
72                from: cr.from_id,
73                to: cr.to_id,
74                edge_type: cr.edge_type.as_str().to_string(),
75                score: cr.score,
76                confidence: cr.confidence,
77            })
78            .collect();
79
80        Self { nodes, edges }
81    }
82
83    /// Export as vis.js compatible JSON
84    pub fn to_visjs_json(&self) -> serde_json::Value {
85        let nodes: Vec<serde_json::Value> = self
86            .nodes
87            .iter()
88            .map(|n| {
89                serde_json::json!({
90                    "id": n.id,
91                    "label": n.label,
92                    "group": n.memory_type,
93                    "value": (n.importance * 10.0) as i32 + 5,
94                    "title": format!("Type: {}\nTags: {}", n.memory_type, n.tags.join(", "))
95                })
96            })
97            .collect();
98
99        let edges: Vec<serde_json::Value> = self
100            .edges
101            .iter()
102            .map(|e| {
103                serde_json::json!({
104                    "from": e.from,
105                    "to": e.to,
106                    "label": e.edge_type,
107                    "value": (e.score * e.confidence * 5.0) as i32 + 1,
108                    "title": format!("Score: {:.2}, Confidence: {:.2}", e.score, e.confidence)
109                })
110            })
111            .collect();
112
113        serde_json::json!({
114            "nodes": nodes,
115            "edges": edges
116        })
117    }
118
119    /// Export as standalone HTML with vis.js
120    pub fn to_html(&self) -> String {
121        let graph_data = self.to_visjs_json();
122
123        format!(
124            r#"<!DOCTYPE html>
125<html>
126<head>
127    <title>Engram Knowledge Graph</title>
128    <script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
129    <style>
130        body {{ margin: 0; padding: 0; font-family: system-ui, sans-serif; }}
131        #graph {{ width: 100vw; height: 100vh; }}
132        #controls {{
133            position: absolute;
134            top: 10px;
135            left: 10px;
136            background: white;
137            padding: 10px;
138            border-radius: 8px;
139            box-shadow: 0 2px 8px rgba(0,0,0,0.1);
140        }}
141        #search {{ padding: 8px; width: 200px; border: 1px solid #ddd; border-radius: 4px; }}
142        .legend {{ display: flex; gap: 10px; margin-top: 10px; flex-wrap: wrap; }}
143        .legend-item {{ display: flex; align-items: center; gap: 5px; font-size: 12px; }}
144        .legend-dot {{ width: 12px; height: 12px; border-radius: 50%; }}
145    </style>
146</head>
147<body>
148    <div id="controls">
149        <input type="text" id="search" placeholder="Search nodes...">
150        <div class="legend">
151            <div class="legend-item"><span class="legend-dot" style="background: #97C2FC;"></span> note</div>
152            <div class="legend-item"><span class="legend-dot" style="background: #FFFF00;"></span> todo</div>
153            <div class="legend-item"><span class="legend-dot" style="background: #FB7E81;"></span> issue</div>
154            <div class="legend-item"><span class="legend-dot" style="background: #7BE141;"></span> decision</div>
155            <div class="legend-item"><span class="legend-dot" style="background: #FFA807;"></span> preference</div>
156            <div class="legend-item"><span class="legend-dot" style="background: #6E6EFD;"></span> learning</div>
157        </div>
158    </div>
159    <div id="graph"></div>
160    <script>
161        const data = {graph_data};
162
163        const options = {{
164            nodes: {{
165                shape: 'dot',
166                scaling: {{ min: 10, max: 30 }},
167                font: {{ size: 12, face: 'system-ui' }}
168            }},
169            edges: {{
170                arrows: 'to',
171                scaling: {{ min: 1, max: 5 }},
172                font: {{ size: 10, align: 'middle' }}
173            }},
174            groups: {{
175                note: {{ color: '#97C2FC' }},
176                todo: {{ color: '#FFFF00' }},
177                issue: {{ color: '#FB7E81' }},
178                decision: {{ color: '#7BE141' }},
179                preference: {{ color: '#FFA807' }},
180                learning: {{ color: '#6E6EFD' }},
181                context: {{ color: '#C2FABC' }},
182                credential: {{ color: '#FD6A6A' }}
183            }},
184            physics: {{
185                stabilization: {{ iterations: 100 }},
186                barnesHut: {{
187                    gravitationalConstant: -2000,
188                    springLength: 100
189                }}
190            }},
191            interaction: {{
192                hover: true,
193                tooltipDelay: 100
194            }}
195        }};
196
197        const container = document.getElementById('graph');
198        const network = new vis.Network(container, data, options);
199
200        // Search functionality
201        const searchInput = document.getElementById('search');
202        searchInput.addEventListener('input', function() {{
203            const query = this.value.toLowerCase();
204            if (query) {{
205                const matchingNodes = data.nodes.filter(n =>
206                    n.label.toLowerCase().includes(query)
207                ).map(n => n.id);
208                network.selectNodes(matchingNodes);
209                if (matchingNodes.length > 0) {{
210                    network.focus(matchingNodes[0], {{ scale: 1.5, animation: true }});
211                }}
212            }} else {{
213                network.unselectAll();
214            }}
215        }});
216
217        // Click to focus
218        network.on('click', function(params) {{
219            if (params.nodes.length > 0) {{
220                network.focus(params.nodes[0], {{ scale: 1.5, animation: true }});
221            }}
222        }});
223    </script>
224</body>
225</html>"#,
226            graph_data = serde_json::to_string(&graph_data).unwrap_or_default()
227        )
228    }
229}
230
231/// Truncate content for display as node label
232fn truncate_label(content: &str, max_len: usize) -> String {
233    let first_line = content.lines().next().unwrap_or(content);
234    if first_line.len() <= max_len {
235        first_line.to_string()
236    } else {
237        format!("{}...", &first_line[..max_len - 3])
238    }
239}
240
241// =============================================================================
242// Graph Statistics (RML-894)
243// =============================================================================
244
245/// Graph statistics and metrics
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct GraphStats {
248    /// Total number of nodes
249    pub node_count: usize,
250    /// Total number of edges
251    pub edge_count: usize,
252    /// Average degree (edges per node)
253    pub avg_degree: f32,
254    /// Graph density (actual edges / possible edges)
255    pub density: f32,
256    /// Number of connected components
257    pub component_count: usize,
258    /// Size of largest component
259    pub largest_component_size: usize,
260    /// Nodes by memory type
261    pub nodes_by_type: HashMap<String, usize>,
262    /// Edges by type
263    pub edges_by_type: HashMap<String, usize>,
264    /// Most connected nodes (top 10 by degree)
265    pub hub_nodes: Vec<(MemoryId, usize)>,
266    /// Isolated nodes (degree 0)
267    pub isolated_count: usize,
268}
269
270impl KnowledgeGraph {
271    /// Calculate graph statistics
272    pub fn stats(&self) -> GraphStats {
273        let node_count = self.nodes.len();
274        let edge_count = self.edges.len();
275
276        // Build adjacency for degree calculation
277        let mut degree: HashMap<MemoryId, usize> = HashMap::new();
278        for node in &self.nodes {
279            degree.insert(node.id, 0);
280        }
281        for edge in &self.edges {
282            *degree.entry(edge.from).or_insert(0) += 1;
283            *degree.entry(edge.to).or_insert(0) += 1;
284        }
285
286        let avg_degree = if node_count > 0 {
287            degree.values().sum::<usize>() as f32 / node_count as f32
288        } else {
289            0.0
290        };
291
292        // Density: edges / (n * (n-1) / 2) for undirected, edges / (n * (n-1)) for directed
293        let density = if node_count > 1 {
294            edge_count as f32 / (node_count * (node_count - 1)) as f32
295        } else {
296            0.0
297        };
298
299        // Count by type
300        let mut nodes_by_type: HashMap<String, usize> = HashMap::new();
301        for node in &self.nodes {
302            *nodes_by_type.entry(node.memory_type.clone()).or_insert(0) += 1;
303        }
304
305        let mut edges_by_type: HashMap<String, usize> = HashMap::new();
306        for edge in &self.edges {
307            *edges_by_type.entry(edge.edge_type.clone()).or_insert(0) += 1;
308        }
309
310        // Find hub nodes (top 10 by degree)
311        let mut degree_list: Vec<(MemoryId, usize)> =
312            degree.iter().map(|(&k, &v)| (k, v)).collect();
313        degree_list.sort_by(|a, b| b.1.cmp(&a.1));
314        let hub_nodes: Vec<(MemoryId, usize)> = degree_list.into_iter().take(10).collect();
315
316        // Count isolated nodes
317        let isolated_count = degree.values().filter(|&&d| d == 0).count();
318
319        // Find connected components using BFS
320        let components = self.find_connected_components();
321        let component_count = components.len();
322        let largest_component_size = components.iter().map(|c| c.len()).max().unwrap_or(0);
323
324        GraphStats {
325            node_count,
326            edge_count,
327            avg_degree,
328            density,
329            component_count,
330            largest_component_size,
331            nodes_by_type,
332            edges_by_type,
333            hub_nodes,
334            isolated_count,
335        }
336    }
337
338    /// Find connected components using BFS
339    fn find_connected_components(&self) -> Vec<Vec<MemoryId>> {
340        let node_ids: HashSet<MemoryId> = self.nodes.iter().map(|n| n.id).collect();
341
342        // Build adjacency list (undirected)
343        let mut adj: HashMap<MemoryId, Vec<MemoryId>> = HashMap::new();
344        for id in &node_ids {
345            adj.insert(*id, Vec::new());
346        }
347        for edge in &self.edges {
348            if let Some(list) = adj.get_mut(&edge.from) {
349                list.push(edge.to);
350            }
351            if let Some(list) = adj.get_mut(&edge.to) {
352                list.push(edge.from);
353            }
354        }
355
356        let mut visited: HashSet<MemoryId> = HashSet::new();
357        let mut components = Vec::new();
358
359        for &start in &node_ids {
360            if visited.contains(&start) {
361                continue;
362            }
363
364            let mut component = Vec::new();
365            let mut queue = VecDeque::new();
366            queue.push_back(start);
367            visited.insert(start);
368
369            while let Some(node) = queue.pop_front() {
370                component.push(node);
371                if let Some(neighbors) = adj.get(&node) {
372                    for &neighbor in neighbors {
373                        if !visited.contains(&neighbor) {
374                            visited.insert(neighbor);
375                            queue.push_back(neighbor);
376                        }
377                    }
378                }
379            }
380
381            components.push(component);
382        }
383
384        components
385    }
386
387    /// Calculate centrality scores for nodes
388    pub fn centrality(&self) -> HashMap<MemoryId, CentralityScores> {
389        let mut results: HashMap<MemoryId, CentralityScores> = HashMap::new();
390
391        // Build adjacency
392        let mut in_degree: HashMap<MemoryId, usize> = HashMap::new();
393        let mut out_degree: HashMap<MemoryId, usize> = HashMap::new();
394
395        for node in &self.nodes {
396            in_degree.insert(node.id, 0);
397            out_degree.insert(node.id, 0);
398        }
399
400        for edge in &self.edges {
401            *out_degree.entry(edge.from).or_insert(0) += 1;
402            *in_degree.entry(edge.to).or_insert(0) += 1;
403        }
404
405        let max_degree = self.nodes.len().saturating_sub(1).max(1) as f32;
406
407        for node in &self.nodes {
408            let in_d = *in_degree.get(&node.id).unwrap_or(&0) as f32;
409            let out_d = *out_degree.get(&node.id).unwrap_or(&0) as f32;
410
411            results.insert(
412                node.id,
413                CentralityScores {
414                    in_degree: in_d / max_degree,
415                    out_degree: out_d / max_degree,
416                    degree: (in_d + out_d) / (2.0 * max_degree),
417                    // Simplified closeness based on direct connections
418                    closeness: (in_d + out_d) / (2.0 * max_degree),
419                },
420            );
421        }
422
423        results
424    }
425}
426
427/// Centrality scores for a node
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct CentralityScores {
430    /// Normalized in-degree centrality
431    pub in_degree: f32,
432    /// Normalized out-degree centrality
433    pub out_degree: f32,
434    /// Combined degree centrality
435    pub degree: f32,
436    /// Closeness centrality (simplified)
437    pub closeness: f32,
438}
439
440// =============================================================================
441// Graph Filtering (RML-894)
442// =============================================================================
443
444/// Filter options for graph queries
445#[derive(Debug, Clone, Default)]
446pub struct GraphFilter {
447    /// Filter by memory types
448    pub memory_types: Option<Vec<String>>,
449    /// Filter by tags (any match)
450    pub tags: Option<Vec<String>>,
451    /// Filter by edge types
452    pub edge_types: Option<Vec<String>>,
453    /// Minimum importance threshold
454    pub min_importance: Option<f32>,
455    /// Maximum importance threshold
456    pub max_importance: Option<f32>,
457    /// Created after this date
458    pub created_after: Option<DateTime<Utc>>,
459    /// Created before this date
460    pub created_before: Option<DateTime<Utc>>,
461    /// Minimum edge confidence
462    pub min_confidence: Option<f32>,
463    /// Minimum edge score
464    pub min_score: Option<f32>,
465    /// Maximum number of nodes
466    pub limit: Option<usize>,
467}
468
469impl GraphFilter {
470    pub fn new() -> Self {
471        Self::default()
472    }
473
474    pub fn with_types(mut self, types: Vec<String>) -> Self {
475        self.memory_types = Some(types);
476        self
477    }
478
479    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
480        self.tags = Some(tags);
481        self
482    }
483
484    pub fn with_min_importance(mut self, min: f32) -> Self {
485        self.min_importance = Some(min);
486        self
487    }
488
489    pub fn with_min_confidence(mut self, min: f32) -> Self {
490        self.min_confidence = Some(min);
491        self
492    }
493
494    pub fn with_limit(mut self, limit: usize) -> Self {
495        self.limit = Some(limit);
496        self
497    }
498}
499
500impl KnowledgeGraph {
501    /// Apply filter to create a subgraph
502    pub fn filter(&self, filter: &GraphFilter) -> KnowledgeGraph {
503        // Filter nodes
504        let mut filtered_nodes: Vec<GraphNode> = self
505            .nodes
506            .iter()
507            .filter(|n| {
508                // Type filter
509                if let Some(ref types) = filter.memory_types {
510                    if !types.contains(&n.memory_type) {
511                        return false;
512                    }
513                }
514
515                // Tag filter (any match)
516                if let Some(ref tags) = filter.tags {
517                    if !n.tags.iter().any(|t| tags.contains(t)) {
518                        return false;
519                    }
520                }
521
522                // Importance filter
523                if let Some(min) = filter.min_importance {
524                    if n.importance < min {
525                        return false;
526                    }
527                }
528                if let Some(max) = filter.max_importance {
529                    if n.importance > max {
530                        return false;
531                    }
532                }
533
534                true
535            })
536            .cloned()
537            .collect();
538
539        // Apply limit
540        if let Some(limit) = filter.limit {
541            filtered_nodes.truncate(limit);
542        }
543
544        // Get set of valid node IDs
545        let valid_ids: HashSet<MemoryId> = filtered_nodes.iter().map(|n| n.id).collect();
546
547        // Filter edges
548        let filtered_edges: Vec<GraphEdge> = self
549            .edges
550            .iter()
551            .filter(|e| {
552                // Both endpoints must be in filtered nodes
553                if !valid_ids.contains(&e.from) || !valid_ids.contains(&e.to) {
554                    return false;
555                }
556
557                // Edge type filter
558                if let Some(ref types) = filter.edge_types {
559                    if !types.contains(&e.edge_type) {
560                        return false;
561                    }
562                }
563
564                // Confidence filter
565                if let Some(min) = filter.min_confidence {
566                    if e.confidence < min {
567                        return false;
568                    }
569                }
570
571                // Score filter
572                if let Some(min) = filter.min_score {
573                    if e.score < min {
574                        return false;
575                    }
576                }
577
578                true
579            })
580            .cloned()
581            .collect();
582
583        KnowledgeGraph {
584            nodes: filtered_nodes,
585            edges: filtered_edges,
586        }
587    }
588
589    /// Get subgraph centered on a node with given depth
590    pub fn neighborhood(&self, center: MemoryId, depth: usize) -> KnowledgeGraph {
591        let mut visited: HashSet<MemoryId> = HashSet::new();
592        let mut current_level: HashSet<MemoryId> = HashSet::new();
593        current_level.insert(center);
594        visited.insert(center);
595
596        // Build adjacency
597        let mut adj: HashMap<MemoryId, Vec<MemoryId>> = HashMap::new();
598        for edge in &self.edges {
599            adj.entry(edge.from).or_default().push(edge.to);
600            adj.entry(edge.to).or_default().push(edge.from);
601        }
602
603        // BFS to depth
604        for _ in 0..depth {
605            let mut next_level: HashSet<MemoryId> = HashSet::new();
606            for &node in &current_level {
607                if let Some(neighbors) = adj.get(&node) {
608                    for &neighbor in neighbors {
609                        if !visited.contains(&neighbor) {
610                            visited.insert(neighbor);
611                            next_level.insert(neighbor);
612                        }
613                    }
614                }
615            }
616            current_level = next_level;
617        }
618
619        // Filter to visited nodes
620        let nodes: Vec<GraphNode> = self
621            .nodes
622            .iter()
623            .filter(|n| visited.contains(&n.id))
624            .cloned()
625            .collect();
626
627        let edges: Vec<GraphEdge> = self
628            .edges
629            .iter()
630            .filter(|e| visited.contains(&e.from) && visited.contains(&e.to))
631            .cloned()
632            .collect();
633
634        KnowledgeGraph { nodes, edges }
635    }
636}
637
638// =============================================================================
639// DOT Export (RML-894)
640// =============================================================================
641
642impl KnowledgeGraph {
643    /// Export as DOT format for Graphviz
644    pub fn to_dot(&self) -> String {
645        let mut dot = String::from("digraph knowledge_graph {\n");
646        dot.push_str("    rankdir=LR;\n");
647        dot.push_str("    node [shape=box, style=rounded];\n\n");
648
649        // Color mapping for memory types
650        let colors: HashMap<&str, &str> = [
651            ("note", "#97C2FC"),
652            ("todo", "#FFFF00"),
653            ("issue", "#FB7E81"),
654            ("decision", "#7BE141"),
655            ("preference", "#FFA807"),
656            ("learning", "#6E6EFD"),
657            ("context", "#C2FABC"),
658            ("credential", "#FD6A6A"),
659        ]
660        .into_iter()
661        .collect();
662
663        // Write nodes
664        for node in &self.nodes {
665            let color = colors.get(node.memory_type.as_str()).unwrap_or(&"#CCCCCC");
666            let label = node.label.replace('"', "\\\"");
667            dot.push_str(&format!(
668                "    \"{}\" [label=\"{}\", fillcolor=\"{}\", style=\"filled,rounded\"];\n",
669                node.id, label, color
670            ));
671        }
672
673        dot.push('\n');
674
675        // Write edges
676        for edge in &self.edges {
677            let style = match edge.edge_type.as_str() {
678                "related_to" => "solid",
679                "part_of" => "dashed",
680                "depends_on" => "bold",
681                "contradicts" => "dotted",
682                "supports" => "solid",
683                "references" => "dashed",
684                _ => "solid",
685            };
686            dot.push_str(&format!(
687                "    \"{}\" -> \"{}\" [label=\"{}\", style={}, penwidth={}];\n",
688                edge.from,
689                edge.to,
690                edge.edge_type,
691                style,
692                (edge.score * 2.0 + 0.5).min(3.0)
693            ));
694        }
695
696        dot.push_str("}\n");
697        dot
698    }
699
700    /// Export as GEXF format for Gephi
701    pub fn to_gexf(&self) -> String {
702        let mut gexf = String::from(
703            r#"<?xml version="1.0" encoding="UTF-8"?>
704<gexf xmlns="http://gexf.net/1.3" version="1.3">
705  <meta>
706    <creator>Engram</creator>
707    <description>Knowledge Graph Export</description>
708  </meta>
709  <graph mode="static" defaultedgetype="directed">
710    <attributes class="node">
711      <attribute id="0" title="type" type="string"/>
712      <attribute id="1" title="importance" type="float"/>
713    </attributes>
714    <attributes class="edge">
715      <attribute id="0" title="score" type="float"/>
716      <attribute id="1" title="confidence" type="float"/>
717    </attributes>
718    <nodes>
719"#,
720        );
721
722        for node in &self.nodes {
723            let label = node
724                .label
725                .replace('&', "&amp;")
726                .replace('<', "&lt;")
727                .replace('>', "&gt;")
728                .replace('"', "&quot;");
729            gexf.push_str(&format!(
730                r#"      <node id="{}" label="{}">
731        <attvalues>
732          <attvalue for="0" value="{}"/>
733          <attvalue for="1" value="{}"/>
734        </attvalues>
735      </node>
736"#,
737                node.id, label, node.memory_type, node.importance
738            ));
739        }
740
741        gexf.push_str("    </nodes>\n    <edges>\n");
742
743        for (i, edge) in self.edges.iter().enumerate() {
744            gexf.push_str(&format!(
745                r#"      <edge id="{}" source="{}" target="{}" label="{}">
746        <attvalues>
747          <attvalue for="0" value="{}"/>
748          <attvalue for="1" value="{}"/>
749        </attvalues>
750      </edge>
751"#,
752                i, edge.from, edge.to, edge.edge_type, edge.score, edge.confidence
753            ));
754        }
755
756        gexf.push_str("    </edges>\n  </graph>\n</gexf>\n");
757        gexf
758    }
759}
760
761// =============================================================================
762// Community Detection (RML-894)
763// =============================================================================
764
765/// A cluster/community of nodes
766#[derive(Debug, Clone, Serialize, Deserialize)]
767pub struct GraphCluster {
768    /// Cluster identifier
769    pub id: usize,
770    /// Node IDs in this cluster
771    pub members: Vec<MemoryId>,
772    /// Dominant memory type in cluster
773    pub dominant_type: Option<String>,
774    /// Common tags across cluster
775    pub common_tags: Vec<String>,
776    /// Internal edge count
777    pub internal_edges: usize,
778    /// Cluster cohesion score
779    pub cohesion: f32,
780}
781
782impl KnowledgeGraph {
783    /// Detect communities using label propagation algorithm
784    pub fn detect_communities(&self, max_iterations: usize) -> Vec<GraphCluster> {
785        if self.nodes.is_empty() {
786            return Vec::new();
787        }
788
789        // Initialize: each node in its own community
790        let mut labels: HashMap<MemoryId, usize> = self
791            .nodes
792            .iter()
793            .enumerate()
794            .map(|(i, n)| (n.id, i))
795            .collect();
796
797        // Build adjacency
798        let mut adj: HashMap<MemoryId, Vec<(MemoryId, f32)>> = HashMap::new();
799        for node in &self.nodes {
800            adj.insert(node.id, Vec::new());
801        }
802        for edge in &self.edges {
803            let weight = edge.score * edge.confidence;
804            adj.entry(edge.from).or_default().push((edge.to, weight));
805            adj.entry(edge.to).or_default().push((edge.from, weight));
806        }
807
808        // Label propagation
809        let node_ids: Vec<MemoryId> = self.nodes.iter().map(|n| n.id).collect();
810
811        for _ in 0..max_iterations {
812            let mut changed = false;
813
814            for &node_id in &node_ids {
815                if let Some(neighbors) = adj.get(&node_id) {
816                    if neighbors.is_empty() {
817                        continue;
818                    }
819
820                    // Count weighted votes for each label
821                    let mut votes: HashMap<usize, f32> = HashMap::new();
822                    for &(neighbor, weight) in neighbors {
823                        if let Some(&label) = labels.get(&neighbor) {
824                            *votes.entry(label).or_insert(0.0) += weight;
825                        }
826                    }
827
828                    // Pick label with most votes
829                    if let Some((&best_label, _)) = votes.iter().max_by(|a, b| a.1.total_cmp(b.1)) {
830                        let current = labels.get(&node_id).copied().unwrap_or(0);
831                        if best_label != current {
832                            labels.insert(node_id, best_label);
833                            changed = true;
834                        }
835                    }
836                }
837            }
838
839            if !changed {
840                break;
841            }
842        }
843
844        // Group nodes by label
845        let mut clusters_map: HashMap<usize, Vec<MemoryId>> = HashMap::new();
846        for (node_id, label) in &labels {
847            clusters_map.entry(*label).or_default().push(*node_id);
848        }
849
850        // Build cluster objects
851        let node_map: HashMap<MemoryId, &GraphNode> =
852            self.nodes.iter().map(|n| (n.id, n)).collect();
853
854        let mut clusters: Vec<GraphCluster> = clusters_map
855            .into_iter()
856            .enumerate()
857            .map(|(new_id, (_, members))| {
858                // Find dominant type
859                let mut type_counts: HashMap<&str, usize> = HashMap::new();
860                let mut all_tags: HashMap<&str, usize> = HashMap::new();
861
862                for &member_id in &members {
863                    if let Some(node) = node_map.get(&member_id) {
864                        *type_counts.entry(node.memory_type.as_str()).or_insert(0) += 1;
865                        for tag in &node.tags {
866                            *all_tags.entry(tag.as_str()).or_insert(0) += 1;
867                        }
868                    }
869                }
870
871                let dominant_type = type_counts
872                    .into_iter()
873                    .max_by_key(|(_, count)| *count)
874                    .map(|(t, _)| t.to_string());
875
876                // Common tags (present in > 50% of members)
877                let threshold = members.len() / 2;
878                let common_tags: Vec<String> = all_tags
879                    .into_iter()
880                    .filter(|(_, count)| *count > threshold)
881                    .map(|(tag, _)| tag.to_string())
882                    .collect();
883
884                // Count internal edges
885                let member_set: HashSet<MemoryId> = members.iter().copied().collect();
886                let internal_edges = self
887                    .edges
888                    .iter()
889                    .filter(|e| member_set.contains(&e.from) && member_set.contains(&e.to))
890                    .count();
891
892                // Cohesion: internal edges / possible internal edges
893                let n = members.len();
894                let possible = if n > 1 { n * (n - 1) } else { 1 };
895                let cohesion = internal_edges as f32 / possible as f32;
896
897                GraphCluster {
898                    id: new_id,
899                    members,
900                    dominant_type,
901                    common_tags,
902                    internal_edges,
903                    cohesion,
904                }
905            })
906            .collect();
907
908        // Sort by size (largest first)
909        clusters.sort_by(|a, b| b.members.len().cmp(&a.members.len()));
910
911        // Renumber IDs
912        for (i, cluster) in clusters.iter_mut().enumerate() {
913            cluster.id = i;
914        }
915
916        clusters
917    }
918}
919
920#[cfg(test)]
921mod tests {
922    use super::*;
923
924    fn make_node(id: MemoryId, memory_type: &str, tags: Vec<&str>) -> GraphNode {
925        GraphNode {
926            id,
927            label: format!("Node {}", id),
928            memory_type: memory_type.to_string(),
929            importance: 0.5,
930            tags: tags.into_iter().map(String::from).collect(),
931        }
932    }
933
934    fn make_edge(from: MemoryId, to: MemoryId, edge_type: &str) -> GraphEdge {
935        GraphEdge {
936            from,
937            to,
938            edge_type: edge_type.to_string(),
939            score: 0.8,
940            confidence: 0.9,
941        }
942    }
943
944    #[test]
945    fn test_truncate_label() {
946        assert_eq!(truncate_label("short", 50), "short");
947        assert_eq!(
948            truncate_label("this is a very long label that should be truncated", 20),
949            "this is a very lo..."
950        );
951    }
952
953    #[test]
954    fn test_graph_stats() {
955        let id1: MemoryId = 1;
956        let id2: MemoryId = 2;
957        let id3: MemoryId = 3;
958
959        let graph = KnowledgeGraph {
960            nodes: vec![
961                make_node(id1, "note", vec!["rust"]),
962                make_node(id2, "note", vec!["rust"]),
963                make_node(id3, "todo", vec!["python"]),
964            ],
965            edges: vec![
966                make_edge(id1, id2, "related_to"),
967                make_edge(id2, id3, "depends_on"),
968            ],
969        };
970
971        let stats = graph.stats();
972        assert_eq!(stats.node_count, 3);
973        assert_eq!(stats.edge_count, 2);
974        assert_eq!(stats.nodes_by_type.get("note"), Some(&2));
975        assert_eq!(stats.nodes_by_type.get("todo"), Some(&1));
976        assert_eq!(stats.isolated_count, 0);
977        assert_eq!(stats.component_count, 1);
978    }
979
980    #[test]
981    fn test_graph_filter() {
982        let id1: MemoryId = 1;
983        let id2: MemoryId = 2;
984        let id3: MemoryId = 3;
985
986        let graph = KnowledgeGraph {
987            nodes: vec![
988                make_node(id1, "note", vec!["rust"]),
989                make_node(id2, "note", vec!["python"]),
990                make_node(id3, "todo", vec!["rust"]),
991            ],
992            edges: vec![
993                make_edge(id1, id2, "related_to"),
994                make_edge(id2, id3, "depends_on"),
995            ],
996        };
997
998        // Filter by type
999        let filter = GraphFilter::new().with_types(vec!["note".to_string()]);
1000        let filtered = graph.filter(&filter);
1001        assert_eq!(filtered.nodes.len(), 2);
1002        assert_eq!(filtered.edges.len(), 1); // Only edge between notes
1003
1004        // Filter by tag
1005        let filter = GraphFilter::new().with_tags(vec!["rust".to_string()]);
1006        let filtered = graph.filter(&filter);
1007        assert_eq!(filtered.nodes.len(), 2); // id1 and id3 have "rust"
1008    }
1009
1010    #[test]
1011    fn test_neighborhood() {
1012        let id1: MemoryId = 1;
1013        let id2: MemoryId = 2;
1014        let id3: MemoryId = 3;
1015        let id4: MemoryId = 4;
1016
1017        let graph = KnowledgeGraph {
1018            nodes: vec![
1019                make_node(id1, "note", vec![]),
1020                make_node(id2, "note", vec![]),
1021                make_node(id3, "note", vec![]),
1022                make_node(id4, "note", vec![]),
1023            ],
1024            edges: vec![
1025                make_edge(id1, id2, "related_to"),
1026                make_edge(id2, id3, "related_to"),
1027                make_edge(id3, id4, "related_to"),
1028            ],
1029        };
1030
1031        // Depth 1 from id1 should include id1, id2
1032        let subgraph = graph.neighborhood(id1, 1);
1033        assert_eq!(subgraph.nodes.len(), 2);
1034
1035        // Depth 2 from id1 should include id1, id2, id3
1036        let subgraph = graph.neighborhood(id1, 2);
1037        assert_eq!(subgraph.nodes.len(), 3);
1038    }
1039
1040    #[test]
1041    fn test_to_dot() {
1042        let id1: MemoryId = 1;
1043        let id2: MemoryId = 2;
1044
1045        let graph = KnowledgeGraph {
1046            nodes: vec![
1047                make_node(id1, "note", vec![]),
1048                make_node(id2, "todo", vec![]),
1049            ],
1050            edges: vec![make_edge(id1, id2, "related_to")],
1051        };
1052
1053        let dot = graph.to_dot();
1054        assert!(dot.contains("digraph knowledge_graph"));
1055        assert!(dot.contains(&id1.to_string()));
1056        assert!(dot.contains(&id2.to_string()));
1057        assert!(dot.contains("related_to"));
1058    }
1059
1060    #[test]
1061    fn test_community_detection() {
1062        // Create two clusters
1063        let a1: MemoryId = 1;
1064        let a2: MemoryId = 2;
1065        let a3: MemoryId = 3;
1066        let b1: MemoryId = 4;
1067        let b2: MemoryId = 5;
1068
1069        let graph = KnowledgeGraph {
1070            nodes: vec![
1071                make_node(a1, "note", vec!["cluster-a"]),
1072                make_node(a2, "note", vec!["cluster-a"]),
1073                make_node(a3, "note", vec!["cluster-a"]),
1074                make_node(b1, "todo", vec!["cluster-b"]),
1075                make_node(b2, "todo", vec!["cluster-b"]),
1076            ],
1077            edges: vec![
1078                // Cluster A - densely connected
1079                make_edge(a1, a2, "related_to"),
1080                make_edge(a2, a3, "related_to"),
1081                make_edge(a1, a3, "related_to"),
1082                // Cluster B - connected
1083                make_edge(b1, b2, "related_to"),
1084                // Weak link between clusters
1085                GraphEdge {
1086                    from: a3,
1087                    to: b1,
1088                    edge_type: "related_to".to_string(),
1089                    score: 0.1, // weak
1090                    confidence: 0.1,
1091                },
1092            ],
1093        };
1094
1095        let communities = graph.detect_communities(10);
1096        // Should detect at least the general structure
1097        assert!(!communities.is_empty());
1098        // Largest community should have at least 2 members
1099        assert!(communities[0].members.len() >= 2);
1100    }
1101}