Skip to main content

infigraph_core/graph/
store_bulk.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result};
4use kuzu::Connection;
5
6use super::schema::ensure_custom_edge_table;
7use super::store::GraphStore;
8use super::store_util::escape;
9use crate::model::{FileExtraction, RelationKind};
10
11impl GraphStore {
12    /// Bulk insert all extractions in minimal queries -- one UNWIND per node/edge type.
13    /// Much faster than calling upsert_file_conn_no_delete per file.
14    pub fn upsert_all_bulk(
15        &self,
16        conn: &Connection<'_>,
17        extractions: &[FileExtraction],
18    ) -> Result<()> {
19        if extractions.is_empty() {
20            return Ok(());
21        }
22
23        // 1. All Module nodes
24        let module_rows: Vec<String> = extractions
25            .iter()
26            .map(|e| {
27                let name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
28                format!(
29                    "{{id: '{}', name: '{}', file: '{}', language: '{}', content_hash: '{}'}}",
30                    escape(&e.file),
31                    escape(name),
32                    escape(&e.file),
33                    escape(&e.language),
34                    escape(&e.content_hash)
35                )
36            })
37            .collect();
38        conn.query(&format!("UNWIND [{}] AS m CREATE (:Module {{id: m.id, name: m.name, file: m.file, language: m.language, content_hash: m.content_hash}})", module_rows.join(", ")))
39            .context("bulk module insert")?;
40
41        // 2. All File nodes
42        let file_rows: Vec<String> = extractions
43            .iter()
44            .map(|e| {
45                let name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
46                format!(
47                    "{{id: '{}', name: '{}', path: '{}', language: '{}', symbol_count: {}}}",
48                    escape(&e.file),
49                    escape(name),
50                    escape(&e.file),
51                    escape(&e.language),
52                    e.symbols.len()
53                )
54            })
55            .collect();
56        conn.query(&format!("UNWIND [{}] AS f CREATE (:File {{id: f.id, name: f.name, path: f.path, language: f.language, symbol_count: f.symbol_count}})", file_rows.join(", ")))
57            .context("bulk file insert")?;
58
59        // 3. All Symbol nodes in chunks (query string size limit)
60        const SYM_CHUNK: usize = 2000;
61        let all_syms: Vec<String> = extractions.iter().flat_map(|e| {
62            e.symbols.iter().map(move |sym| format!(
63                "{{id: '{}', name: '{}', kind: '{}', file: '{}', start_line: {}, end_line: {}, signature_hash: '{}', language: '{}', visibility: '{}', parent: '{}', docstring: '{}', complexity: {}, parameters: '{}', return_type: '{}'}}",
64                escape(&sym.id), escape(&sym.name), sym.kind.as_str(), escape(&e.file),
65                sym.span.start_line, sym.span.end_line, escape(&sym.signature_hash),
66                escape(&sym.language), escape(sym.visibility.as_deref().unwrap_or("")),
67                escape(sym.parent.as_deref().unwrap_or("")),
68                escape(sym.docstring.as_deref().unwrap_or("")), sym.complexity,
69                escape(sym.parameters.as_deref().unwrap_or("")),
70                escape(sym.return_type.as_deref().unwrap_or(""))
71            ))
72        }).collect();
73        for chunk in all_syms.chunks(SYM_CHUNK) {
74            conn.query(&format!(
75                "UNWIND [{}] AS s CREATE (:Symbol {{id: s.id, name: s.name, kind: s.kind, file: s.file, start_line: s.start_line, end_line: s.end_line, signature_hash: s.signature_hash, language: s.language, visibility: s.visibility, parent: s.parent, docstring: s.docstring, complexity: s.complexity, parameters: s.parameters, return_type: s.return_type}})",
76                chunk.join(", ")
77            )).context("bulk symbol insert")?;
78        }
79
80        // 4. CONTAINS edges (module -> symbols) in chunks
81        let contains_pairs: Vec<String> = extractions
82            .iter()
83            .flat_map(|e| {
84                e.symbols.iter().map(move |sym| {
85                    format!("{{m: '{}', s: '{}'}}", escape(&e.file), escape(&sym.id))
86                })
87            })
88            .collect();
89        for chunk in contains_pairs.chunks(SYM_CHUNK) {
90            let _ = conn.query(&format!(
91                "UNWIND [{}] AS p MATCH (m:Module), (s:Symbol) WHERE m.id = p.m AND s.id = p.s CREATE (m)-[:CONTAINS]->(s)",
92                chunk.join(", ")
93            ));
94        }
95
96        // 5. DEFINES edges (file -> symbols) in chunks
97        let defines_pairs: Vec<String> = extractions
98            .iter()
99            .flat_map(|e| {
100                e.symbols.iter().map(move |sym| {
101                    format!("{{f: '{}', s: '{}'}}", escape(&e.file), escape(&sym.id))
102                })
103            })
104            .collect();
105        for chunk in defines_pairs.chunks(SYM_CHUNK) {
106            let _ = conn.query(&format!(
107                "UNWIND [{}] AS p MATCH (f:File), (s:Symbol) WHERE f.id = p.f AND s.id = p.s CREATE (f)-[:DEFINES]->(s)",
108                chunk.join(", ")
109            ));
110        }
111
112        // 6. All relation edges grouped by type
113        let mut calls_pairs: Vec<String> = Vec::new();
114        let mut inherits_pairs: Vec<String> = Vec::new();
115        let mut tested_by_pairs: Vec<String> = Vec::new();
116        let mut imports_pairs: Vec<String> = Vec::new();
117        let mut reads_pairs: Vec<String> = Vec::new();
118        let mut writes_pairs: Vec<String> = Vec::new();
119        let mut custom_pairs: HashMap<String, Vec<String>> = HashMap::new();
120        for e in extractions {
121            for rel in &e.relations {
122                let pair = format!(
123                    "{{a: '{}', b: '{}'}}",
124                    escape(&rel.source_id),
125                    escape(&rel.target_id)
126                );
127                match &rel.kind {
128                    RelationKind::Calls | RelationKind::CalledBy => calls_pairs.push(pair),
129                    RelationKind::Inherits | RelationKind::InheritedBy => inherits_pairs.push(pair),
130                    RelationKind::TestedBy | RelationKind::Tests => tested_by_pairs.push(pair),
131                    RelationKind::Imports | RelationKind::ImportedBy => imports_pairs.push(pair),
132                    RelationKind::Reads => reads_pairs.push(pair),
133                    RelationKind::Writes => writes_pairs.push(pair),
134                    RelationKind::Custom(name) => {
135                        custom_pairs.entry(name.clone()).or_default().push(pair);
136                    }
137                    _ => {}
138                }
139            }
140        }
141        for (pairs, rel_type) in [
142            (&calls_pairs, "CALLS"),
143            (&inherits_pairs, "INHERITS"),
144            (&tested_by_pairs, "TESTED_BY"),
145            (&reads_pairs, "READS"),
146            (&writes_pairs, "WRITES"),
147        ] {
148            for chunk in pairs.chunks(SYM_CHUNK) {
149                let _ = conn.query(&format!(
150                    "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:{rel_type}]->(b)",
151                    chunk.join(", ")
152                ));
153            }
154        }
155        for chunk in imports_pairs.chunks(SYM_CHUNK) {
156            let _ = conn.query(&format!(
157                "UNWIND [{}] AS p MATCH (a:Module), (b:Module) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:IMPORTS]->(b)",
158                chunk.join(", ")
159            ));
160        }
161        for (edge_name, pairs) in &custom_pairs {
162            if pairs.is_empty() {
163                continue;
164            }
165            let _ = ensure_custom_edge_table(conn, edge_name);
166            for chunk in pairs.chunks(SYM_CHUNK) {
167                let _ = conn.query(&format!(
168                    "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:{}]->(b)",
169                    chunk.join(", "),
170                    edge_name
171                ));
172            }
173        }
174
175        // Statement nodes + HAS_STATEMENT edges
176        let all_stmts: Vec<String> = extractions.iter().flat_map(|e| {
177            e.statements.iter().map(|s| format!(
178                "{{id: '{}', kind: '{}', condition: '{}', start_line: {}, end_line: {}, depth: {}, parent_symbol: '{}'}}",
179                escape(&s.id), s.kind.as_str(), escape(&s.condition),
180                s.start_line, s.end_line, s.depth, escape(&s.parent_symbol)
181            ))
182        }).collect();
183        for chunk in all_stmts.chunks(SYM_CHUNK) {
184            let _ = conn.query(&format!(
185                "UNWIND [{}] AS s CREATE (:Statement {{id: s.id, kind: s.kind, condition: s.condition, start_line: s.start_line, end_line: s.end_line, depth: s.depth, parent_symbol: s.parent_symbol}})",
186                chunk.join(", ")
187            ));
188        }
189        let stmt_edges: Vec<String> = extractions.iter().flat_map(|e| {
190            e.statements.iter().map(|s| format!("{{a: '{}', b: '{}'}}", escape(&s.parent_symbol), escape(&s.id)))
191        }).collect();
192        for chunk in stmt_edges.chunks(SYM_CHUNK) {
193            let _ = conn.query(&format!(
194                "UNWIND [{}] AS p MATCH (a:Symbol), (b:Statement) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:HAS_STATEMENT]->(b)",
195                chunk.join(", ")
196            ));
197        }
198
199        Ok(())
200    }
201}