Skip to main content

codemem_engine/
graph_linking.rs

1use crate::CodememEngine;
2use codemem_core::{CodememError, Edge, MemoryNode, NodeKind, NodeMemoryResult, RelationshipType};
3use std::collections::HashSet;
4
5#[cfg(test)]
6#[path = "tests/graph_linking_tests.rs"]
7mod tests;
8
9impl CodememEngine {
10    // ── Auto-linking ─────────────────────────────────────────────────────
11
12    /// Scan memory content for file paths and qualified symbol names that exist
13    /// as graph nodes, and create RELATES_TO edges.
14    pub fn auto_link_to_code_nodes(
15        &self,
16        memory_id: &str,
17        content: &str,
18        existing_links: &[String],
19    ) -> usize {
20        let mut graph = match self.lock_graph() {
21            Ok(g) => g,
22            Err(_) => return 0,
23        };
24
25        let existing_set: HashSet<&str> = existing_links.iter().map(|s| s.as_str()).collect();
26
27        let mut candidates: Vec<String> = Vec::new();
28
29        for word in content.split_whitespace() {
30            let cleaned = word.trim_matches(|c: char| {
31                !c.is_alphanumeric() && c != '/' && c != '.' && c != '_' && c != '-' && c != ':'
32            });
33            if cleaned.is_empty() {
34                continue;
35            }
36            if cleaned.contains('/') || cleaned.contains('.') {
37                let file_id = format!("file:{cleaned}");
38                if !existing_set.contains(file_id.as_str()) {
39                    candidates.push(file_id);
40                }
41            }
42            if cleaned.contains("::") {
43                let sym_id = format!("sym:{cleaned}");
44                if !existing_set.contains(sym_id.as_str()) {
45                    candidates.push(sym_id);
46                }
47            }
48        }
49
50        let now = chrono::Utc::now();
51        let mut created = 0;
52        let mut seen = HashSet::new();
53
54        for candidate_id in &candidates {
55            if !seen.insert(candidate_id.clone()) {
56                continue;
57            }
58            if graph.get_node(candidate_id).ok().flatten().is_none() {
59                continue;
60            }
61            let edge = Edge {
62                id: format!("{memory_id}-RELATES_TO-{candidate_id}"),
63                src: memory_id.to_string(),
64                dst: candidate_id.clone(),
65                relationship: RelationshipType::RelatesTo,
66                weight: 0.5,
67                properties: std::collections::HashMap::from([(
68                    "auto_linked".to_string(),
69                    serde_json::json!(true),
70                )]),
71                created_at: now,
72                valid_from: None,
73                valid_to: None,
74            };
75            if self.storage.insert_graph_edge(&edge).is_ok() && graph.add_edge(edge).is_ok() {
76                created += 1;
77            }
78        }
79
80        created
81    }
82
83    // ── Tag-based Auto-linking ──────────────────────────────────────────
84
85    /// Create edges between this memory and other memories that share tags.
86    /// - `session:*` tags → PRECEDED_BY edges (temporal ordering within a session)
87    /// - Other shared tags → SHARES_THEME edges (topical overlap)
88    ///
89    /// This runs during `persist_memory` so the graph builds connectivity at
90    /// ingestion time, rather than relying solely on creative consolidation.
91    pub fn auto_link_by_tags(&self, memory: &MemoryNode) {
92        if memory.tags.is_empty() {
93            return;
94        }
95
96        // Phase 1: Collect sibling IDs and build edges WITHOUT holding the graph lock.
97        let now = chrono::Utc::now();
98        let mut linked = HashSet::new();
99        let mut edges_to_add = Vec::new();
100
101        for tag in &memory.tags {
102            let is_session_tag = tag.starts_with("session:");
103
104            let sibling_ids = match self.storage.find_memory_ids_by_tag(
105                tag,
106                memory.namespace.as_deref(),
107                &memory.id,
108            ) {
109                Ok(ids) => ids,
110                Err(_) => continue,
111            };
112
113            for sibling_id in sibling_ids {
114                if !linked.insert(sibling_id.clone()) {
115                    continue;
116                }
117
118                let (relationship, edge_label) = if is_session_tag {
119                    (RelationshipType::PrecededBy, "PRECEDED_BY")
120                } else {
121                    (RelationshipType::SharesTheme, "SHARES_THEME")
122                };
123
124                let edge_id = format!("{}-{edge_label}-{sibling_id}", memory.id);
125                edges_to_add.push(Edge {
126                    id: edge_id,
127                    src: sibling_id,
128                    dst: memory.id.clone(),
129                    relationship,
130                    weight: if is_session_tag { 0.8 } else { 0.5 },
131                    properties: std::collections::HashMap::from([(
132                        "auto_linked".to_string(),
133                        serde_json::json!(true),
134                    )]),
135                    created_at: now,
136                    valid_from: Some(now),
137                    valid_to: None,
138                });
139            }
140        }
141
142        if edges_to_add.is_empty() {
143            return;
144        }
145
146        // Phase 2: Acquire graph lock only for mutations.
147        let mut graph = match self.lock_graph() {
148            Ok(g) => g,
149            Err(_) => return,
150        };
151
152        for edge in edges_to_add {
153            if self.storage.insert_graph_edge(&edge).is_ok() {
154                let _ = graph.add_edge(edge);
155            }
156        }
157    }
158
159    // ── Node Memory Queries ──────────────────────────────────────────────
160
161    /// Retrieve all memories connected to a graph node via BFS traversal.
162    ///
163    /// Performs level-by-level BFS to track actual hop distance. For each
164    /// Memory node found, reports the relationship type from the edge that
165    /// connected it (or the edge leading into the path toward it).
166    pub fn get_node_memories(
167        &self,
168        node_id: &str,
169        max_depth: usize,
170        include_relationships: Option<&[RelationshipType]>,
171    ) -> Result<Vec<NodeMemoryResult>, CodememError> {
172        let graph = self.lock_graph()?;
173
174        // Manual BFS tracking (node_id, depth, relationship_from_parent_edge)
175        let mut results: Vec<NodeMemoryResult> = Vec::new();
176        let mut seen_memory_ids = HashSet::new();
177        let mut visited = HashSet::new();
178        let mut queue: std::collections::VecDeque<(String, usize, String)> =
179            std::collections::VecDeque::new();
180
181        visited.insert(node_id.to_string());
182        queue.push_back((node_id.to_string(), 0, String::new()));
183
184        while let Some((current_id, depth, parent_rel)) = queue.pop_front() {
185            // Collect Memory nodes (skip the start node itself)
186            if current_id != node_id {
187                if let Some(node) = graph.get_node_ref(&current_id) {
188                    if node.kind == NodeKind::Memory {
189                        let memory_id = node.memory_id.as_deref().unwrap_or(&node.id);
190                        if seen_memory_ids.insert(memory_id.to_string()) {
191                            if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
192                                results.push(NodeMemoryResult {
193                                    memory,
194                                    relationship: parent_rel.clone(),
195                                    depth,
196                                });
197                            }
198                        }
199                    }
200                }
201            }
202
203            if depth >= max_depth {
204                continue;
205            }
206
207            // Expand neighbors via edges, skipping Chunk nodes
208            for edge in graph.get_edges_ref(&current_id) {
209                let neighbor_id = if edge.src == current_id {
210                    &edge.dst
211                } else {
212                    &edge.src
213                };
214
215                if visited.contains(neighbor_id.as_str()) {
216                    continue;
217                }
218
219                // Apply relationship filter
220                if let Some(allowed) = include_relationships {
221                    if !allowed.contains(&edge.relationship) {
222                        continue;
223                    }
224                }
225
226                // Skip Chunk nodes (noisy, low-value for memory discovery)
227                if let Some(neighbor) = graph.get_node_ref(neighbor_id) {
228                    if neighbor.kind == NodeKind::Chunk {
229                        continue;
230                    }
231                }
232
233                visited.insert(neighbor_id.clone());
234                queue.push_back((
235                    neighbor_id.clone(),
236                    depth + 1,
237                    edge.relationship.to_string(),
238                ));
239            }
240        }
241
242        Ok(results)
243    }
244}