Skip to main content

infigraph_core/graph/
store_parquet.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use arrow::array::{Int64Array, StringArray};
5use arrow::datatypes::DataType;
6use kuzu::Connection;
7
8use super::parquet_loader;
9use super::store::GraphStore;
10use super::store_util::{escape, fwd_slash_path, unwind_edges_from_pairs};
11use crate::model::{FileExtraction, RelationKind};
12
13impl GraphStore {
14    /// Create Folder nodes and edges for a set of file paths in bulk.
15    /// More efficient than per-file upsert_folder_hierarchy calls.
16    pub fn upsert_folders_bulk(&self, file_paths: &[&str]) -> Result<()> {
17        let conn = self.connection()?;
18        self.upsert_folders_bulk_conn(&conn, file_paths)
19    }
20
21    pub fn upsert_folders_bulk_conn(
22        &self,
23        conn: &Connection<'_>,
24        file_paths: &[&str],
25    ) -> Result<()> {
26        let mut all_folders: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
27        for file_path in file_paths {
28            let parts: Vec<&str> = file_path.rsplitn(2, '/').collect();
29            if parts.len() < 2 {
30                continue;
31            }
32            let dir_path = parts[1];
33            let segments: Vec<&str> = dir_path.split('/').collect();
34            for i in 0..segments.len() {
35                all_folders.insert(segments[..=i].join("/"));
36            }
37        }
38
39        if all_folders.is_empty() {
40            return Ok(());
41        }
42
43        // Write Folder nodes to parquet
44        let folder_pq = std::env::temp_dir().join("infigraph_folders.parquet");
45        {
46            let ids: Vec<&str> = all_folders.iter().map(|s| s.as_str()).collect();
47            let names: Vec<&str> = all_folders
48                .iter()
49                .map(|fp| fp.rsplit_once('/').map(|(_, n)| n).unwrap_or(fp.as_str()))
50                .collect();
51            let paths: Vec<&str> = all_folders.iter().map(|s| s.as_str()).collect();
52            parquet_loader::write_node_parquet(
53                &folder_pq,
54                &[
55                    ("id", DataType::Utf8),
56                    ("name", DataType::Utf8),
57                    ("path", DataType::Utf8),
58                ],
59                vec![
60                    Arc::new(StringArray::from(ids)),
61                    Arc::new(StringArray::from(names)),
62                    Arc::new(StringArray::from(paths)),
63                ],
64            )?;
65        }
66
67        // Collect edge pairs in memory
68        let cf_pairs: Vec<(String, String)> = all_folders
69            .iter()
70            .filter_map(|child| {
71                child
72                    .rsplit_once('/')
73                    .map(|(p, _)| p)
74                    .and_then(|parent_path| {
75                        if all_folders.contains(parent_path) {
76                            Some((parent_path.to_string(), child.clone()))
77                        } else {
78                            None
79                        }
80                    })
81            })
82            .collect();
83
84        let cfile_pairs: Vec<(String, String)> = file_paths
85            .iter()
86            .filter_map(|fp| {
87                let parts: Vec<&str> = fp.rsplitn(2, '/').collect();
88                if parts.len() < 2 {
89                    return None;
90                }
91                Some((parts[1].to_string(), fp.to_string()))
92            })
93            .collect();
94
95        let copy_ok = conn
96            .query(&format!(
97                "COPY Folder FROM '{}'",
98                fwd_slash_path(&folder_pq)
99            ))
100            .is_ok();
101
102        if copy_ok {
103            // Write edge parquet files and COPY FROM
104            let cf_pq = std::env::temp_dir().join("infigraph_contains_folder.parquet");
105            let cf_refs: Vec<(&str, &str)> = cf_pairs
106                .iter()
107                .map(|(a, b)| (a.as_str(), b.as_str()))
108                .collect();
109            parquet_loader::write_edge_parquet(&cf_pq, &cf_refs)?;
110            if let Err(e) = conn.query(&format!(
111                "COPY CONTAINS_FOLDER FROM '{}'",
112                fwd_slash_path(&cf_pq)
113            )) {
114                eprintln!("warn: COPY CONTAINS_FOLDER failed ({e}), using UNWIND fallback");
115                unwind_edges_from_pairs(conn, &cf_refs, "CONTAINS_FOLDER", "Folder", "Folder");
116            }
117            let _ = std::fs::remove_file(&cf_pq);
118
119            let cfile_pq = std::env::temp_dir().join("infigraph_contains_file.parquet");
120            let cfile_refs: Vec<(&str, &str)> = cfile_pairs
121                .iter()
122                .map(|(a, b)| (a.as_str(), b.as_str()))
123                .collect();
124            parquet_loader::write_edge_parquet(&cfile_pq, &cfile_refs)?;
125            if let Err(e) = conn.query(&format!(
126                "COPY CONTAINS_FILE FROM '{}'",
127                fwd_slash_path(&cfile_pq)
128            )) {
129                eprintln!("warn: COPY CONTAINS_FILE failed ({e}), using UNWIND fallback");
130                unwind_edges_from_pairs(conn, &cfile_refs, "CONTAINS_FILE", "Folder", "File");
131            }
132            let _ = std::fs::remove_file(&cfile_pq);
133        } else {
134            // Incremental path: some folders may already exist. Use UNWIND with MERGE semantics.
135            const CHUNK: usize = 500;
136            for chunk in all_folders.iter().collect::<Vec<_>>().chunks(CHUNK) {
137                let items: Vec<String> = chunk
138                    .iter()
139                    .map(|fp| {
140                        let name = fp.rsplit_once('/').map(|(_, n)| n).unwrap_or(fp);
141                        format!(
142                            "{{id: '{}', name: '{}', path: '{}'}}",
143                            escape(fp),
144                            escape(name),
145                            escape(fp)
146                        )
147                    })
148                    .collect();
149                let _ = conn.query(&format!(
150                    "UNWIND [{}] AS f MERGE (d:Folder {{id: f.id}}) ON CREATE SET d.name = f.name, d.path = f.path ON MATCH SET d.name = f.name, d.path = f.path",
151                    items.join(", ")
152                ));
153            }
154            let cf_refs: Vec<(&str, &str)> = cf_pairs
155                .iter()
156                .map(|(a, b)| (a.as_str(), b.as_str()))
157                .collect();
158            unwind_edges_from_pairs(conn, &cf_refs, "CONTAINS_FOLDER", "Folder", "Folder");
159            let cfile_refs: Vec<(&str, &str)> = cfile_pairs
160                .iter()
161                .map(|(a, b)| (a.as_str(), b.as_str()))
162                .collect();
163            unwind_edges_from_pairs(conn, &cfile_refs, "CONTAINS_FILE", "Folder", "File");
164        }
165
166        let _ = std::fs::remove_file(&folder_pq);
167        Ok(())
168    }
169
170    /// Bulk write all extractions using COPY FROM Parquet -- binary format eliminates escaping issues.
171    /// Used for --full index. Incremental index still uses upsert_file_conn_no_delete.
172    pub fn upsert_all_parquet(&self, extractions: &[FileExtraction]) -> Result<()> {
173        if extractions.is_empty() {
174            return Ok(());
175        }
176
177        let conn = self.connection()?;
178        let tmp = std::env::temp_dir();
179
180        let mut known_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
181        for e in extractions {
182            for sym in &e.symbols {
183                known_ids.insert(sym.id.clone());
184            }
185        }
186        let mut sym_seen: std::collections::HashSet<String> = std::collections::HashSet::new();
187        let known_module_ids: std::collections::HashSet<String> =
188            extractions.iter().map(|e| e.file.clone()).collect();
189
190        // Collect all data into vecs
191        let mut mod_ids = Vec::new();
192        let mut mod_names = Vec::new();
193        let mut mod_files = Vec::new();
194        let mut mod_langs = Vec::new();
195        let mut mod_hashes = Vec::new();
196        let mut mod_summaries = Vec::new();
197        let mut file_ids = Vec::new();
198        let mut file_names = Vec::new();
199        let mut file_paths = Vec::new();
200        let mut file_langs = Vec::new();
201        let mut file_symcounts: Vec<i64> = Vec::new();
202        let mut sym_ids = Vec::new();
203        let mut sym_names = Vec::new();
204        let mut sym_kinds = Vec::new();
205        let mut sym_files = Vec::new();
206        let mut sym_slines: Vec<i64> = Vec::new();
207        let mut sym_elines: Vec<i64> = Vec::new();
208        let mut sym_sighashes = Vec::new();
209        let mut sym_languages = Vec::new();
210        let mut sym_visibilities = Vec::new();
211        let mut sym_parents = Vec::new();
212        let mut sym_docstrings = Vec::new();
213        let mut sym_complexities: Vec<i64> = Vec::new();
214        let mut sym_parameters = Vec::new();
215        let mut sym_return_types = Vec::new();
216        let mut contains_pairs: Vec<(String, String)> = Vec::new();
217        let mut defines_pairs: Vec<(String, String)> = Vec::new();
218
219        let mut calls_seen: std::collections::HashSet<(String, String)> =
220            std::collections::HashSet::new();
221        let mut inh_seen: std::collections::HashSet<(String, String)> =
222            std::collections::HashSet::new();
223        let mut test_seen: std::collections::HashSet<(String, String)> =
224            std::collections::HashSet::new();
225        let mut imp_seen: std::collections::HashSet<(String, String)> =
226            std::collections::HashSet::new();
227        let mut reads_seen: std::collections::HashSet<(String, String)> =
228            std::collections::HashSet::new();
229        let mut writes_seen: std::collections::HashSet<(String, String)> =
230            std::collections::HashSet::new();
231        let mut calls_pairs: Vec<(String, String)> = Vec::new();
232        let mut inh_pairs: Vec<(String, String)> = Vec::new();
233        let mut test_pairs: Vec<(String, String)> = Vec::new();
234        let mut imp_pairs: Vec<(String, String)> = Vec::new();
235        let mut reads_pairs: Vec<(String, String)> = Vec::new();
236        let mut writes_pairs: Vec<(String, String)> = Vec::new();
237        let mut custom_seen: std::collections::HashMap<
238            String,
239            std::collections::HashSet<(String, String)>,
240        > = std::collections::HashMap::new();
241        let mut custom_pairs: std::collections::HashMap<String, Vec<(String, String)>> =
242            std::collections::HashMap::new();
243
244        let mut stmt_ids: Vec<String> = Vec::new();
245        let mut stmt_kinds: Vec<String> = Vec::new();
246        let mut stmt_conditions: Vec<String> = Vec::new();
247        let mut stmt_slines: Vec<i64> = Vec::new();
248        let mut stmt_elines: Vec<i64> = Vec::new();
249        let mut stmt_depths: Vec<i64> = Vec::new();
250        let mut stmt_parents_sym = Vec::new();
251        let mut has_stmt_pairs: Vec<(String, String)> = Vec::new();
252
253        for e in extractions {
254            let mod_name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
255            mod_ids.push(e.file.clone());
256            mod_names.push(mod_name.to_string());
257            mod_files.push(e.file.clone());
258            mod_langs.push(e.language.clone());
259            mod_hashes.push(e.content_hash.clone());
260            mod_summaries.push(String::new());
261
262            file_ids.push(e.file.clone());
263            file_names.push(mod_name.to_string());
264            file_paths.push(e.file.clone());
265            file_langs.push(e.language.clone());
266            file_symcounts.push(e.symbols.len() as i64);
267
268            for sym in &e.symbols {
269                if sym_seen.insert(sym.id.clone()) {
270                    sym_ids.push(sym.id.clone());
271                    sym_names.push(sym.name.clone());
272                    sym_kinds.push(sym.kind.as_str().to_string());
273                    sym_files.push(e.file.clone());
274                    sym_slines.push(sym.span.start_line as i64);
275                    sym_elines.push(sym.span.end_line as i64);
276                    sym_sighashes.push(sym.signature_hash.clone());
277                    sym_languages.push(sym.language.clone());
278                    sym_visibilities.push(sym.visibility.as_deref().unwrap_or("").to_string());
279                    sym_parents.push(sym.parent.as_deref().unwrap_or("").to_string());
280                    sym_docstrings.push(sym.docstring.as_deref().unwrap_or("").to_string());
281                    sym_complexities.push(sym.complexity as i64);
282                    sym_parameters.push(sym.parameters.as_deref().unwrap_or("").to_string());
283                    sym_return_types.push(sym.return_type.as_deref().unwrap_or("").to_string());
284                    contains_pairs.push((e.file.clone(), sym.id.clone()));
285                    defines_pairs.push((e.file.clone(), sym.id.clone()));
286                }
287            }
288
289            for rel in &e.relations {
290                let src = rel.source_id.clone();
291                let tgt = rel.target_id.clone();
292                match &rel.kind {
293                    RelationKind::Imports | RelationKind::ImportedBy => {
294                        if known_module_ids.contains(&src)
295                            && known_module_ids.contains(&tgt)
296                            && imp_seen.insert((src.clone(), tgt.clone()))
297                        {
298                            imp_pairs.push((src, tgt));
299                        }
300                    }
301                    RelationKind::Custom(name) => {
302                        if known_ids.contains(&src)
303                            && known_ids.contains(&tgt)
304                            && custom_seen
305                                .entry(name.clone())
306                                .or_default()
307                                .insert((src.clone(), tgt.clone()))
308                        {
309                            custom_pairs
310                                .entry(name.clone())
311                                .or_default()
312                                .push((src, tgt));
313                        }
314                    }
315                    _ => {
316                        if !known_ids.contains(&src) || !known_ids.contains(&tgt) {
317                            continue;
318                        }
319                        match &rel.kind {
320                            RelationKind::Calls | RelationKind::CalledBy
321                                if calls_seen.insert((src.clone(), tgt.clone())) =>
322                            {
323                                calls_pairs.push((src, tgt));
324                            }
325                            RelationKind::Inherits | RelationKind::InheritedBy
326                                if inh_seen.insert((src.clone(), tgt.clone())) =>
327                            {
328                                inh_pairs.push((src, tgt));
329                            }
330                            RelationKind::TestedBy | RelationKind::Tests
331                                if test_seen.insert((src.clone(), tgt.clone())) =>
332                            {
333                                test_pairs.push((src, tgt));
334                            }
335                            RelationKind::Reads
336                                if reads_seen.insert((src.clone(), tgt.clone())) =>
337                            {
338                                reads_pairs.push((src, tgt));
339                            }
340                            RelationKind::Writes
341                                if writes_seen.insert((src.clone(), tgt.clone())) =>
342                            {
343                                writes_pairs.push((src, tgt));
344                            }
345                            _ => {}
346                        }
347                    }
348                }
349            }
350
351            for stmt in &e.statements {
352                stmt_ids.push(stmt.id.clone());
353                stmt_kinds.push(stmt.kind.as_str().to_string());
354                stmt_conditions.push(stmt.condition.clone());
355                stmt_slines.push(stmt.start_line as i64);
356                stmt_elines.push(stmt.end_line as i64);
357                stmt_depths.push(stmt.depth as i64);
358                stmt_parents_sym.push(stmt.parent_symbol.clone());
359                if known_ids.contains(&stmt.parent_symbol) {
360                    has_stmt_pairs.push((stmt.parent_symbol.clone(), stmt.id.clone()));
361                }
362            }
363        }
364
365        // Write node parquet files
366        let mod_pq = tmp.join("infigraph_index_modules.parquet");
367        parquet_loader::write_node_parquet(
368            &mod_pq,
369            &[
370                ("id", DataType::Utf8),
371                ("name", DataType::Utf8),
372                ("file", DataType::Utf8),
373                ("language", DataType::Utf8),
374                ("content_hash", DataType::Utf8),
375                ("summary", DataType::Utf8),
376            ],
377            vec![
378                Arc::new(StringArray::from(mod_ids)),
379                Arc::new(StringArray::from(mod_names)),
380                Arc::new(StringArray::from(mod_files)),
381                Arc::new(StringArray::from(mod_langs)),
382                Arc::new(StringArray::from(mod_hashes)),
383                Arc::new(StringArray::from(mod_summaries)),
384            ],
385        )?;
386
387        let file_pq = tmp.join("infigraph_index_files.parquet");
388        parquet_loader::write_node_parquet(
389            &file_pq,
390            &[
391                ("id", DataType::Utf8),
392                ("name", DataType::Utf8),
393                ("path", DataType::Utf8),
394                ("language", DataType::Utf8),
395                ("symbol_count", DataType::Int64),
396            ],
397            vec![
398                Arc::new(StringArray::from(file_ids)),
399                Arc::new(StringArray::from(file_names)),
400                Arc::new(StringArray::from(file_paths)),
401                Arc::new(StringArray::from(file_langs)),
402                Arc::new(Int64Array::from(file_symcounts)),
403            ],
404        )?;
405
406        let sym_pq = tmp.join("infigraph_index_symbols.parquet");
407        parquet_loader::write_node_parquet(
408            &sym_pq,
409            &[
410                ("id", DataType::Utf8),
411                ("name", DataType::Utf8),
412                ("kind", DataType::Utf8),
413                ("file", DataType::Utf8),
414                ("start_line", DataType::Int64),
415                ("end_line", DataType::Int64),
416                ("signature_hash", DataType::Utf8),
417                ("language", DataType::Utf8),
418                ("visibility", DataType::Utf8),
419                ("parent", DataType::Utf8),
420                ("docstring", DataType::Utf8),
421                ("complexity", DataType::Int64),
422                ("parameters", DataType::Utf8),
423                ("return_type", DataType::Utf8),
424            ],
425            vec![
426                Arc::new(StringArray::from(sym_ids)),
427                Arc::new(StringArray::from(sym_names)),
428                Arc::new(StringArray::from(sym_kinds)),
429                Arc::new(StringArray::from(sym_files)),
430                Arc::new(Int64Array::from(sym_slines)),
431                Arc::new(Int64Array::from(sym_elines)),
432                Arc::new(StringArray::from(sym_sighashes)),
433                Arc::new(StringArray::from(sym_languages)),
434                Arc::new(StringArray::from(sym_visibilities)),
435                Arc::new(StringArray::from(sym_parents)),
436                Arc::new(StringArray::from(sym_docstrings)),
437                Arc::new(Int64Array::from(sym_complexities)),
438                Arc::new(StringArray::from(sym_parameters)),
439                Arc::new(StringArray::from(sym_return_types)),
440            ],
441        )?;
442
443        // COPY FROM parquet -- node tables first
444        conn.query(&format!("COPY Module FROM '{}'", fwd_slash_path(&mod_pq)))
445            .map_err(|e| anyhow::anyhow!("COPY Module failed: {e}"))?;
446        conn.query(&format!("COPY File FROM '{}'", fwd_slash_path(&file_pq)))
447            .map_err(|e| anyhow::anyhow!("COPY File failed: {e}"))?;
448        conn.query(&format!(
449            "COPY Symbol (id, name, kind, file, start_line, end_line, signature_hash, language, visibility, parent, docstring, complexity, parameters, return_type) FROM '{}'",
450            fwd_slash_path(&sym_pq)
451        )).map_err(|e| anyhow::anyhow!("COPY Symbol failed: {e}"))?;
452
453        let stmt_pq = tmp.join("infigraph_index_statements.parquet");
454        if !stmt_ids.is_empty() {
455            parquet_loader::write_node_parquet(&stmt_pq, &[
456                ("id", DataType::Utf8), ("kind", DataType::Utf8), ("condition", DataType::Utf8),
457                ("start_line", DataType::Int64), ("end_line", DataType::Int64),
458                ("depth", DataType::Int64), ("parent_symbol", DataType::Utf8),
459            ], vec![
460                Arc::new(StringArray::from(stmt_ids)), Arc::new(StringArray::from(stmt_kinds)),
461                Arc::new(StringArray::from(stmt_conditions)),
462                Arc::new(Int64Array::from(stmt_slines)), Arc::new(Int64Array::from(stmt_elines)),
463                Arc::new(Int64Array::from(stmt_depths)), Arc::new(StringArray::from(stmt_parents_sym)),
464            ])?;
465            conn.query(&format!("COPY Statement FROM '{}'", fwd_slash_path(&stmt_pq)))
466                .map_err(|e| anyhow::anyhow!("COPY Statement failed: {e}"))?;
467        }
468
469        // Edge tables -- write parquet and COPY FROM with in-memory UNWIND fallback
470        #[allow(clippy::type_complexity)]
471        let edge_tables: Vec<(&str, &[(String, String)], &str, &str)> = vec![
472            ("CONTAINS", &contains_pairs, "Module", "Symbol"),
473            ("DEFINES", &defines_pairs, "File", "Symbol"),
474            ("CALLS", &calls_pairs, "Symbol", "Symbol"),
475            ("INHERITS", &inh_pairs, "Symbol", "Symbol"),
476            ("TESTED_BY", &test_pairs, "Symbol", "Symbol"),
477            ("IMPORTS", &imp_pairs, "Module", "Module"),
478            ("READS", &reads_pairs, "Symbol", "Symbol"),
479            ("WRITES", &writes_pairs, "Symbol", "Symbol"),
480            ("HAS_STATEMENT", &has_stmt_pairs, "Symbol", "Statement"),
481        ];
482
483        for (table, pairs, src_label, dst_label) in &edge_tables {
484            if pairs.is_empty() {
485                continue;
486            }
487            let edge_pq = tmp.join(format!("infigraph_index_{}.parquet", table.to_lowercase()));
488            let refs: Vec<(&str, &str)> = pairs
489                .iter()
490                .map(|(a, b)| (a.as_str(), b.as_str()))
491                .collect();
492            parquet_loader::write_edge_parquet(&edge_pq, &refs)?;
493            if let Err(e) = conn.query(&format!("COPY {table} FROM '{}'", fwd_slash_path(&edge_pq)))
494            {
495                eprintln!("warn: COPY {table} via parquet failed ({e}), falling back to UNWIND");
496                unwind_edges_from_pairs(&conn, &refs, table, src_label, dst_label);
497            }
498            let _ = std::fs::remove_file(&edge_pq);
499        }
500
501        // Custom edge tables
502        for (edge_name, pairs) in &custom_pairs {
503            if pairs.is_empty() {
504                continue;
505            }
506            let _ = super::schema::ensure_custom_edge_table(&conn, edge_name);
507            let edge_pq = tmp.join(format!(
508                "infigraph_index_{}.parquet",
509                edge_name.to_lowercase()
510            ));
511            let refs: Vec<(&str, &str)> = pairs
512                .iter()
513                .map(|(a, b)| (a.as_str(), b.as_str()))
514                .collect();
515            parquet_loader::write_edge_parquet(&edge_pq, &refs)?;
516            if let Err(e) = conn.query(&format!(
517                "COPY {} FROM '{}'",
518                edge_name,
519                fwd_slash_path(&edge_pq)
520            )) {
521                eprintln!(
522                    "warn: COPY {} via parquet failed ({e}), falling back to UNWIND",
523                    edge_name
524                );
525                unwind_edges_from_pairs(&conn, &refs, edge_name, "Symbol", "Symbol");
526            }
527            let _ = std::fs::remove_file(&edge_pq);
528        }
529
530        // Cleanup node parquet files
531        let _ = std::fs::remove_file(&mod_pq);
532        let _ = std::fs::remove_file(&file_pq);
533        let _ = std::fs::remove_file(&sym_pq);
534        let _ = std::fs::remove_file(&stmt_pq);
535
536        Ok(())
537    }
538}