Skip to main content

codemem_engine/
graph_linking.rs

1use crate::CodememEngine;
2use codemem_core::{
3    CodememError, Edge, GraphBackend, MemoryNode, NodeKind, NodeMemoryResult, RelationshipType,
4};
5use std::collections::HashSet;
6
7#[cfg(test)]
8#[path = "tests/graph_linking_tests.rs"]
9mod tests;
10
11impl CodememEngine {
12    // ── Auto-linking ─────────────────────────────────────────────────────
13
14    /// Scan memory content for file paths and qualified symbol names that exist
15    /// as graph nodes, and create RELATES_TO edges.
16    pub fn auto_link_to_code_nodes(
17        &self,
18        memory_id: &str,
19        content: &str,
20        existing_links: &[String],
21    ) -> usize {
22        let mut graph = match self.lock_graph() {
23            Ok(g) => g,
24            Err(_) => return 0,
25        };
26
27        let existing_set: HashSet<&str> = existing_links.iter().map(|s| s.as_str()).collect();
28
29        let mut candidates: Vec<String> = Vec::new();
30
31        for word in content.split_whitespace() {
32            let cleaned = word.trim_matches(|c: char| {
33                !c.is_alphanumeric() && c != '/' && c != '.' && c != '_' && c != '-' && c != ':'
34            });
35            if cleaned.is_empty() {
36                continue;
37            }
38            if cleaned.contains('/') || cleaned.contains('.') {
39                let file_id = format!("file:{cleaned}");
40                if !existing_set.contains(file_id.as_str()) {
41                    candidates.push(file_id);
42                }
43            }
44            if cleaned.contains("::") {
45                let sym_id = format!("sym:{cleaned}");
46                if !existing_set.contains(sym_id.as_str()) {
47                    candidates.push(sym_id);
48                }
49            }
50        }
51
52        let now = chrono::Utc::now();
53        let mut created = 0;
54        let mut seen = HashSet::new();
55
56        for candidate_id in &candidates {
57            if !seen.insert(candidate_id.clone()) {
58                continue;
59            }
60            if graph.get_node(candidate_id).ok().flatten().is_none() {
61                continue;
62            }
63            let edge = Edge {
64                id: format!("{memory_id}-RELATES_TO-{candidate_id}"),
65                src: memory_id.to_string(),
66                dst: candidate_id.clone(),
67                relationship: RelationshipType::RelatesTo,
68                weight: 0.5,
69                properties: std::collections::HashMap::from([(
70                    "auto_linked".to_string(),
71                    serde_json::json!(true),
72                )]),
73                created_at: now,
74                valid_from: None,
75                valid_to: None,
76            };
77            if self.storage.insert_graph_edge(&edge).is_ok() && graph.add_edge(edge).is_ok() {
78                created += 1;
79            }
80        }
81
82        created
83    }
84
85    // ── Tag-based Auto-linking ──────────────────────────────────────────
86
87    /// Create edges between this memory and other memories that share tags.
88    /// - `session:*` tags → PRECEDED_BY edges (temporal ordering within a session)
89    /// - Other shared tags → SHARES_THEME edges (topical overlap)
90    ///
91    /// This runs during `persist_memory` so the graph builds connectivity at
92    /// ingestion time, rather than relying solely on creative consolidation.
93    pub fn auto_link_by_tags(&self, memory: &MemoryNode) {
94        if memory.tags.is_empty() {
95            return;
96        }
97
98        // Phase 1: Collect sibling IDs and build edges WITHOUT holding the graph lock.
99        let now = chrono::Utc::now();
100        let mut linked = HashSet::new();
101        let mut edges_to_add = Vec::new();
102
103        for tag in &memory.tags {
104            let is_session_tag = tag.starts_with("session:");
105
106            let sibling_ids = match self.storage.find_memory_ids_by_tag(
107                tag,
108                memory.namespace.as_deref(),
109                &memory.id,
110            ) {
111                Ok(ids) => ids,
112                Err(_) => continue,
113            };
114
115            for sibling_id in sibling_ids {
116                if !linked.insert(sibling_id.clone()) {
117                    continue;
118                }
119
120                let (relationship, edge_label) = if is_session_tag {
121                    (RelationshipType::PrecededBy, "PRECEDED_BY")
122                } else {
123                    (RelationshipType::SharesTheme, "SHARES_THEME")
124                };
125
126                let edge_id = format!("{}-{edge_label}-{sibling_id}", memory.id);
127                edges_to_add.push(Edge {
128                    id: edge_id,
129                    src: sibling_id,
130                    dst: memory.id.clone(),
131                    relationship,
132                    weight: if is_session_tag { 0.8 } else { 0.5 },
133                    properties: std::collections::HashMap::from([(
134                        "auto_linked".to_string(),
135                        serde_json::json!(true),
136                    )]),
137                    created_at: now,
138                    valid_from: Some(now),
139                    valid_to: None,
140                });
141            }
142        }
143
144        if edges_to_add.is_empty() {
145            return;
146        }
147
148        // Phase 2: Acquire graph lock only for mutations.
149        let mut graph = match self.lock_graph() {
150            Ok(g) => g,
151            Err(_) => return,
152        };
153
154        for edge in edges_to_add {
155            if self.storage.insert_graph_edge(&edge).is_ok() {
156                let _ = graph.add_edge(edge);
157            }
158        }
159    }
160
161    // ── Node Memory Queries ──────────────────────────────────────────────
162
163    /// Retrieve all memories connected to a graph node via BFS traversal.
164    ///
165    /// Performs level-by-level BFS to track actual hop distance. For each
166    /// Memory node found, reports the relationship type from the edge that
167    /// connected it (or the edge leading into the path toward it).
168    pub fn get_node_memories(
169        &self,
170        node_id: &str,
171        max_depth: usize,
172        include_relationships: Option<&[RelationshipType]>,
173    ) -> Result<Vec<NodeMemoryResult>, CodememError> {
174        let graph = self.lock_graph()?;
175
176        // Manual BFS tracking (node_id, depth, relationship_from_parent_edge)
177        let mut results: Vec<NodeMemoryResult> = Vec::new();
178        let mut seen_memory_ids = HashSet::new();
179        let mut visited = HashSet::new();
180        let mut queue: std::collections::VecDeque<(String, usize, String)> =
181            std::collections::VecDeque::new();
182
183        visited.insert(node_id.to_string());
184        queue.push_back((node_id.to_string(), 0, String::new()));
185
186        while let Some((current_id, depth, parent_rel)) = queue.pop_front() {
187            // Collect Memory nodes (skip the start node itself)
188            if current_id != node_id {
189                if let Some(node) = graph.get_node_ref(&current_id) {
190                    if node.kind == NodeKind::Memory {
191                        let memory_id = node.memory_id.as_deref().unwrap_or(&node.id);
192                        if seen_memory_ids.insert(memory_id.to_string()) {
193                            if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
194                                results.push(NodeMemoryResult {
195                                    memory,
196                                    relationship: parent_rel.clone(),
197                                    depth,
198                                });
199                            }
200                        }
201                    }
202                }
203            }
204
205            if depth >= max_depth {
206                continue;
207            }
208
209            // Expand neighbors via edges, skipping Chunk nodes
210            for edge in graph.get_edges_ref(&current_id) {
211                let neighbor_id = if edge.src == current_id {
212                    &edge.dst
213                } else {
214                    &edge.src
215                };
216
217                if visited.contains(neighbor_id.as_str()) {
218                    continue;
219                }
220
221                // Apply relationship filter
222                if let Some(allowed) = include_relationships {
223                    if !allowed.contains(&edge.relationship) {
224                        continue;
225                    }
226                }
227
228                // Skip Chunk nodes (noisy, low-value for memory discovery)
229                if let Some(neighbor) = graph.get_node_ref(neighbor_id) {
230                    if neighbor.kind == NodeKind::Chunk {
231                        continue;
232                    }
233                }
234
235                visited.insert(neighbor_id.clone());
236                queue.push_back((
237                    neighbor_id.clone(),
238                    depth + 1,
239                    edge.relationship.to_string(),
240                ));
241            }
242        }
243
244        Ok(results)
245    }
246}