Skip to main content

infigraph_core/graph/
store_parquet.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use anyhow::Result;
5use arrow::array::{Int64Array, StringArray};
6use arrow::datatypes::DataType;
7use kuzu::Connection;
8
9use super::parquet_loader;
10use super::store::GraphStore;
11use super::store_util::{escape, fwd_slash_path, unwind_edges_from_pairs};
12use crate::model::{FileExtraction, RelationKind};
13
14fn unique_tmp_dir() -> std::path::PathBuf {
15    static COUNTER: AtomicU64 = AtomicU64::new(0);
16    let id = COUNTER.fetch_add(1, Ordering::Relaxed);
17    let pid = std::process::id();
18    let dir = std::env::temp_dir().join(format!("infigraph_pq_{}_{}", pid, id));
19    let _ = std::fs::create_dir_all(&dir);
20    dir
21}
22
23impl GraphStore {
24    /// Create Folder nodes and edges for a set of file paths in bulk.
25    /// More efficient than per-file upsert_folder_hierarchy calls.
26    pub fn upsert_folders_bulk(&self, file_paths: &[&str]) -> Result<()> {
27        let _lock = self.write_lock()?;
28        let conn = self.connection()?;
29        self.upsert_folders_bulk_conn(&conn, file_paths)
30    }
31
32    /// Caller must hold WriteLock.
33    pub fn upsert_folders_bulk_conn(
34        &self,
35        conn: &Connection<'_>,
36        file_paths: &[&str],
37    ) -> Result<()> {
38        let mut all_folders: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
39        for file_path in file_paths {
40            let parts: Vec<&str> = file_path.rsplitn(2, '/').collect();
41            if parts.len() < 2 {
42                continue;
43            }
44            let dir_path = parts[1];
45            let segments: Vec<&str> = dir_path.split('/').collect();
46            for i in 0..segments.len() {
47                all_folders.insert(segments[..=i].join("/"));
48            }
49        }
50
51        if all_folders.is_empty() {
52            return Ok(());
53        }
54
55        // Write Folder nodes to parquet
56        let folder_pq = unique_tmp_dir().join("infigraph_folders.parquet");
57        {
58            let ids: Vec<&str> = all_folders.iter().map(|s| s.as_str()).collect();
59            let names: Vec<&str> = all_folders
60                .iter()
61                .map(|fp| fp.rsplit_once('/').map(|(_, n)| n).unwrap_or(fp.as_str()))
62                .collect();
63            let paths: Vec<&str> = all_folders.iter().map(|s| s.as_str()).collect();
64            parquet_loader::write_node_parquet(
65                &folder_pq,
66                &[
67                    ("id", DataType::Utf8),
68                    ("name", DataType::Utf8),
69                    ("path", DataType::Utf8),
70                ],
71                vec![
72                    Arc::new(StringArray::from(ids)),
73                    Arc::new(StringArray::from(names)),
74                    Arc::new(StringArray::from(paths)),
75                ],
76            )?;
77        }
78
79        // Collect edge pairs in memory
80        let cf_pairs: Vec<(String, String)> = all_folders
81            .iter()
82            .filter_map(|child| {
83                child
84                    .rsplit_once('/')
85                    .map(|(p, _)| p)
86                    .and_then(|parent_path| {
87                        if all_folders.contains(parent_path) {
88                            Some((parent_path.to_string(), child.clone()))
89                        } else {
90                            None
91                        }
92                    })
93            })
94            .collect();
95
96        let cfile_pairs: Vec<(String, String)> = file_paths
97            .iter()
98            .filter_map(|fp| {
99                let parts: Vec<&str> = fp.rsplitn(2, '/').collect();
100                if parts.len() < 2 {
101                    return None;
102                }
103                Some((parts[1].to_string(), fp.to_string()))
104            })
105            .collect();
106
107        let copy_ok = conn
108            .query(&format!(
109                "COPY Folder FROM '{}'",
110                fwd_slash_path(&folder_pq)
111            ))
112            .is_ok();
113
114        if copy_ok {
115            // Write edge parquet files and COPY FROM
116            let cf_pq = unique_tmp_dir().join("infigraph_contains_folder.parquet");
117            let cf_refs: Vec<(&str, &str)> = cf_pairs
118                .iter()
119                .map(|(a, b)| (a.as_str(), b.as_str()))
120                .collect();
121            parquet_loader::write_edge_parquet(&cf_pq, &cf_refs)?;
122            if let Err(e) = conn.query(&format!(
123                "COPY CONTAINS_FOLDER FROM '{}'",
124                fwd_slash_path(&cf_pq)
125            )) {
126                eprintln!("warn: COPY CONTAINS_FOLDER failed ({e}), using UNWIND fallback");
127                unwind_edges_from_pairs(conn, &cf_refs, "CONTAINS_FOLDER", "Folder", "Folder");
128            }
129            let _ = std::fs::remove_file(&cf_pq);
130
131            let cfile_pq = unique_tmp_dir().join("infigraph_contains_file.parquet");
132            let cfile_refs: Vec<(&str, &str)> = cfile_pairs
133                .iter()
134                .map(|(a, b)| (a.as_str(), b.as_str()))
135                .collect();
136            parquet_loader::write_edge_parquet(&cfile_pq, &cfile_refs)?;
137            if let Err(e) = conn.query(&format!(
138                "COPY CONTAINS_FILE FROM '{}'",
139                fwd_slash_path(&cfile_pq)
140            )) {
141                eprintln!("warn: COPY CONTAINS_FILE failed ({e}), using UNWIND fallback");
142                unwind_edges_from_pairs(conn, &cfile_refs, "CONTAINS_FILE", "Folder", "File");
143            }
144            let _ = std::fs::remove_file(&cfile_pq);
145        } else {
146            // Incremental path: some folders may already exist. Use UNWIND with MERGE semantics.
147            const CHUNK: usize = 500;
148            for chunk in all_folders.iter().collect::<Vec<_>>().chunks(CHUNK) {
149                let items: Vec<String> = chunk
150                    .iter()
151                    .map(|fp| {
152                        let name = fp.rsplit_once('/').map(|(_, n)| n).unwrap_or(fp);
153                        format!(
154                            "{{id: '{}', name: '{}', path: '{}'}}",
155                            escape(fp),
156                            escape(name),
157                            escape(fp)
158                        )
159                    })
160                    .collect();
161                let _ = conn.query(&format!(
162                    "UNWIND [{}] AS f MERGE (d:Folder {{id: f.id}}) ON CREATE SET d.name = f.name, d.path = f.path ON MATCH SET d.name = f.name, d.path = f.path",
163                    items.join(", ")
164                ));
165            }
166            let cf_refs: Vec<(&str, &str)> = cf_pairs
167                .iter()
168                .map(|(a, b)| (a.as_str(), b.as_str()))
169                .collect();
170            unwind_edges_from_pairs(conn, &cf_refs, "CONTAINS_FOLDER", "Folder", "Folder");
171            let cfile_refs: Vec<(&str, &str)> = cfile_pairs
172                .iter()
173                .map(|(a, b)| (a.as_str(), b.as_str()))
174                .collect();
175            unwind_edges_from_pairs(conn, &cfile_refs, "CONTAINS_FILE", "Folder", "File");
176        }
177
178        let _ = std::fs::remove_file(&folder_pq);
179        Ok(())
180    }
181
182    /// Bulk write all extractions using COPY FROM Parquet -- binary format eliminates escaping issues.
183    /// Used for --full index. Incremental index still uses upsert_file_conn_no_delete.
184    pub fn upsert_all_parquet(&self, extractions: &[FileExtraction]) -> Result<()> {
185        if extractions.is_empty() {
186            return Ok(());
187        }
188        let _lock = self.write_lock()?;
189        let conn = self.connection()?;
190        self.upsert_all_parquet_conn(&conn, extractions)
191    }
192
193    /// Caller must hold WriteLock.
194    pub fn upsert_all_parquet_conn(
195        &self,
196        conn: &Connection<'_>,
197        extractions: &[FileExtraction],
198    ) -> Result<()> {
199        if extractions.is_empty() {
200            return Ok(());
201        }
202
203        let tmp = unique_tmp_dir();
204
205        let mut known_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
206        for e in extractions {
207            for sym in &e.symbols {
208                known_ids.insert(sym.id.clone());
209            }
210        }
211        let mut sym_seen: std::collections::HashSet<String> = std::collections::HashSet::new();
212        let known_module_ids: std::collections::HashSet<String> =
213            extractions.iter().map(|e| e.file.clone()).collect();
214
215        // Collect all data into vecs
216        let mut mod_ids = Vec::new();
217        let mut mod_names = Vec::new();
218        let mut mod_files = Vec::new();
219        let mut mod_langs = Vec::new();
220        let mut mod_hashes = Vec::new();
221        let mut mod_summaries = Vec::new();
222        let mut file_ids = Vec::new();
223        let mut file_names = Vec::new();
224        let mut file_paths = Vec::new();
225        let mut file_langs = Vec::new();
226        let mut file_symcounts: Vec<i64> = Vec::new();
227        let mut sym_ids = Vec::new();
228        let mut sym_names = Vec::new();
229        let mut sym_kinds = Vec::new();
230        let mut sym_files = Vec::new();
231        let mut sym_slines: Vec<i64> = Vec::new();
232        let mut sym_elines: Vec<i64> = Vec::new();
233        let mut sym_sighashes = Vec::new();
234        let mut sym_languages = Vec::new();
235        let mut sym_visibilities = Vec::new();
236        let mut sym_parents = Vec::new();
237        let mut sym_docstrings = Vec::new();
238        let mut sym_complexities: Vec<i64> = Vec::new();
239        let mut sym_parameters = Vec::new();
240        let mut sym_return_types = Vec::new();
241        let mut contains_pairs: Vec<(String, String)> = Vec::new();
242        let mut defines_pairs: Vec<(String, String)> = Vec::new();
243
244        let mut calls_seen: std::collections::HashSet<(String, String)> =
245            std::collections::HashSet::new();
246        let mut inh_seen: std::collections::HashSet<(String, String)> =
247            std::collections::HashSet::new();
248        let mut test_seen: std::collections::HashSet<(String, String)> =
249            std::collections::HashSet::new();
250        let mut imp_seen: std::collections::HashSet<(String, String)> =
251            std::collections::HashSet::new();
252        let mut reads_seen: std::collections::HashSet<(String, String)> =
253            std::collections::HashSet::new();
254        let mut writes_seen: std::collections::HashSet<(String, String)> =
255            std::collections::HashSet::new();
256        let mut calls_pairs: Vec<(String, String)> = Vec::new();
257        let mut inh_pairs: Vec<(String, String)> = Vec::new();
258        let mut test_pairs: Vec<(String, String)> = Vec::new();
259        let mut imp_pairs: Vec<(String, String)> = Vec::new();
260        let mut reads_pairs: Vec<(String, String)> = Vec::new();
261        let mut writes_pairs: Vec<(String, String)> = Vec::new();
262        let mut custom_seen: std::collections::HashMap<
263            String,
264            std::collections::HashSet<(String, String)>,
265        > = std::collections::HashMap::new();
266        let mut custom_pairs: std::collections::HashMap<String, Vec<(String, String)>> =
267            std::collections::HashMap::new();
268
269        let mut stmt_ids: Vec<String> = Vec::new();
270        let mut stmt_kinds: Vec<String> = Vec::new();
271        let mut stmt_conditions: Vec<String> = Vec::new();
272        let mut stmt_slines: Vec<i64> = Vec::new();
273        let mut stmt_elines: Vec<i64> = Vec::new();
274        let mut stmt_depths: Vec<i64> = Vec::new();
275        let mut stmt_parents_sym = Vec::new();
276        let mut has_stmt_pairs: Vec<(String, String)> = Vec::new();
277
278        for e in extractions {
279            let mod_name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
280            mod_ids.push(e.file.clone());
281            mod_names.push(mod_name.to_string());
282            mod_files.push(e.file.clone());
283            mod_langs.push(e.language.clone());
284            mod_hashes.push(e.content_hash.clone());
285            mod_summaries.push(String::new());
286
287            file_ids.push(e.file.clone());
288            file_names.push(mod_name.to_string());
289            file_paths.push(e.file.clone());
290            file_langs.push(e.language.clone());
291            file_symcounts.push(e.symbols.len() as i64);
292
293            for sym in &e.symbols {
294                if sym_seen.insert(sym.id.clone()) {
295                    sym_ids.push(sym.id.clone());
296                    sym_names.push(sym.name.clone());
297                    sym_kinds.push(sym.kind.as_str().to_string());
298                    sym_files.push(e.file.clone());
299                    sym_slines.push(sym.span.start_line as i64);
300                    sym_elines.push(sym.span.end_line as i64);
301                    sym_sighashes.push(sym.signature_hash.clone());
302                    sym_languages.push(sym.language.clone());
303                    sym_visibilities.push(sym.visibility.as_deref().unwrap_or("").to_string());
304                    sym_parents.push(sym.parent.as_deref().unwrap_or("").to_string());
305                    sym_docstrings.push(sym.docstring.as_deref().unwrap_or("").to_string());
306                    sym_complexities.push(sym.complexity as i64);
307                    sym_parameters.push(sym.parameters.as_deref().unwrap_or("").to_string());
308                    sym_return_types.push(sym.return_type.as_deref().unwrap_or("").to_string());
309                    contains_pairs.push((e.file.clone(), sym.id.clone()));
310                    defines_pairs.push((e.file.clone(), sym.id.clone()));
311                }
312            }
313
314            for rel in &e.relations {
315                let src = rel.source_id.clone();
316                let tgt = rel.target_id.clone();
317                match &rel.kind {
318                    RelationKind::Imports | RelationKind::ImportedBy => {
319                        if known_module_ids.contains(&src)
320                            && known_module_ids.contains(&tgt)
321                            && imp_seen.insert((src.clone(), tgt.clone()))
322                        {
323                            imp_pairs.push((src, tgt));
324                        }
325                    }
326                    RelationKind::Custom(name) => {
327                        if known_ids.contains(&src)
328                            && known_ids.contains(&tgt)
329                            && custom_seen
330                                .entry(name.clone())
331                                .or_default()
332                                .insert((src.clone(), tgt.clone()))
333                        {
334                            custom_pairs
335                                .entry(name.clone())
336                                .or_default()
337                                .push((src, tgt));
338                        }
339                    }
340                    _ => {
341                        if !known_ids.contains(&src) || !known_ids.contains(&tgt) {
342                            continue;
343                        }
344                        match &rel.kind {
345                            RelationKind::Calls | RelationKind::CalledBy
346                                if calls_seen.insert((src.clone(), tgt.clone())) =>
347                            {
348                                calls_pairs.push((src, tgt));
349                            }
350                            RelationKind::Inherits | RelationKind::InheritedBy
351                                if inh_seen.insert((src.clone(), tgt.clone())) =>
352                            {
353                                inh_pairs.push((src, tgt));
354                            }
355                            RelationKind::TestedBy | RelationKind::Tests
356                                if test_seen.insert((src.clone(), tgt.clone())) =>
357                            {
358                                test_pairs.push((src, tgt));
359                            }
360                            RelationKind::Reads
361                                if reads_seen.insert((src.clone(), tgt.clone())) =>
362                            {
363                                reads_pairs.push((src, tgt));
364                            }
365                            RelationKind::Writes
366                                if writes_seen.insert((src.clone(), tgt.clone())) =>
367                            {
368                                writes_pairs.push((src, tgt));
369                            }
370                            _ => {}
371                        }
372                    }
373                }
374            }
375
376            for stmt in &e.statements {
377                stmt_ids.push(stmt.id.clone());
378                stmt_kinds.push(stmt.kind.as_str().to_string());
379                stmt_conditions.push(stmt.condition.clone());
380                stmt_slines.push(stmt.start_line as i64);
381                stmt_elines.push(stmt.end_line as i64);
382                stmt_depths.push(stmt.depth as i64);
383                stmt_parents_sym.push(stmt.parent_symbol.clone());
384                if known_ids.contains(&stmt.parent_symbol) {
385                    has_stmt_pairs.push((stmt.parent_symbol.clone(), stmt.id.clone()));
386                }
387            }
388        }
389
390        // Write node parquet files
391        let mod_pq = tmp.join("infigraph_index_modules.parquet");
392        parquet_loader::write_node_parquet(
393            &mod_pq,
394            &[
395                ("id", DataType::Utf8),
396                ("name", DataType::Utf8),
397                ("file", DataType::Utf8),
398                ("language", DataType::Utf8),
399                ("content_hash", DataType::Utf8),
400                ("summary", DataType::Utf8),
401            ],
402            vec![
403                Arc::new(StringArray::from(mod_ids)),
404                Arc::new(StringArray::from(mod_names)),
405                Arc::new(StringArray::from(mod_files)),
406                Arc::new(StringArray::from(mod_langs)),
407                Arc::new(StringArray::from(mod_hashes)),
408                Arc::new(StringArray::from(mod_summaries)),
409            ],
410        )?;
411
412        let file_pq = tmp.join("infigraph_index_files.parquet");
413        parquet_loader::write_node_parquet(
414            &file_pq,
415            &[
416                ("id", DataType::Utf8),
417                ("name", DataType::Utf8),
418                ("path", DataType::Utf8),
419                ("language", DataType::Utf8),
420                ("symbol_count", DataType::Int64),
421            ],
422            vec![
423                Arc::new(StringArray::from(file_ids)),
424                Arc::new(StringArray::from(file_names)),
425                Arc::new(StringArray::from(file_paths)),
426                Arc::new(StringArray::from(file_langs)),
427                Arc::new(Int64Array::from(file_symcounts)),
428            ],
429        )?;
430
431        let sym_pq = tmp.join("infigraph_index_symbols.parquet");
432        parquet_loader::write_node_parquet(
433            &sym_pq,
434            &[
435                ("id", DataType::Utf8),
436                ("name", DataType::Utf8),
437                ("kind", DataType::Utf8),
438                ("file", DataType::Utf8),
439                ("start_line", DataType::Int64),
440                ("end_line", DataType::Int64),
441                ("signature_hash", DataType::Utf8),
442                ("language", DataType::Utf8),
443                ("visibility", DataType::Utf8),
444                ("parent", DataType::Utf8),
445                ("docstring", DataType::Utf8),
446                ("complexity", DataType::Int64),
447                ("parameters", DataType::Utf8),
448                ("return_type", DataType::Utf8),
449            ],
450            vec![
451                Arc::new(StringArray::from(sym_ids)),
452                Arc::new(StringArray::from(sym_names)),
453                Arc::new(StringArray::from(sym_kinds)),
454                Arc::new(StringArray::from(sym_files)),
455                Arc::new(Int64Array::from(sym_slines)),
456                Arc::new(Int64Array::from(sym_elines)),
457                Arc::new(StringArray::from(sym_sighashes)),
458                Arc::new(StringArray::from(sym_languages)),
459                Arc::new(StringArray::from(sym_visibilities)),
460                Arc::new(StringArray::from(sym_parents)),
461                Arc::new(StringArray::from(sym_docstrings)),
462                Arc::new(Int64Array::from(sym_complexities)),
463                Arc::new(StringArray::from(sym_parameters)),
464                Arc::new(StringArray::from(sym_return_types)),
465            ],
466        )?;
467
468        // COPY FROM parquet -- node tables first
469        conn.query(&format!("COPY Module FROM '{}'", fwd_slash_path(&mod_pq)))
470            .map_err(|e| anyhow::anyhow!("COPY Module failed: {e}"))?;
471        conn.query(&format!("COPY File FROM '{}'", fwd_slash_path(&file_pq)))
472            .map_err(|e| anyhow::anyhow!("COPY File failed: {e}"))?;
473        conn.query(&format!(
474            "COPY Symbol (id, name, kind, file, start_line, end_line, signature_hash, language, visibility, parent, docstring, complexity, parameters, return_type) FROM '{}'",
475            fwd_slash_path(&sym_pq)
476        )).map_err(|e| anyhow::anyhow!("COPY Symbol failed: {e}"))?;
477
478        let stmt_pq = tmp.join("infigraph_index_statements.parquet");
479        if !stmt_ids.is_empty() {
480            parquet_loader::write_node_parquet(
481                &stmt_pq,
482                &[
483                    ("id", DataType::Utf8),
484                    ("kind", DataType::Utf8),
485                    ("condition", DataType::Utf8),
486                    ("start_line", DataType::Int64),
487                    ("end_line", DataType::Int64),
488                    ("depth", DataType::Int64),
489                    ("parent_symbol", DataType::Utf8),
490                ],
491                vec![
492                    Arc::new(StringArray::from(stmt_ids)),
493                    Arc::new(StringArray::from(stmt_kinds)),
494                    Arc::new(StringArray::from(stmt_conditions)),
495                    Arc::new(Int64Array::from(stmt_slines)),
496                    Arc::new(Int64Array::from(stmt_elines)),
497                    Arc::new(Int64Array::from(stmt_depths)),
498                    Arc::new(StringArray::from(stmt_parents_sym)),
499                ],
500            )?;
501            conn.query(&format!(
502                "COPY Statement FROM '{}'",
503                fwd_slash_path(&stmt_pq)
504            ))
505            .map_err(|e| anyhow::anyhow!("COPY Statement failed: {e}"))?;
506        }
507
508        // Edge tables -- write parquet and COPY FROM with in-memory UNWIND fallback
509        #[allow(clippy::type_complexity)]
510        let edge_tables: Vec<(&str, &[(String, String)], &str, &str)> = vec![
511            ("CONTAINS", &contains_pairs, "Module", "Symbol"),
512            ("DEFINES", &defines_pairs, "File", "Symbol"),
513            ("CALLS", &calls_pairs, "Symbol", "Symbol"),
514            ("INHERITS", &inh_pairs, "Symbol", "Symbol"),
515            ("TESTED_BY", &test_pairs, "Symbol", "Symbol"),
516            ("IMPORTS", &imp_pairs, "Module", "Module"),
517            ("READS", &reads_pairs, "Symbol", "Symbol"),
518            ("WRITES", &writes_pairs, "Symbol", "Symbol"),
519            ("HAS_STATEMENT", &has_stmt_pairs, "Symbol", "Statement"),
520        ];
521
522        for (table, pairs, src_label, dst_label) in &edge_tables {
523            if pairs.is_empty() {
524                continue;
525            }
526            let edge_pq = tmp.join(format!("infigraph_index_{}.parquet", table.to_lowercase()));
527            let refs: Vec<(&str, &str)> = pairs
528                .iter()
529                .map(|(a, b)| (a.as_str(), b.as_str()))
530                .collect();
531            parquet_loader::write_edge_parquet(&edge_pq, &refs)?;
532            if let Err(e) = conn.query(&format!("COPY {table} FROM '{}'", fwd_slash_path(&edge_pq)))
533            {
534                eprintln!("warn: COPY {table} via parquet failed ({e}), falling back to UNWIND");
535                unwind_edges_from_pairs(&conn, &refs, table, src_label, dst_label);
536            }
537            let _ = std::fs::remove_file(&edge_pq);
538        }
539
540        // Custom edge tables
541        for (edge_name, pairs) in &custom_pairs {
542            if pairs.is_empty() {
543                continue;
544            }
545            let _ = super::schema::ensure_custom_edge_table(&conn, edge_name);
546            let edge_pq = tmp.join(format!(
547                "infigraph_index_{}.parquet",
548                edge_name.to_lowercase()
549            ));
550            let refs: Vec<(&str, &str)> = pairs
551                .iter()
552                .map(|(a, b)| (a.as_str(), b.as_str()))
553                .collect();
554            parquet_loader::write_edge_parquet(&edge_pq, &refs)?;
555            if let Err(e) = conn.query(&format!(
556                "COPY {} FROM '{}'",
557                edge_name,
558                fwd_slash_path(&edge_pq)
559            )) {
560                eprintln!(
561                    "warn: COPY {} via parquet failed ({e}), falling back to UNWIND",
562                    edge_name
563                );
564                unwind_edges_from_pairs(&conn, &refs, edge_name, "Symbol", "Symbol");
565            }
566            let _ = std::fs::remove_file(&edge_pq);
567        }
568
569        // Cleanup node parquet files
570        let _ = std::fs::remove_file(&mod_pq);
571        let _ = std::fs::remove_file(&file_pq);
572        let _ = std::fs::remove_file(&sym_pq);
573        let _ = std::fs::remove_file(&stmt_pq);
574
575        Ok(())
576    }
577}