Skip to main content

codemem_engine/persistence/
mod.rs

1//! Graph persistence: persist indexing results (file/package/symbol/chunk nodes,
2//! edges, embeddings, compaction) into the storage and graph backends.
3
4mod compaction;
5
6use crate::index::{CodeChunk, ResolvedEdge, Symbol};
7use crate::IndexAndResolveResult;
8use codemem_core::{
9    CodememError, Edge, GraphBackend, GraphConfig, GraphNode, NodeKind, RelationshipType,
10    VectorBackend,
11};
12use std::collections::{HashMap, HashSet};
13
14/// Counts of what was persisted by `persist_index_results`.
15#[derive(Debug, Default)]
16pub struct IndexPersistResult {
17    pub files_created: usize,
18    pub packages_created: usize,
19    pub symbols_stored: usize,
20    pub chunks_stored: usize,
21    pub edges_resolved: usize,
22    pub symbols_embedded: usize,
23    pub chunks_embedded: usize,
24    pub chunks_pruned: usize,
25    pub symbols_pruned: usize,
26}
27
28/// Return the edge weight for a given relationship type, using config overrides
29/// for the three most common types (Contains, Calls, Imports).
30pub fn edge_weight_for(rel: &RelationshipType, config: &GraphConfig) -> f64 {
31    match rel {
32        RelationshipType::Calls => config.calls_edge_weight,
33        RelationshipType::Imports => config.imports_edge_weight,
34        RelationshipType::Contains => config.contains_edge_weight,
35        RelationshipType::Implements | RelationshipType::Inherits => 0.8,
36        RelationshipType::DependsOn => 0.7,
37        RelationshipType::CoChanged => 0.6,
38        RelationshipType::EvolvedInto | RelationshipType::Summarizes => 0.7,
39        RelationshipType::PartOf => 0.4,
40        RelationshipType::RelatesTo | RelationshipType::SharesTheme => 0.3,
41        _ => 0.5,
42    }
43}
44
45/// Intermediate counts from graph node persistence (before embedding).
46struct GraphPersistCounts {
47    packages_created: usize,
48    chunks_stored: usize,
49}
50
51impl super::CodememEngine {
52    /// Persist all indexing results (file nodes, package tree, symbol nodes, chunk nodes,
53    /// edges, embeddings, compaction) into storage and the in-memory graph.
54    ///
55    /// This is the full persistence pipeline called after `Indexer::index_and_resolve()`.
56    pub fn persist_index_results(
57        &self,
58        results: &IndexAndResolveResult,
59        namespace: Option<&str>,
60    ) -> Result<IndexPersistResult, CodememError> {
61        self.persist_index_results_with_progress(results, namespace, |_, _| {})
62    }
63
64    /// Like `persist_index_results`, but calls `on_progress(done, total)` during
65    /// the embedding phase so callers can display progress.
66    pub fn persist_index_results_with_progress(
67        &self,
68        results: &IndexAndResolveResult,
69        namespace: Option<&str>,
70        on_progress: impl Fn(usize, usize),
71    ) -> Result<IndexPersistResult, CodememError> {
72        let seen_files = &results.file_paths;
73
74        // 1. Persist all graph nodes and edges
75        let graph_counts = self.persist_graph_nodes(results, namespace)?;
76
77        // 2. Embed symbols and chunks
78        let (symbols_embedded, chunks_embedded) = self.embed_and_persist(
79            &results.symbols,
80            &results.chunks,
81            &results.edges,
82            on_progress,
83        )?;
84
85        // 3. Auto-compact
86        let (chunks_pruned, symbols_pruned) = if self.config.chunking.auto_compact {
87            self.compact_graph(seen_files)
88        } else {
89            (0, 0)
90        };
91
92        Ok(IndexPersistResult {
93            files_created: seen_files.len(),
94            packages_created: graph_counts.packages_created,
95            symbols_stored: results.symbols.len(),
96            chunks_stored: graph_counts.chunks_stored,
97            edges_resolved: results.edges.len(),
98            symbols_embedded,
99            chunks_embedded,
100            chunks_pruned,
101            symbols_pruned,
102        })
103    }
104
105    // ── Graph Node Persistence ───────────────────────────────────────────
106
107    /// Persist file, package, symbol, chunk nodes and all edges into storage
108    /// and the in-memory graph. Returns counts for the result struct.
109    fn persist_graph_nodes(
110        &self,
111        results: &IndexAndResolveResult,
112        namespace: Option<&str>,
113    ) -> Result<GraphPersistCounts, CodememError> {
114        let all_symbols = &results.symbols;
115        let all_chunks = &results.chunks;
116        let seen_files = &results.file_paths;
117        let edges = &results.edges;
118
119        let now = chrono::Utc::now();
120        let ns_string = namespace.map(|s| s.to_string());
121        let contains_weight = edge_weight_for(&RelationshipType::Contains, &self.config.graph);
122
123        let mut graph = self.lock_graph()?;
124
125        // ── File nodes
126        let file_nodes: Vec<GraphNode> = seen_files
127            .iter()
128            .map(|file_path| {
129                let mut payload = HashMap::new();
130                payload.insert(
131                    "file_path".to_string(),
132                    serde_json::Value::String(file_path.clone()),
133                );
134                GraphNode {
135                    id: format!("file:{file_path}"),
136                    kind: NodeKind::File,
137                    label: file_path.clone(),
138                    payload,
139                    centrality: 0.0,
140                    memory_id: None,
141                    namespace: ns_string.clone(),
142                }
143            })
144            .collect();
145        self.persist_nodes_to_storage_and_graph(&file_nodes, &mut graph);
146
147        // ── Package (directory) nodes
148        let (dir_nodes, dir_edges, created_dirs) =
149            self.build_package_tree(seen_files, &ns_string, contains_weight, now, &graph);
150        self.persist_nodes_to_storage_and_graph(&dir_nodes, &mut graph);
151        self.persist_edges_to_storage_and_graph(&dir_edges, &mut graph);
152
153        // ── Symbol nodes + file→symbol edges
154        let (sym_nodes, sym_edges) =
155            Self::build_symbol_nodes(all_symbols, &ns_string, contains_weight, now);
156
157        // Clean up stale symbols: single pass over in-memory graph to collect
158        // existing symbols grouped by file, then diff against new parse results.
159        //
160        // Lock protocol: We collect old symbols while holding the graph lock,
161        // then drop it so `cleanup_stale_symbols` can acquire graph + vector
162        // locks internally. The re-acquire below is safe: cleanup only removes
163        // stale nodes that won't conflict with the inserts that follow.
164        let mut old_syms_by_file: HashMap<String, HashSet<String>> = HashMap::new();
165        for node in graph.get_all_nodes() {
166            if !node.id.starts_with("sym:") {
167                continue;
168            }
169            let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) else {
170                continue;
171            };
172            if !seen_files.contains(fp) {
173                continue;
174            }
175            old_syms_by_file
176                .entry(fp.to_string())
177                .or_default()
178                .insert(node.id);
179        }
180        drop(graph);
181        for file_path in seen_files {
182            let new_sym_ids: HashSet<String> = sym_nodes
183                .iter()
184                .filter(|n| {
185                    n.payload.get("file_path").and_then(|v| v.as_str()) == Some(file_path.as_str())
186                })
187                .map(|n| n.id.clone())
188                .collect();
189            let empty = HashSet::new();
190            let old_sym_ids = old_syms_by_file.get(file_path).unwrap_or(&empty);
191            if let Err(e) = self.cleanup_stale_symbols(file_path, old_sym_ids, &new_sym_ids) {
192                tracing::warn!("Failed to cleanup stale symbols for {file_path}: {e}");
193            }
194        }
195        let mut graph = self.lock_graph()?; // Re-acquire lock
196
197        self.persist_nodes_to_storage_and_graph(&sym_nodes, &mut graph);
198        self.persist_edges_to_storage_and_graph(&sym_edges, &mut graph);
199
200        // ── Resolved reference edges
201        let ref_edges = Self::build_reference_edges(edges, &self.config.graph, now);
202        self.persist_edges_to_storage_and_graph(&ref_edges, &mut graph);
203
204        // ── Chunk nodes + file→chunk / symbol→chunk edges
205        for file_path in seen_files {
206            let prefix = format!("chunk:{file_path}:");
207            let _ = self.storage.delete_graph_nodes_by_prefix(&prefix);
208        }
209        let (chunk_nodes, chunk_edges) =
210            Self::build_chunk_nodes(all_chunks, &ns_string, contains_weight, now);
211        let chunk_count = chunk_nodes.len();
212        self.persist_nodes_to_storage_and_graph(&chunk_nodes, &mut graph);
213        self.persist_edges_to_storage_and_graph(&chunk_edges, &mut graph);
214
215        drop(graph);
216
217        Ok(GraphPersistCounts {
218            packages_created: created_dirs,
219            chunks_stored: chunk_count,
220        })
221    }
222
223    /// Batch-insert nodes into both SQLite and the in-memory graph.
224    fn persist_nodes_to_storage_and_graph(
225        &self,
226        nodes: &[GraphNode],
227        graph: &mut crate::GraphEngine,
228    ) {
229        let _ = self.storage.insert_graph_nodes_batch(nodes);
230        for node in nodes {
231            let _ = graph.add_node(node.clone());
232        }
233    }
234
235    /// Batch-insert edges into both SQLite and the in-memory graph.
236    fn persist_edges_to_storage_and_graph(&self, edges: &[Edge], graph: &mut crate::GraphEngine) {
237        let _ = self.storage.insert_graph_edges_batch(edges);
238        for edge in edges {
239            let _ = graph.add_edge(edge.clone());
240        }
241    }
242
243    /// Build directory/package nodes and CONTAINS edges from file paths.
244    /// Returns (nodes, edges, number_of_dirs_created).
245    fn build_package_tree(
246        &self,
247        seen_files: &HashSet<String>,
248        ns_string: &Option<String>,
249        contains_weight: f64,
250        now: chrono::DateTime<chrono::Utc>,
251        graph: &crate::GraphEngine,
252    ) -> (Vec<GraphNode>, Vec<Edge>, usize) {
253        let mut created_dirs: HashSet<String> = HashSet::new();
254        let mut dir_nodes = Vec::new();
255        let mut dir_edges = Vec::new();
256
257        for file_path in seen_files {
258            let p = std::path::Path::new(file_path);
259            let mut ancestors: Vec<String> = Vec::new();
260            let mut current = p.parent();
261            while let Some(dir) = current {
262                let dir_str = dir.to_string_lossy().to_string();
263                if dir_str.is_empty() || dir_str == "." {
264                    break;
265                }
266                ancestors.push(dir_str);
267                current = dir.parent();
268            }
269            ancestors.reverse();
270            for (i, dir_str) in ancestors.iter().enumerate() {
271                let pkg_id = format!("pkg:{dir_str}/");
272                if created_dirs.insert(pkg_id.clone()) {
273                    dir_nodes.push(GraphNode {
274                        id: pkg_id.clone(),
275                        kind: NodeKind::Package,
276                        label: format!("{dir_str}/"),
277                        payload: HashMap::new(),
278                        centrality: 0.0,
279                        memory_id: None,
280                        namespace: ns_string.clone(),
281                    });
282                }
283                if i == 0 {
284                    continue;
285                }
286                let parent_pkg_id = format!("pkg:{}/", ancestors[i - 1]);
287                let edge_id = format!("contains:{parent_pkg_id}->{pkg_id}");
288                if graph
289                    .get_edges(&parent_pkg_id)
290                    .unwrap_or_default()
291                    .iter()
292                    .any(|e| e.id == edge_id)
293                {
294                    continue;
295                }
296                dir_edges.push(Edge {
297                    id: edge_id,
298                    src: parent_pkg_id,
299                    dst: pkg_id.clone(),
300                    relationship: RelationshipType::Contains,
301                    weight: contains_weight,
302                    valid_from: Some(now),
303                    valid_to: None,
304                    properties: HashMap::new(),
305                    created_at: now,
306                });
307            }
308            if let Some(last_dir) = ancestors.last() {
309                let parent_pkg_id = format!("pkg:{last_dir}/");
310                let file_node_id = format!("file:{file_path}");
311                let edge_id = format!("contains:{parent_pkg_id}->{file_node_id}");
312                dir_edges.push(Edge {
313                    id: edge_id,
314                    src: parent_pkg_id,
315                    dst: file_node_id,
316                    relationship: RelationshipType::Contains,
317                    weight: contains_weight,
318                    valid_from: Some(now),
319                    valid_to: None,
320                    properties: HashMap::new(),
321                    created_at: now,
322                });
323            }
324        }
325
326        let count = created_dirs.len();
327        (dir_nodes, dir_edges, count)
328    }
329
330    /// Build symbol graph nodes and file→symbol CONTAINS edges.
331    fn build_symbol_nodes(
332        symbols: &[Symbol],
333        ns_string: &Option<String>,
334        contains_weight: f64,
335        now: chrono::DateTime<chrono::Utc>,
336    ) -> (Vec<GraphNode>, Vec<Edge>) {
337        let mut sym_nodes = Vec::with_capacity(symbols.len());
338        let mut sym_edges = Vec::with_capacity(symbols.len());
339
340        for sym in symbols {
341            let kind = NodeKind::from(sym.kind);
342            let payload = Self::build_symbol_payload(sym);
343
344            let sym_node_id = format!("sym:{}", sym.qualified_name);
345            sym_nodes.push(GraphNode {
346                id: sym_node_id.clone(),
347                kind,
348                label: sym.qualified_name.clone(),
349                payload,
350                centrality: 0.0,
351                memory_id: None,
352                namespace: ns_string.clone(),
353            });
354
355            let file_node_id = format!("file:{}", sym.file_path);
356            sym_edges.push(Edge {
357                id: format!("contains:{file_node_id}->{sym_node_id}"),
358                src: file_node_id,
359                dst: sym_node_id,
360                relationship: RelationshipType::Contains,
361                weight: contains_weight,
362                valid_from: Some(now),
363                valid_to: None,
364                properties: HashMap::new(),
365                created_at: now,
366            });
367        }
368
369        (sym_nodes, sym_edges)
370    }
371
372    /// Build the payload HashMap for a symbol's graph node.
373    fn build_symbol_payload(sym: &Symbol) -> HashMap<String, serde_json::Value> {
374        let mut payload = HashMap::new();
375        payload.insert(
376            "symbol_kind".to_string(),
377            serde_json::Value::String(sym.kind.to_string()),
378        );
379        payload.insert(
380            "signature".to_string(),
381            serde_json::Value::String(sym.signature.clone()),
382        );
383        payload.insert(
384            "file_path".to_string(),
385            serde_json::Value::String(sym.file_path.clone()),
386        );
387        payload.insert("line_start".to_string(), serde_json::json!(sym.line_start));
388        payload.insert("line_end".to_string(), serde_json::json!(sym.line_end));
389        payload.insert(
390            "visibility".to_string(),
391            serde_json::Value::String(sym.visibility.to_string()),
392        );
393        if let Some(ref doc) = sym.doc_comment {
394            payload.insert(
395                "doc_comment".to_string(),
396                serde_json::Value::String(doc.clone()),
397            );
398        }
399        if !sym.parameters.is_empty() {
400            payload.insert(
401                "parameters".to_string(),
402                serde_json::to_value(&sym.parameters).unwrap_or_default(),
403            );
404        }
405        if let Some(ref ret) = sym.return_type {
406            payload.insert(
407                "return_type".to_string(),
408                serde_json::Value::String(ret.clone()),
409            );
410        }
411        if sym.is_async {
412            payload.insert("is_async".to_string(), serde_json::json!(true));
413        }
414        if !sym.attributes.is_empty() {
415            payload.insert(
416                "attributes".to_string(),
417                serde_json::to_value(&sym.attributes).unwrap_or_default(),
418            );
419        }
420        if !sym.throws.is_empty() {
421            payload.insert(
422                "throws".to_string(),
423                serde_json::to_value(&sym.throws).unwrap_or_default(),
424            );
425        }
426        if let Some(ref gp) = sym.generic_params {
427            payload.insert(
428                "generic_params".to_string(),
429                serde_json::Value::String(gp.clone()),
430            );
431        }
432        if sym.is_abstract {
433            payload.insert("is_abstract".to_string(), serde_json::json!(true));
434        }
435        if let Some(ref parent) = sym.parent {
436            payload.insert(
437                "parent".to_string(),
438                serde_json::Value::String(parent.clone()),
439            );
440        }
441        payload
442    }
443
444    /// Build edges from resolved cross-file references.
445    fn build_reference_edges(
446        edges: &[ResolvedEdge],
447        graph_config: &GraphConfig,
448        now: chrono::DateTime<chrono::Utc>,
449    ) -> Vec<Edge> {
450        edges
451            .iter()
452            .map(|edge| Edge {
453                id: format!(
454                    "ref:{}->{}:{}",
455                    edge.source_qualified_name, edge.target_qualified_name, edge.relationship
456                ),
457                src: format!("sym:{}", edge.source_qualified_name),
458                dst: format!("sym:{}", edge.target_qualified_name),
459                relationship: edge.relationship,
460                weight: edge_weight_for(&edge.relationship, graph_config),
461                valid_from: Some(now),
462                valid_to: None,
463                properties: HashMap::new(),
464                created_at: now,
465            })
466            .collect()
467    }
468
469    /// Build chunk graph nodes and file→chunk / symbol→chunk CONTAINS edges.
470    fn build_chunk_nodes(
471        chunks: &[CodeChunk],
472        ns_string: &Option<String>,
473        contains_weight: f64,
474        now: chrono::DateTime<chrono::Utc>,
475    ) -> (Vec<GraphNode>, Vec<Edge>) {
476        let mut chunk_nodes = Vec::with_capacity(chunks.len());
477        let mut chunk_edges = Vec::with_capacity(chunks.len() * 2);
478
479        for chunk in chunks {
480            let chunk_id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
481
482            let mut payload = HashMap::new();
483            payload.insert(
484                "file_path".to_string(),
485                serde_json::Value::String(chunk.file_path.clone()),
486            );
487            payload.insert(
488                "line_start".to_string(),
489                serde_json::json!(chunk.line_start),
490            );
491            payload.insert("line_end".to_string(), serde_json::json!(chunk.line_end));
492            payload.insert(
493                "node_kind".to_string(),
494                serde_json::Value::String(chunk.node_kind.clone()),
495            );
496            payload.insert(
497                "non_ws_chars".to_string(),
498                serde_json::json!(chunk.non_ws_chars),
499            );
500            if let Some(ref parent) = chunk.parent_symbol {
501                payload.insert(
502                    "parent_symbol".to_string(),
503                    serde_json::Value::String(parent.clone()),
504                );
505            }
506
507            chunk_nodes.push(GraphNode {
508                id: chunk_id.clone(),
509                kind: NodeKind::Chunk,
510                label: format!(
511                    "chunk:{}:{}..{}",
512                    chunk.file_path, chunk.line_start, chunk.line_end
513                ),
514                payload,
515                centrality: 0.0,
516                memory_id: None,
517                namespace: ns_string.clone(),
518            });
519
520            let file_node_id = format!("file:{}", chunk.file_path);
521            chunk_edges.push(Edge {
522                id: format!("contains:{file_node_id}->{chunk_id}"),
523                src: file_node_id,
524                dst: chunk_id.clone(),
525                relationship: RelationshipType::Contains,
526                weight: contains_weight,
527                valid_from: Some(now),
528                valid_to: None,
529                properties: HashMap::new(),
530                created_at: now,
531            });
532
533            if let Some(ref parent_sym) = chunk.parent_symbol {
534                let parent_node_id = format!("sym:{parent_sym}");
535                chunk_edges.push(Edge {
536                    id: format!("contains:{parent_node_id}->{chunk_id}"),
537                    src: parent_node_id,
538                    dst: chunk_id,
539                    relationship: RelationshipType::Contains,
540                    weight: contains_weight,
541                    valid_from: Some(now),
542                    valid_to: None,
543                    properties: HashMap::new(),
544                    created_at: now,
545                });
546            }
547        }
548
549        (chunk_nodes, chunk_edges)
550    }
551
552    // ── Embedding Persistence ────────────────────────────────────────────
553
554    /// Embed symbols and chunks, persisting embeddings to SQLite and the
555    /// vector index in batches with progress reporting.
556    ///
557    /// Returns (symbols_embedded, chunks_embedded).
558    fn embed_and_persist(
559        &self,
560        symbols: &[Symbol],
561        chunks: &[CodeChunk],
562        edges: &[ResolvedEdge],
563        on_progress: impl Fn(usize, usize),
564    ) -> Result<(usize, usize), CodememError> {
565        let mut symbols_embedded = 0usize;
566        let mut chunks_embedded = 0usize;
567
568        // Quick check: skip expensive text enrichment if embedding provider isn't loaded.
569        // This avoids triggering lazy init during lightweight operations (hooks).
570        if !self.embeddings_ready() {
571            return Ok((0, 0));
572        }
573
574        // Phase 1: Collect enriched texts without holding any lock.
575        let sym_texts: Vec<(String, String)> = symbols
576            .iter()
577            .map(|sym| {
578                let id = format!("sym:{}", sym.qualified_name);
579                let text = self.enrich_symbol_text(sym, edges);
580                (id, text)
581            })
582            .collect();
583        let chunk_texts: Vec<(String, String)> = chunks
584            .iter()
585            .map(|chunk| {
586                let id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
587                let text = self.enrich_chunk_text(chunk);
588                (id, text)
589            })
590            .collect();
591
592        // Phase 2+3: Embed in batches and persist progressively.
593        const EMBED_CHUNK_SIZE: usize = 64;
594
595        let all_pairs: Vec<(String, String)> = sym_texts.into_iter().chain(chunk_texts).collect();
596        let total = all_pairs.len();
597        let sym_count = symbols.len();
598        let mut done = 0usize;
599
600        for batch in all_pairs.chunks(EMBED_CHUNK_SIZE) {
601            let texts: Vec<&str> = batch.iter().map(|(_, t)| t.as_str()).collect();
602
603            let t0 = std::time::Instant::now();
604            let embed_result = {
605                let emb = self.lock_embeddings()?;
606                match emb {
607                    Some(emb_guard) => emb_guard.embed_batch(&texts),
608                    None => break,
609                }
610            };
611
612            match embed_result {
613                Ok(embeddings) => {
614                    let embed_ms = t0.elapsed().as_millis();
615
616                    let t1 = std::time::Instant::now();
617                    let pairs: Vec<(&str, &[f32])> = batch
618                        .iter()
619                        .zip(embeddings.iter())
620                        .map(|((id, _), emb_vec)| (id.as_str(), emb_vec.as_slice()))
621                        .collect();
622                    if let Err(e) = self.storage.store_embeddings_batch(&pairs) {
623                        tracing::warn!("Failed to batch-store embeddings: {e}");
624                    }
625                    let sqlite_ms = t1.elapsed().as_millis();
626
627                    let t2 = std::time::Instant::now();
628                    let batch_items: Vec<(String, Vec<f32>)> = batch
629                        .iter()
630                        .zip(embeddings.into_iter())
631                        .map(|((id, _), emb_vec)| (id.clone(), emb_vec))
632                        .collect();
633                    let batch_len = batch_items.len();
634                    {
635                        let mut vec = self.lock_vector()?;
636                        if let Err(e) = vec.insert_batch(&batch_items) {
637                            tracing::warn!("Failed to batch-insert into vector index: {e}");
638                        }
639                    }
640                    let vector_ms = t2.elapsed().as_millis();
641
642                    let syms_in_batch = batch_len.min(sym_count.saturating_sub(done));
643                    symbols_embedded += syms_in_batch;
644                    chunks_embedded += batch_len - syms_in_batch;
645                    done += batch_len;
646
647                    tracing::debug!(
648                        "Embed batch {}: embed={embed_ms}ms sqlite={sqlite_ms}ms vector={vector_ms}ms",
649                        batch_len
650                    );
651                }
652                Err(e) => {
653                    tracing::warn!("embed_batch failed for chunk of {} texts: {e}", batch.len());
654                }
655            }
656            on_progress(done, total);
657        }
658        self.save_index();
659
660        Ok((symbols_embedded, chunks_embedded))
661    }
662}