Skip to main content

codemem_engine/persistence/
mod.rs

1//! Graph persistence: persist indexing results (file/package/symbol/chunk nodes,
2//! edges, embeddings, compaction) into the storage and graph backends.
3
4mod compaction;
5
6use crate::index::{CodeChunk, ResolvedEdge, Symbol};
7use crate::IndexAndResolveResult;
8use codemem_core::{
9    CodememError, Edge, GraphBackend, GraphConfig, GraphNode, NodeKind, RelationshipType,
10    VectorBackend,
11};
12use std::collections::{HashMap, HashSet};
13
14/// Counts of what was persisted by `persist_index_results`.
15#[derive(Debug, Default)]
16pub struct IndexPersistResult {
17    pub files_created: usize,
18    pub packages_created: usize,
19    pub symbols_stored: usize,
20    pub chunks_stored: usize,
21    pub edges_resolved: usize,
22    pub symbols_embedded: usize,
23    pub chunks_embedded: usize,
24    pub chunks_pruned: usize,
25    pub symbols_pruned: usize,
26}
27
28/// Return the edge weight for a given relationship type, using config overrides
29/// for the three most common types (Contains, Calls, Imports).
30pub fn edge_weight_for(rel: &RelationshipType, config: &GraphConfig) -> f64 {
31    match rel {
32        RelationshipType::Calls => config.calls_edge_weight,
33        RelationshipType::Imports => config.imports_edge_weight,
34        RelationshipType::Contains => config.contains_edge_weight,
35        RelationshipType::Implements | RelationshipType::Inherits => 0.8,
36        RelationshipType::DependsOn => 0.7,
37        RelationshipType::CoChanged => 0.6,
38        RelationshipType::EvolvedInto | RelationshipType::Summarizes => 0.7,
39        RelationshipType::PartOf => 0.4,
40        RelationshipType::RelatesTo | RelationshipType::SharesTheme => 0.3,
41        _ => 0.5,
42    }
43}
44
45/// Intermediate counts from graph node persistence (before embedding).
46struct GraphPersistCounts {
47    packages_created: usize,
48    chunks_stored: usize,
49}
50
51impl super::CodememEngine {
52    /// Persist all indexing results (file nodes, package tree, symbol nodes, chunk nodes,
53    /// edges, embeddings, compaction) into storage and the in-memory graph.
54    ///
55    /// This is the full persistence pipeline called after `Indexer::index_and_resolve()`.
56    pub fn persist_index_results(
57        &self,
58        results: &IndexAndResolveResult,
59        namespace: Option<&str>,
60    ) -> Result<IndexPersistResult, CodememError> {
61        self.persist_index_results_with_progress(results, namespace, |_, _| {})
62    }
63
64    /// Like `persist_index_results`, but calls `on_progress(done, total)` during
65    /// the embedding phase so callers can display progress.
66    pub fn persist_index_results_with_progress(
67        &self,
68        results: &IndexAndResolveResult,
69        namespace: Option<&str>,
70        on_progress: impl Fn(usize, usize),
71    ) -> Result<IndexPersistResult, CodememError> {
72        let seen_files = &results.file_paths;
73
74        // 1. Persist all graph nodes and edges
75        let graph_counts = self.persist_graph_nodes(results, namespace)?;
76
77        // 2. Embed symbols and chunks
78        let (symbols_embedded, chunks_embedded) = self.embed_and_persist(
79            &results.symbols,
80            &results.chunks,
81            &results.edges,
82            on_progress,
83        )?;
84
85        // 3. Auto-compact
86        let (chunks_pruned, symbols_pruned) = if self.config.chunking.auto_compact {
87            self.compact_graph(seen_files)
88        } else {
89            (0, 0)
90        };
91
92        Ok(IndexPersistResult {
93            files_created: seen_files.len(),
94            packages_created: graph_counts.packages_created,
95            symbols_stored: results.symbols.len(),
96            chunks_stored: graph_counts.chunks_stored,
97            edges_resolved: results.edges.len(),
98            symbols_embedded,
99            chunks_embedded,
100            chunks_pruned,
101            symbols_pruned,
102        })
103    }
104
105    // ── Graph Node Persistence ───────────────────────────────────────────
106
107    /// Persist file, package, symbol, chunk nodes and all edges into storage
108    /// and the in-memory graph. Returns counts for the result struct.
109    fn persist_graph_nodes(
110        &self,
111        results: &IndexAndResolveResult,
112        namespace: Option<&str>,
113    ) -> Result<GraphPersistCounts, CodememError> {
114        let all_symbols = &results.symbols;
115        let all_chunks = &results.chunks;
116        let seen_files = &results.file_paths;
117        let edges = &results.edges;
118
119        let now = chrono::Utc::now();
120        let ns_string = namespace.map(|s| s.to_string());
121        let contains_weight = edge_weight_for(&RelationshipType::Contains, &self.config.graph);
122
123        let mut graph = self.lock_graph()?;
124
125        // ── File nodes
126        let file_nodes: Vec<GraphNode> = seen_files
127            .iter()
128            .map(|file_path| {
129                let mut payload = HashMap::new();
130                payload.insert(
131                    "file_path".to_string(),
132                    serde_json::Value::String(file_path.clone()),
133                );
134                GraphNode {
135                    id: format!("file:{file_path}"),
136                    kind: NodeKind::File,
137                    label: file_path.clone(),
138                    payload,
139                    centrality: 0.0,
140                    memory_id: None,
141                    namespace: ns_string.clone(),
142                }
143            })
144            .collect();
145        self.persist_nodes_to_storage_and_graph(&file_nodes, &mut graph);
146
147        // ── Package (directory) nodes
148        let (dir_nodes, dir_edges, created_dirs) =
149            self.build_package_tree(seen_files, &ns_string, contains_weight, now, &graph);
150        self.persist_nodes_to_storage_and_graph(&dir_nodes, &mut graph);
151        self.persist_edges_to_storage_and_graph(&dir_edges, &mut graph);
152
153        // ── Symbol nodes + file→symbol edges
154        let (sym_nodes, sym_edges) =
155            Self::build_symbol_nodes(all_symbols, &ns_string, contains_weight, now);
156        self.persist_nodes_to_storage_and_graph(&sym_nodes, &mut graph);
157        self.persist_edges_to_storage_and_graph(&sym_edges, &mut graph);
158
159        // ── Resolved reference edges
160        let ref_edges = Self::build_reference_edges(edges, &self.config.graph, now);
161        self.persist_edges_to_storage_and_graph(&ref_edges, &mut graph);
162
163        // ── Chunk nodes + file→chunk / symbol→chunk edges
164        for file_path in seen_files {
165            let prefix = format!("chunk:{file_path}:");
166            let _ = self.storage.delete_graph_nodes_by_prefix(&prefix);
167        }
168        let (chunk_nodes, chunk_edges) =
169            Self::build_chunk_nodes(all_chunks, &ns_string, contains_weight, now);
170        let chunk_count = chunk_nodes.len();
171        self.persist_nodes_to_storage_and_graph(&chunk_nodes, &mut graph);
172        self.persist_edges_to_storage_and_graph(&chunk_edges, &mut graph);
173
174        drop(graph);
175
176        Ok(GraphPersistCounts {
177            packages_created: created_dirs,
178            chunks_stored: chunk_count,
179        })
180    }
181
182    /// Batch-insert nodes into both SQLite and the in-memory graph.
183    fn persist_nodes_to_storage_and_graph(
184        &self,
185        nodes: &[GraphNode],
186        graph: &mut crate::GraphEngine,
187    ) {
188        let _ = self.storage.insert_graph_nodes_batch(nodes);
189        for node in nodes {
190            let _ = graph.add_node(node.clone());
191        }
192    }
193
194    /// Batch-insert edges into both SQLite and the in-memory graph.
195    fn persist_edges_to_storage_and_graph(&self, edges: &[Edge], graph: &mut crate::GraphEngine) {
196        let _ = self.storage.insert_graph_edges_batch(edges);
197        for edge in edges {
198            let _ = graph.add_edge(edge.clone());
199        }
200    }
201
202    /// Build directory/package nodes and CONTAINS edges from file paths.
203    /// Returns (nodes, edges, number_of_dirs_created).
204    fn build_package_tree(
205        &self,
206        seen_files: &HashSet<String>,
207        ns_string: &Option<String>,
208        contains_weight: f64,
209        now: chrono::DateTime<chrono::Utc>,
210        graph: &crate::GraphEngine,
211    ) -> (Vec<GraphNode>, Vec<Edge>, usize) {
212        let mut created_dirs: HashSet<String> = HashSet::new();
213        let mut dir_nodes = Vec::new();
214        let mut dir_edges = Vec::new();
215
216        for file_path in seen_files {
217            let p = std::path::Path::new(file_path);
218            let mut ancestors: Vec<String> = Vec::new();
219            let mut current = p.parent();
220            while let Some(dir) = current {
221                let dir_str = dir.to_string_lossy().to_string();
222                if dir_str.is_empty() || dir_str == "." {
223                    break;
224                }
225                ancestors.push(dir_str);
226                current = dir.parent();
227            }
228            ancestors.reverse();
229            for (i, dir_str) in ancestors.iter().enumerate() {
230                let pkg_id = format!("pkg:{dir_str}/");
231                if created_dirs.insert(pkg_id.clone()) {
232                    dir_nodes.push(GraphNode {
233                        id: pkg_id.clone(),
234                        kind: NodeKind::Package,
235                        label: format!("{dir_str}/"),
236                        payload: HashMap::new(),
237                        centrality: 0.0,
238                        memory_id: None,
239                        namespace: ns_string.clone(),
240                    });
241                }
242                if i == 0 {
243                    continue;
244                }
245                let parent_pkg_id = format!("pkg:{}/", ancestors[i - 1]);
246                let edge_id = format!("contains:{parent_pkg_id}->{pkg_id}");
247                if graph
248                    .get_edges(&parent_pkg_id)
249                    .unwrap_or_default()
250                    .iter()
251                    .any(|e| e.id == edge_id)
252                {
253                    continue;
254                }
255                dir_edges.push(Edge {
256                    id: edge_id,
257                    src: parent_pkg_id,
258                    dst: pkg_id.clone(),
259                    relationship: RelationshipType::Contains,
260                    weight: contains_weight,
261                    valid_from: None,
262                    valid_to: None,
263                    properties: HashMap::new(),
264                    created_at: now,
265                });
266            }
267            if let Some(last_dir) = ancestors.last() {
268                let parent_pkg_id = format!("pkg:{last_dir}/");
269                let file_node_id = format!("file:{file_path}");
270                let edge_id = format!("contains:{parent_pkg_id}->{file_node_id}");
271                dir_edges.push(Edge {
272                    id: edge_id,
273                    src: parent_pkg_id,
274                    dst: file_node_id,
275                    relationship: RelationshipType::Contains,
276                    weight: contains_weight,
277                    valid_from: None,
278                    valid_to: None,
279                    properties: HashMap::new(),
280                    created_at: now,
281                });
282            }
283        }
284
285        let count = created_dirs.len();
286        (dir_nodes, dir_edges, count)
287    }
288
289    /// Build symbol graph nodes and file→symbol CONTAINS edges.
290    fn build_symbol_nodes(
291        symbols: &[Symbol],
292        ns_string: &Option<String>,
293        contains_weight: f64,
294        now: chrono::DateTime<chrono::Utc>,
295    ) -> (Vec<GraphNode>, Vec<Edge>) {
296        let mut sym_nodes = Vec::with_capacity(symbols.len());
297        let mut sym_edges = Vec::with_capacity(symbols.len());
298
299        for sym in symbols {
300            let kind = NodeKind::from(sym.kind);
301            let payload = Self::build_symbol_payload(sym);
302
303            let sym_node_id = format!("sym:{}", sym.qualified_name);
304            sym_nodes.push(GraphNode {
305                id: sym_node_id.clone(),
306                kind,
307                label: sym.qualified_name.clone(),
308                payload,
309                centrality: 0.0,
310                memory_id: None,
311                namespace: ns_string.clone(),
312            });
313
314            let file_node_id = format!("file:{}", sym.file_path);
315            sym_edges.push(Edge {
316                id: format!("contains:{file_node_id}->{sym_node_id}"),
317                src: file_node_id,
318                dst: sym_node_id,
319                relationship: RelationshipType::Contains,
320                weight: contains_weight,
321                valid_from: None,
322                valid_to: None,
323                properties: HashMap::new(),
324                created_at: now,
325            });
326        }
327
328        (sym_nodes, sym_edges)
329    }
330
331    /// Build the payload HashMap for a symbol's graph node.
332    fn build_symbol_payload(sym: &Symbol) -> HashMap<String, serde_json::Value> {
333        let mut payload = HashMap::new();
334        payload.insert(
335            "symbol_kind".to_string(),
336            serde_json::Value::String(sym.kind.to_string()),
337        );
338        payload.insert(
339            "signature".to_string(),
340            serde_json::Value::String(sym.signature.clone()),
341        );
342        payload.insert(
343            "file_path".to_string(),
344            serde_json::Value::String(sym.file_path.clone()),
345        );
346        payload.insert("line_start".to_string(), serde_json::json!(sym.line_start));
347        payload.insert("line_end".to_string(), serde_json::json!(sym.line_end));
348        payload.insert(
349            "visibility".to_string(),
350            serde_json::Value::String(sym.visibility.to_string()),
351        );
352        if let Some(ref doc) = sym.doc_comment {
353            payload.insert(
354                "doc_comment".to_string(),
355                serde_json::Value::String(doc.clone()),
356            );
357        }
358        if !sym.parameters.is_empty() {
359            payload.insert(
360                "parameters".to_string(),
361                serde_json::to_value(&sym.parameters).unwrap_or_default(),
362            );
363        }
364        if let Some(ref ret) = sym.return_type {
365            payload.insert(
366                "return_type".to_string(),
367                serde_json::Value::String(ret.clone()),
368            );
369        }
370        if sym.is_async {
371            payload.insert("is_async".to_string(), serde_json::json!(true));
372        }
373        if !sym.attributes.is_empty() {
374            payload.insert(
375                "attributes".to_string(),
376                serde_json::to_value(&sym.attributes).unwrap_or_default(),
377            );
378        }
379        if !sym.throws.is_empty() {
380            payload.insert(
381                "throws".to_string(),
382                serde_json::to_value(&sym.throws).unwrap_or_default(),
383            );
384        }
385        if let Some(ref gp) = sym.generic_params {
386            payload.insert(
387                "generic_params".to_string(),
388                serde_json::Value::String(gp.clone()),
389            );
390        }
391        if sym.is_abstract {
392            payload.insert("is_abstract".to_string(), serde_json::json!(true));
393        }
394        if let Some(ref parent) = sym.parent {
395            payload.insert(
396                "parent".to_string(),
397                serde_json::Value::String(parent.clone()),
398            );
399        }
400        payload
401    }
402
403    /// Build edges from resolved cross-file references.
404    fn build_reference_edges(
405        edges: &[ResolvedEdge],
406        graph_config: &GraphConfig,
407        now: chrono::DateTime<chrono::Utc>,
408    ) -> Vec<Edge> {
409        edges
410            .iter()
411            .map(|edge| Edge {
412                id: format!(
413                    "ref:{}->{}:{}",
414                    edge.source_qualified_name, edge.target_qualified_name, edge.relationship
415                ),
416                src: format!("sym:{}", edge.source_qualified_name),
417                dst: format!("sym:{}", edge.target_qualified_name),
418                relationship: edge.relationship,
419                weight: edge_weight_for(&edge.relationship, graph_config),
420                valid_from: None,
421                valid_to: None,
422                properties: HashMap::new(),
423                created_at: now,
424            })
425            .collect()
426    }
427
428    /// Build chunk graph nodes and file→chunk / symbol→chunk CONTAINS edges.
429    fn build_chunk_nodes(
430        chunks: &[CodeChunk],
431        ns_string: &Option<String>,
432        contains_weight: f64,
433        now: chrono::DateTime<chrono::Utc>,
434    ) -> (Vec<GraphNode>, Vec<Edge>) {
435        let mut chunk_nodes = Vec::with_capacity(chunks.len());
436        let mut chunk_edges = Vec::with_capacity(chunks.len() * 2);
437
438        for chunk in chunks {
439            let chunk_id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
440
441            let mut payload = HashMap::new();
442            payload.insert(
443                "file_path".to_string(),
444                serde_json::Value::String(chunk.file_path.clone()),
445            );
446            payload.insert(
447                "line_start".to_string(),
448                serde_json::json!(chunk.line_start),
449            );
450            payload.insert("line_end".to_string(), serde_json::json!(chunk.line_end));
451            payload.insert(
452                "node_kind".to_string(),
453                serde_json::Value::String(chunk.node_kind.clone()),
454            );
455            payload.insert(
456                "non_ws_chars".to_string(),
457                serde_json::json!(chunk.non_ws_chars),
458            );
459            if let Some(ref parent) = chunk.parent_symbol {
460                payload.insert(
461                    "parent_symbol".to_string(),
462                    serde_json::Value::String(parent.clone()),
463                );
464            }
465
466            chunk_nodes.push(GraphNode {
467                id: chunk_id.clone(),
468                kind: NodeKind::Chunk,
469                label: format!(
470                    "chunk:{}:{}..{}",
471                    chunk.file_path, chunk.line_start, chunk.line_end
472                ),
473                payload,
474                centrality: 0.0,
475                memory_id: None,
476                namespace: ns_string.clone(),
477            });
478
479            let file_node_id = format!("file:{}", chunk.file_path);
480            chunk_edges.push(Edge {
481                id: format!("contains:{file_node_id}->{chunk_id}"),
482                src: file_node_id,
483                dst: chunk_id.clone(),
484                relationship: RelationshipType::Contains,
485                weight: contains_weight,
486                valid_from: None,
487                valid_to: None,
488                properties: HashMap::new(),
489                created_at: now,
490            });
491
492            if let Some(ref parent_sym) = chunk.parent_symbol {
493                let parent_node_id = format!("sym:{parent_sym}");
494                chunk_edges.push(Edge {
495                    id: format!("contains:{parent_node_id}->{chunk_id}"),
496                    src: parent_node_id,
497                    dst: chunk_id,
498                    relationship: RelationshipType::Contains,
499                    weight: contains_weight,
500                    valid_from: None,
501                    valid_to: None,
502                    properties: HashMap::new(),
503                    created_at: now,
504                });
505            }
506        }
507
508        (chunk_nodes, chunk_edges)
509    }
510
511    // ── Embedding Persistence ────────────────────────────────────────────
512
513    /// Embed symbols and chunks, persisting embeddings to SQLite and the
514    /// vector index in batches with progress reporting.
515    ///
516    /// Returns (symbols_embedded, chunks_embedded).
517    fn embed_and_persist(
518        &self,
519        symbols: &[Symbol],
520        chunks: &[CodeChunk],
521        edges: &[ResolvedEdge],
522        on_progress: impl Fn(usize, usize),
523    ) -> Result<(usize, usize), CodememError> {
524        let mut symbols_embedded = 0usize;
525        let mut chunks_embedded = 0usize;
526
527        // Quick check: skip expensive text enrichment if no embedding provider.
528        let has_embeddings = self.lock_embeddings()?.is_some();
529        if !has_embeddings {
530            return Ok((0, 0));
531        }
532
533        // Phase 1: Collect enriched texts without holding any lock.
534        let sym_texts: Vec<(String, String)> = symbols
535            .iter()
536            .map(|sym| {
537                let id = format!("sym:{}", sym.qualified_name);
538                let text = self.enrich_symbol_text(sym, edges);
539                (id, text)
540            })
541            .collect();
542        let chunk_texts: Vec<(String, String)> = chunks
543            .iter()
544            .map(|chunk| {
545                let id = format!("chunk:{}:{}", chunk.file_path, chunk.index);
546                let text = self.enrich_chunk_text(chunk);
547                (id, text)
548            })
549            .collect();
550
551        // Phase 2+3: Embed in batches and persist progressively.
552        const EMBED_CHUNK_SIZE: usize = 64;
553
554        let all_pairs: Vec<(String, String)> = sym_texts.into_iter().chain(chunk_texts).collect();
555        let total = all_pairs.len();
556        let sym_count = symbols.len();
557        let mut done = 0usize;
558
559        for batch in all_pairs.chunks(EMBED_CHUNK_SIZE) {
560            let texts: Vec<&str> = batch.iter().map(|(_, t)| t.as_str()).collect();
561
562            let t0 = std::time::Instant::now();
563            let embed_result = {
564                let emb = self.lock_embeddings()?;
565                match emb {
566                    Some(emb_guard) => emb_guard.embed_batch(&texts),
567                    None => break,
568                }
569            };
570
571            match embed_result {
572                Ok(embeddings) => {
573                    let embed_ms = t0.elapsed().as_millis();
574
575                    let t1 = std::time::Instant::now();
576                    let pairs: Vec<(&str, &[f32])> = batch
577                        .iter()
578                        .zip(embeddings.iter())
579                        .map(|((id, _), emb_vec)| (id.as_str(), emb_vec.as_slice()))
580                        .collect();
581                    if let Err(e) = self.storage.store_embeddings_batch(&pairs) {
582                        tracing::warn!("Failed to batch-store embeddings: {e}");
583                    }
584                    let sqlite_ms = t1.elapsed().as_millis();
585
586                    let t2 = std::time::Instant::now();
587                    let batch_items: Vec<(String, Vec<f32>)> = batch
588                        .iter()
589                        .zip(embeddings.into_iter())
590                        .map(|((id, _), emb_vec)| (id.clone(), emb_vec))
591                        .collect();
592                    let batch_len = batch_items.len();
593                    {
594                        let mut vec = self.lock_vector()?;
595                        if let Err(e) = vec.insert_batch(&batch_items) {
596                            tracing::warn!("Failed to batch-insert into vector index: {e}");
597                        }
598                    }
599                    let vector_ms = t2.elapsed().as_millis();
600
601                    let syms_in_batch = batch_len.min(sym_count.saturating_sub(done));
602                    symbols_embedded += syms_in_batch;
603                    chunks_embedded += batch_len - syms_in_batch;
604                    done += batch_len;
605
606                    tracing::debug!(
607                        "Embed batch {}: embed={embed_ms}ms sqlite={sqlite_ms}ms vector={vector_ms}ms",
608                        batch_len
609                    );
610                }
611                Err(e) => {
612                    tracing::warn!("embed_batch failed for chunk of {} texts: {e}", batch.len());
613                }
614            }
615            on_progress(done, total);
616        }
617        self.save_index();
618
619        Ok((symbols_embedded, chunks_embedded))
620    }
621}