Skip to main content

codemem_engine/
graph_linking.rs

1use crate::CodememEngine;
2use codemem_core::{CodememError, Edge, MemoryNode, NodeKind, NodeMemoryResult, RelationshipType};
3use std::collections::HashSet;
4
5#[cfg(test)]
6#[path = "tests/graph_linking_tests.rs"]
7mod tests;
8
9impl CodememEngine {
10    // ── Auto-linking ─────────────────────────────────────────────────────
11
12    /// Scan memory content for file paths and qualified symbol names that exist
13    /// as graph nodes, and create RELATES_TO edges.
14    pub fn auto_link_to_code_nodes(
15        &self,
16        memory_id: &str,
17        content: &str,
18        existing_links: &[String],
19    ) -> usize {
20        let mut graph = match self.lock_graph() {
21            Ok(g) => g,
22            Err(_) => return 0,
23        };
24
25        let existing_set: HashSet<&str> = existing_links.iter().map(|s| s.as_str()).collect();
26
27        let mut candidates: Vec<String> = Vec::new();
28
29        for word in content.split_whitespace() {
30            let cleaned = word.trim_matches(|c: char| {
31                !c.is_alphanumeric() && c != '/' && c != '.' && c != '_' && c != '-' && c != ':'
32            });
33            if cleaned.is_empty() {
34                continue;
35            }
36            if cleaned.contains('/') || cleaned.contains('.') {
37                let file_id = format!("file:{cleaned}");
38                if !existing_set.contains(file_id.as_str()) {
39                    candidates.push(file_id);
40                }
41            }
42            if cleaned.contains("::") {
43                let sym_id = format!("sym:{cleaned}");
44                if !existing_set.contains(sym_id.as_str()) {
45                    candidates.push(sym_id);
46                }
47            }
48        }
49
50        let now = chrono::Utc::now();
51        let mut created = 0;
52        let mut seen = HashSet::new();
53
54        for candidate_id in &candidates {
55            if !seen.insert(candidate_id.clone()) {
56                continue;
57            }
58            match graph.get_node(candidate_id).ok().flatten() {
59                // Skip if node doesn't exist or is expired
60                None => continue,
61                Some(n) if n.valid_to.is_some_and(|vt| vt <= now) => continue,
62                _ => {}
63            }
64            let edge = Edge {
65                id: format!("{memory_id}-RELATES_TO-{candidate_id}"),
66                src: memory_id.to_string(),
67                dst: candidate_id.clone(),
68                relationship: RelationshipType::RelatesTo,
69                weight: 0.5,
70                properties: std::collections::HashMap::from([(
71                    "auto_linked".to_string(),
72                    serde_json::json!(true),
73                )]),
74                created_at: now,
75                valid_from: None,
76                valid_to: None,
77            };
78            if self.storage.insert_graph_edge(&edge).is_ok() && graph.add_edge(edge).is_ok() {
79                created += 1;
80            }
81        }
82
83        created
84    }
85
86    // ── Tag-based Auto-linking ──────────────────────────────────────────
87
88    /// Create edges between this memory and other memories that share tags.
89    /// - `session:*` tags → PRECEDED_BY edges (temporal ordering within a session)
90    /// - Other shared tags → SHARES_THEME edges (topical overlap)
91    ///
92    /// This runs during `persist_memory` so the graph builds connectivity at
93    /// ingestion time, rather than relying solely on creative consolidation.
94    pub fn auto_link_by_tags(&self, memory: &MemoryNode) {
95        if memory.tags.is_empty() {
96            return;
97        }
98
99        // Phase 1: Collect sibling IDs and build edges WITHOUT holding the graph lock.
100        let now = chrono::Utc::now();
101        let mut linked = HashSet::new();
102        let mut edges_to_add = Vec::new();
103
104        for tag in &memory.tags {
105            let is_session_tag = tag.starts_with("session:");
106
107            let sibling_ids = match self.storage.find_memory_ids_by_tag(
108                tag,
109                memory.namespace.as_deref(),
110                &memory.id,
111            ) {
112                Ok(ids) => ids,
113                Err(_) => continue,
114            };
115
116            for sibling_id in sibling_ids {
117                if !linked.insert(sibling_id.clone()) {
118                    continue;
119                }
120
121                let (relationship, edge_label) = if is_session_tag {
122                    (RelationshipType::PrecededBy, "PRECEDED_BY")
123                } else {
124                    (RelationshipType::SharesTheme, "SHARES_THEME")
125                };
126
127                let edge_id = format!("{}-{edge_label}-{sibling_id}", memory.id);
128                edges_to_add.push(Edge {
129                    id: edge_id,
130                    src: sibling_id,
131                    dst: memory.id.clone(),
132                    relationship,
133                    weight: if is_session_tag { 0.8 } else { 0.5 },
134                    properties: std::collections::HashMap::from([(
135                        "auto_linked".to_string(),
136                        serde_json::json!(true),
137                    )]),
138                    created_at: now,
139                    valid_from: Some(now),
140                    valid_to: None,
141                });
142            }
143        }
144
145        if edges_to_add.is_empty() {
146            return;
147        }
148
149        // Phase 2: Acquire graph lock only for mutations.
150        let mut graph = match self.lock_graph() {
151            Ok(g) => g,
152            Err(_) => return,
153        };
154
155        for edge in edges_to_add {
156            if self.storage.insert_graph_edge(&edge).is_ok() {
157                let _ = graph.add_edge(edge);
158            }
159        }
160    }
161
162    // ── Node Memory Queries ──────────────────────────────────────────────
163
164    /// Retrieve all memories connected to a graph node via BFS traversal.
165    ///
166    /// Performs level-by-level BFS to track actual hop distance. For each
167    /// Memory node found, reports the relationship type from the edge that
168    /// connected it (or the edge leading into the path toward it).
169    pub fn get_node_memories(
170        &self,
171        node_id: &str,
172        max_depth: usize,
173        include_relationships: Option<&[RelationshipType]>,
174    ) -> Result<Vec<NodeMemoryResult>, CodememError> {
175        let graph = self.lock_graph()?;
176
177        // Manual BFS tracking (node_id, depth, relationship_from_parent_edge)
178        let mut results: Vec<NodeMemoryResult> = Vec::new();
179        let mut seen_memory_ids = HashSet::new();
180        let mut visited = HashSet::new();
181        let mut queue: std::collections::VecDeque<(String, usize, String)> =
182            std::collections::VecDeque::new();
183
184        visited.insert(node_id.to_string());
185        queue.push_back((node_id.to_string(), 0, String::new()));
186
187        while let Some((current_id, depth, parent_rel)) = queue.pop_front() {
188            // Collect Memory nodes (skip the start node itself)
189            if current_id != node_id {
190                if let Some(node) = graph.get_node_ref(&current_id) {
191                    if node.kind == NodeKind::Memory {
192                        let memory_id = node.memory_id.as_deref().unwrap_or(&node.id);
193                        if seen_memory_ids.insert(memory_id.to_string()) {
194                            if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
195                                results.push(NodeMemoryResult {
196                                    memory,
197                                    relationship: parent_rel.clone(),
198                                    depth,
199                                });
200                            }
201                        }
202                    }
203                }
204            }
205
206            if depth >= max_depth {
207                continue;
208            }
209
210            // Expand neighbors via edges, skipping Chunk nodes
211            for edge in graph.get_edges_ref(&current_id) {
212                let neighbor_id = if edge.src == current_id {
213                    &edge.dst
214                } else {
215                    &edge.src
216                };
217
218                if visited.contains(neighbor_id.as_str()) {
219                    continue;
220                }
221
222                // Apply relationship filter
223                if let Some(allowed) = include_relationships {
224                    if !allowed.contains(&edge.relationship) {
225                        continue;
226                    }
227                }
228
229                // Skip Chunk nodes (noisy, low-value for memory discovery)
230                if let Some(neighbor) = graph.get_node_ref(neighbor_id) {
231                    if neighbor.kind == NodeKind::Chunk {
232                        continue;
233                    }
234                }
235
236                visited.insert(neighbor_id.clone());
237                queue.push_back((
238                    neighbor_id.clone(),
239                    depth + 1,
240                    edge.relationship.to_string(),
241                ));
242            }
243        }
244
245        Ok(results)
246    }
247}