Skip to main content

codemem_engine/persistence/
mod.rs

1//! Graph persistence: persist indexing results (file/package/symbol/chunk nodes,
2//! edges, embeddings, compaction) into the storage and graph backends.
3
4mod compaction;
5
6use crate::IndexAndResolveResult;
7use codemem_core::{
8    CodememError, Edge, GraphBackend, GraphConfig, GraphNode, NodeKind, RelationshipType,
9    VectorBackend,
10};
11use std::collections::{HashMap, HashSet};
12
13/// Counts of what was persisted by `persist_index_results`.
14#[derive(Debug, Default)]
15pub struct IndexPersistResult {
16    pub files_created: usize,
17    pub packages_created: usize,
18    pub symbols_stored: usize,
19    pub chunks_stored: usize,
20    pub edges_resolved: usize,
21    pub symbols_embedded: usize,
22    pub chunks_embedded: usize,
23    pub chunks_pruned: usize,
24    pub symbols_pruned: usize,
25}
26
27/// Return the edge weight for a given relationship type, using config overrides
28/// for the three most common types (Contains, Calls, Imports).
29pub fn edge_weight_for(rel: &RelationshipType, config: &GraphConfig) -> f64 {
30    match rel {
31        RelationshipType::Calls => config.calls_edge_weight,
32        RelationshipType::Imports => config.imports_edge_weight,
33        RelationshipType::Contains => config.contains_edge_weight,
34        RelationshipType::Implements | RelationshipType::Inherits => 0.8,
35        RelationshipType::DependsOn => 0.7,
36        RelationshipType::CoChanged => 0.6,
37        RelationshipType::EvolvedInto | RelationshipType::Summarizes => 0.7,
38        RelationshipType::PartOf => 0.4,
39        RelationshipType::RelatesTo | RelationshipType::SharesTheme => 0.3,
40        _ => 0.5,
41    }
42}
43
44impl super::CodememEngine {
45    /// Persist all indexing results (file nodes, package tree, symbol nodes, chunk nodes,
46    /// edges, embeddings, compaction) into storage and the in-memory graph.
47    ///
48    /// This is the full persistence pipeline called after `Indexer::index_and_resolve()`.
49    pub fn persist_index_results(
50        &self,
51        results: &IndexAndResolveResult,
52        namespace: Option<&str>,
53    ) -> Result<IndexPersistResult, CodememError> {
54        self.persist_index_results_with_progress(results, namespace, |_, _| {})
55    }
56
57    /// Like `persist_index_results`, but calls `on_progress(done, total)` during
58    /// the embedding phase so callers can display progress.
59    pub fn persist_index_results_with_progress(
60        &self,
61        results: &IndexAndResolveResult,
62        namespace: Option<&str>,
63        on_progress: impl Fn(usize, usize),
64    ) -> Result<IndexPersistResult, CodememError> {
65        let all_symbols = &results.symbols;
66        let all_chunks = &results.chunks;
67        let seen_files = &results.file_paths;
68        let edges = &results.edges;
69
70        let now = chrono::Utc::now();
71        let ns_string = namespace.map(|s| s.to_string());
72
73        let mut graph = self.lock_graph()?;
74
75        // ── File nodes ──────────────────────────────────────────────────────
76        let mut file_nodes = Vec::with_capacity(seen_files.len());
77        for file_path in seen_files {
78            let node_id = format!("file:{file_path}");
79            let mut payload = HashMap::new();
80            payload.insert(
81                "file_path".to_string(),
82                serde_json::Value::String(file_path.clone()),
83            );
84            file_nodes.push(GraphNode {
85                id: node_id,
86                kind: NodeKind::File,
87                label: file_path.clone(),
88                payload,
89                centrality: 0.0,
90                memory_id: None,
91                namespace: ns_string.clone(),
92            });
93        }
94        let _ = self.storage.insert_graph_nodes_batch(&file_nodes);
95        for node in file_nodes {
96            let _ = graph.add_node(node);
97        }
98
99        // ── Package (directory) nodes ───────────────────────────────────────
100        let mut created_dirs: HashSet<String> = HashSet::new();
101        let mut dir_nodes = Vec::new();
102        let mut dir_edges = Vec::new();
103        for file_path in seen_files {
104            let p = std::path::Path::new(file_path);
105            let mut ancestors: Vec<String> = Vec::new();
106            let mut current = p.parent();
107            while let Some(dir) = current {
108                let dir_str = dir.to_string_lossy().to_string();
109                if dir_str.is_empty() || dir_str == "." {
110                    break;
111                }
112                ancestors.push(dir_str);
113                current = dir.parent();
114            }
115            ancestors.reverse();
116            for (i, dir_str) in ancestors.iter().enumerate() {
117                let pkg_id = format!("pkg:{dir_str}/");
118                if created_dirs.insert(pkg_id.clone()) {
119                    dir_nodes.push(GraphNode {
120                        id: pkg_id.clone(),
121                        kind: NodeKind::Package,
122                        label: format!("{dir_str}/"),
123                        payload: HashMap::new(),
124                        centrality: 0.0,
125                        memory_id: None,
126                        namespace: ns_string.clone(),
127                    });
128                }
129                // CONTAINS edge from parent dir to this dir
130                if i > 0 {
131                    let parent_pkg_id = format!("pkg:{}/", ancestors[i - 1]);
132                    let edge_id = format!("contains:{parent_pkg_id}->{pkg_id}");
133                    if !graph
134                        .get_edges(&parent_pkg_id)
135                        .unwrap_or_default()
136                        .iter()
137                        .any(|e| e.id == edge_id)
138                    {
139                        let edge = Edge {
140                            id: edge_id,
141                            src: parent_pkg_id,
142                            dst: pkg_id.clone(),
143                            relationship: RelationshipType::Contains,
144                            weight: edge_weight_for(
145                                &RelationshipType::Contains,
146                                &self.config.graph,
147                            ),
148                            valid_from: None,
149                            valid_to: None,
150                            properties: HashMap::new(),
151                            created_at: now,
152                        };
153                        dir_edges.push(edge);
154                    }
155                }
156            }
157            // CONTAINS edge from innermost directory to file
158            if let Some(last_dir) = ancestors.last() {
159                let parent_pkg_id = format!("pkg:{last_dir}/");
160                let file_node_id = format!("file:{file_path}");
161                let edge_id = format!("contains:{parent_pkg_id}->{file_node_id}");
162                dir_edges.push(Edge {
163                    id: edge_id,
164                    src: parent_pkg_id,
165                    dst: file_node_id,
166                    relationship: RelationshipType::Contains,
167                    weight: edge_weight_for(&RelationshipType::Contains, &self.config.graph),
168                    valid_from: None,
169                    valid_to: None,
170                    properties: HashMap::new(),
171                    created_at: now,
172                });
173            }
174        }
175        let _ = self.storage.insert_graph_nodes_batch(&dir_nodes);
176        for node in dir_nodes {
177            let _ = graph.add_node(node);
178        }
179        let _ = self.storage.insert_graph_edges_batch(&dir_edges);
180        for edge in dir_edges {
181            let _ = graph.add_edge(edge);
182        }
183
184        // ── Symbol nodes ────────────────────────────────────────────────────
185        let mut sym_nodes = Vec::with_capacity(all_symbols.len());
186        let mut sym_edges = Vec::with_capacity(all_symbols.len());
187        for sym in all_symbols {
188            let kind = NodeKind::from(sym.kind);
189
190            let mut payload = HashMap::new();
191            payload.insert(
192                "symbol_kind".to_string(),
193                serde_json::Value::String(sym.kind.to_string()),
194            );
195            payload.insert(
196                "signature".to_string(),
197                serde_json::Value::String(sym.signature.clone()),
198            );
199            payload.insert(
200                "file_path".to_string(),
201                serde_json::Value::String(sym.file_path.clone()),
202            );
203            payload.insert("line_start".to_string(), serde_json::json!(sym.line_start));
204            payload.insert("line_end".to_string(), serde_json::json!(sym.line_end));
205            payload.insert(
206                "visibility".to_string(),
207                serde_json::Value::String(sym.visibility.to_string()),
208            );
209            if let Some(ref doc) = sym.doc_comment {
210                payload.insert(
211                    "doc_comment".to_string(),
212                    serde_json::Value::String(doc.clone()),
213                );
214            }
215            if !sym.parameters.is_empty() {
216                payload.insert(
217                    "parameters".to_string(),
218                    serde_json::to_value(&sym.parameters).unwrap_or_default(),
219                );
220            }
221            if let Some(ref ret) = sym.return_type {
222                payload.insert(
223                    "return_type".to_string(),
224                    serde_json::Value::String(ret.clone()),
225                );
226            }
227            if sym.is_async {
228                payload.insert("is_async".to_string(), serde_json::json!(true));
229            }
230            if !sym.attributes.is_empty() {
231                payload.insert(
232                    "attributes".to_string(),
233                    serde_json::to_value(&sym.attributes).unwrap_or_default(),
234                );
235            }
236            if !sym.throws.is_empty() {
237                payload.insert(
238                    "throws".to_string(),
239                    serde_json::to_value(&sym.throws).unwrap_or_default(),
240                );
241            }
242            if let Some(ref gp) = sym.generic_params {
243                payload.insert(
244                    "generic_params".to_string(),
245                    serde_json::Value::String(gp.clone()),
246                );
247            }
248            if sym.is_abstract {
249                payload.insert("is_abstract".to_string(), serde_json::json!(true));
250            }
251            if let Some(ref parent) = sym.parent {
252                payload.insert(
253                    "parent".to_string(),
254                    serde_json::Value::String(parent.clone()),
255                );
256            }
257
258            let sym_node_id = format!("sym:{}", sym.qualified_name);
259            let node = GraphNode {
260                id: sym_node_id.clone(),
261                kind,
262                label: sym.qualified_name.clone(),
263                payload,
264                centrality: 0.0,
265                memory_id: None,
266                namespace: ns_string.clone(),
267            };
268            sym_nodes.push(node);
269
270            // CONTAINS edge: file → symbol
271            let file_node_id = format!("file:{}", sym.file_path);
272            sym_edges.push(Edge {
273                id: format!("contains:{file_node_id}->{sym_node_id}"),
274                src: file_node_id,
275                dst: sym_node_id,
276                relationship: RelationshipType::Contains,
277                weight: edge_weight_for(&RelationshipType::Contains, &self.config.graph),
278                valid_from: None,
279                valid_to: None,
280                properties: HashMap::new(),
281                created_at: now,
282            });
283        }
284        let _ = self.storage.insert_graph_nodes_batch(&sym_nodes);
285        for node in sym_nodes {
286            let _ = graph.add_node(node);
287        }
288        let _ = self.storage.insert_graph_edges_batch(&sym_edges);
289        for edge in sym_edges {
290            let _ = graph.add_edge(edge);
291        }
292
293        // ── Resolved reference edges ────────────────────────────────────────
294        let ref_edges: Vec<Edge> = edges
295            .iter()
296            .map(|edge| Edge {
297                id: format!(
298                    "ref:{}->{}:{}",
299                    edge.source_qualified_name, edge.target_qualified_name, edge.relationship
300                ),
301                src: format!("sym:{}", edge.source_qualified_name),
302                dst: format!("sym:{}", edge.target_qualified_name),
303                relationship: edge.relationship,
304                weight: edge_weight_for(&edge.relationship, &self.config.graph),
305                valid_from: None,
306                valid_to: None,
307                properties: HashMap::new(),
308                created_at: now,
309            })
310            .collect();
311        let _ = self.storage.insert_graph_edges_batch(&ref_edges);
312        for e in ref_edges {
313            let _ = graph.add_edge(e);
314        }
315
316        // ── Chunk nodes ─────────────────────────────────────────────────────
317        // Cleanup stale chunk nodes for files being re-indexed
318        for file_path in seen_files {
319            let prefix = format!("chunk:{file_path}:");
320            let _ = self.storage.delete_graph_nodes_by_prefix(&prefix);
321        }
322
323        let mut chunk_nodes = Vec::with_capacity(all_chunks.len());
324        let mut chunk_edges = Vec::with_capacity(all_chunks.len() * 2);
325        for chunk in all_chunks {
326            let chunk_id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
327
328            let mut payload = HashMap::new();
329            payload.insert(
330                "file_path".to_string(),
331                serde_json::Value::String(chunk.file_path.clone()),
332            );
333            payload.insert(
334                "line_start".to_string(),
335                serde_json::json!(chunk.line_start),
336            );
337            payload.insert("line_end".to_string(), serde_json::json!(chunk.line_end));
338            payload.insert(
339                "node_kind".to_string(),
340                serde_json::Value::String(chunk.node_kind.clone()),
341            );
342            payload.insert(
343                "non_ws_chars".to_string(),
344                serde_json::json!(chunk.non_ws_chars),
345            );
346            if let Some(ref parent) = chunk.parent_symbol {
347                payload.insert(
348                    "parent_symbol".to_string(),
349                    serde_json::Value::String(parent.clone()),
350                );
351            }
352
353            chunk_nodes.push(GraphNode {
354                id: chunk_id.clone(),
355                kind: NodeKind::Chunk,
356                label: format!(
357                    "chunk:{}:{}..{}",
358                    chunk.file_path, chunk.line_start, chunk.line_end
359                ),
360                payload,
361                centrality: 0.0,
362                memory_id: None,
363                namespace: ns_string.clone(),
364            });
365
366            // CONTAINS edge: file → chunk
367            let file_node_id = format!("file:{}", chunk.file_path);
368            chunk_edges.push(Edge {
369                id: format!("contains:{file_node_id}->{chunk_id}"),
370                src: file_node_id,
371                dst: chunk_id.clone(),
372                relationship: RelationshipType::Contains,
373                weight: edge_weight_for(&RelationshipType::Contains, &self.config.graph),
374                valid_from: None,
375                valid_to: None,
376                properties: HashMap::new(),
377                created_at: now,
378            });
379
380            // CONTAINS edge: parent symbol → chunk
381            if let Some(ref parent_sym) = chunk.parent_symbol {
382                let parent_node_id = format!("sym:{parent_sym}");
383                chunk_edges.push(Edge {
384                    id: format!("contains:{parent_node_id}->{chunk_id}"),
385                    src: parent_node_id,
386                    dst: chunk_id,
387                    relationship: RelationshipType::Contains,
388                    weight: edge_weight_for(&RelationshipType::Contains, &self.config.graph),
389                    valid_from: None,
390                    valid_to: None,
391                    properties: HashMap::new(),
392                    created_at: now,
393                });
394            }
395        }
396        let chunk_count = chunk_nodes.len();
397        let _ = self.storage.insert_graph_nodes_batch(&chunk_nodes);
398        for node in chunk_nodes {
399            let _ = graph.add_node(node);
400        }
401        let _ = self.storage.insert_graph_edges_batch(&chunk_edges);
402        for edge in chunk_edges {
403            let _ = graph.add_edge(edge);
404        }
405        drop(graph);
406
407        // ── Embed symbols and chunks ────────────────────────────────────────
408        // Phase 1: Collect enriched texts without holding any lock.
409        // enrich_symbol_text / enrich_chunk_text only read from the passed-in
410        // Symbol/CodeChunk and edges slice, so no lock is required.
411        //
412        // A12: Embedding bottleneck — The embedding provider is behind a Mutex,
413        // so `embed_batch` runs sequentially even though CPU-bound inference
414        // (Candle) could benefit from parallelism. For large codebases, this is
415        // the primary bottleneck. Potential fix: wrap the provider in an Arc and
416        // use `tokio::spawn_blocking` for CPU-bound Candle inference, or use a
417        // channel-based work queue to decouple embedding from persistence.
418        let mut symbols_embedded = 0usize;
419        let mut chunks_embedded = 0usize;
420
421        // Quick check: skip expensive text enrichment if no embedding provider.
422        // The per-batch loop also guards against provider disappearing mid-run.
423        let has_embeddings = self.lock_embeddings()?.is_some();
424        if has_embeddings {
425            let sym_texts: Vec<(String, String)> = all_symbols
426                .iter()
427                .map(|sym| {
428                    let id = format!("sym:{}", sym.qualified_name);
429                    let text = self.enrich_symbol_text(sym, edges);
430                    (id, text)
431                })
432                .collect();
433            let chunk_texts: Vec<(String, String)> = all_chunks
434                .iter()
435                .map(|chunk| {
436                    let id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
437                    let text = self.enrich_chunk_text(chunk);
438                    (id, text)
439                })
440                .collect();
441
442            // Phase 2+3: Embed in chunks and persist progressively.
443            // Instead of one giant embed_batch (which blocks with no progress for
444            // large codebases), we process in manageable chunks, persisting each
445            // batch and reporting progress.
446            //
447            // The embedding lock is acquired per-batch so that SQLite/vector
448            // writes don't hold it, and remote providers (Ollama/OpenAI) don't
449            // block other operations for the entire duration.
450            // Persistence batch size: how many items to embed + flush per round.
451            // Separate from the GPU batch size (configured on EmbeddingService).
452            const EMBED_CHUNK_SIZE: usize = 64;
453
454            // all_pairs is ordered: symbols first, then chunks (via chain).
455            // sym_count is used to attribute embedded items to the correct counter.
456            let all_pairs: Vec<(String, String)> =
457                sym_texts.into_iter().chain(chunk_texts).collect();
458            let total = all_pairs.len();
459            let sym_count = all_symbols.len();
460            let mut done = 0usize;
461
462            for batch in all_pairs.chunks(EMBED_CHUNK_SIZE) {
463                let texts: Vec<&str> = batch.iter().map(|(_, t)| t.as_str()).collect();
464
465                // Acquire embedding lock only for the embed_batch call, then drop.
466                let t0 = std::time::Instant::now();
467                let embed_result = {
468                    let emb = self.lock_embeddings()?;
469                    match emb {
470                        Some(emb_guard) => emb_guard.embed_batch(&texts),
471                        None => break, // Provider disappeared between check and use
472                    }
473                    // emb_guard dropped here — lock released before persistence I/O
474                };
475
476                match embed_result {
477                    Ok(embeddings) => {
478                        let embed_ms = t0.elapsed().as_millis();
479
480                        // Batch-store embeddings to SQLite
481                        let t1 = std::time::Instant::now();
482                        let pairs: Vec<(&str, &[f32])> = batch
483                            .iter()
484                            .zip(embeddings.iter())
485                            .map(|((id, _), emb_vec)| (id.as_str(), emb_vec.as_slice()))
486                            .collect();
487                        if let Err(e) = self.storage.store_embeddings_batch(&pairs) {
488                            tracing::warn!("Failed to batch-store embeddings: {e}");
489                        }
490                        let sqlite_ms = t1.elapsed().as_millis();
491
492                        // Batch-insert into in-memory vector index
493                        let t2 = std::time::Instant::now();
494                        let batch_items: Vec<(String, Vec<f32>)> = batch
495                            .iter()
496                            .zip(embeddings.into_iter())
497                            .map(|((id, _), emb_vec)| (id.clone(), emb_vec))
498                            .collect();
499                        let batch_len = batch_items.len();
500                        {
501                            let mut vec = self.lock_vector()?;
502                            if let Err(e) = vec.insert_batch(&batch_items) {
503                                tracing::warn!("Failed to batch-insert into vector index: {e}");
504                            }
505                        }
506                        let vector_ms = t2.elapsed().as_millis();
507
508                        // Update per-type counters using arithmetic instead of loop
509                        let syms_in_batch = batch_len.min(sym_count.saturating_sub(done));
510                        symbols_embedded += syms_in_batch;
511                        chunks_embedded += batch_len - syms_in_batch;
512                        done += batch_len;
513
514                        tracing::debug!(
515                            "Embed batch {}: embed={embed_ms}ms sqlite={sqlite_ms}ms vector={vector_ms}ms",
516                            batch_len
517                        );
518                    }
519                    Err(e) => {
520                        tracing::warn!(
521                            "embed_batch failed for chunk of {} texts: {e}",
522                            batch.len()
523                        );
524                        // Don't advance `done` — progress should reflect actual embeddings,
525                        // not failed batches. The total will no longer be reached, which
526                        // correctly signals incomplete embedding to the caller.
527                    }
528                }
529                on_progress(done, total);
530            }
531            self.save_index();
532        }
533
534        // ── Auto-compact ────────────────────────────────────────────────────
535        let (chunks_pruned, symbols_pruned) = if self.config.chunking.auto_compact {
536            self.compact_graph(seen_files)
537        } else {
538            (0, 0)
539        };
540
541        Ok(IndexPersistResult {
542            files_created: seen_files.len(),
543            packages_created: created_dirs.len(),
544            symbols_stored: all_symbols.len(),
545            chunks_stored: chunk_count,
546            edges_resolved: edges.len(),
547            symbols_embedded,
548            chunks_embedded,
549            chunks_pruned,
550            symbols_pruned,
551        })
552    }
553}