Skip to main content

infigraph_core/graph/
store_bulk.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result};
4use kuzu::Connection;
5
6use super::schema::ensure_custom_edge_table;
7use super::store::GraphStore;
8use super::store_util::escape;
9use crate::model::{FileExtraction, RelationKind};
10
11impl GraphStore {
12    /// Bulk insert all extractions in minimal queries -- one UNWIND per node/edge type.
13    /// Much faster than calling upsert_file_conn_no_delete per file.
14    /// Caller must hold WriteLock.
15    pub fn upsert_all_bulk(
16        &self,
17        conn: &Connection<'_>,
18        extractions: &[FileExtraction],
19    ) -> Result<()> {
20        if extractions.is_empty() {
21            return Ok(());
22        }
23
24        // 1. All Module nodes
25        let module_rows: Vec<String> = extractions
26            .iter()
27            .map(|e| {
28                let name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
29                format!(
30                    "{{id: '{}', name: '{}', file: '{}', language: '{}', content_hash: '{}'}}",
31                    escape(&e.file),
32                    escape(name),
33                    escape(&e.file),
34                    escape(&e.language),
35                    escape(&e.content_hash)
36                )
37            })
38            .collect();
39        conn.query(&format!("UNWIND [{}] AS m CREATE (:Module {{id: m.id, name: m.name, file: m.file, language: m.language, content_hash: m.content_hash}})", module_rows.join(", ")))
40            .context("bulk module insert")?;
41
42        // 2. All File nodes
43        let file_rows: Vec<String> = extractions
44            .iter()
45            .map(|e| {
46                let name = e.file.rsplit_once('/').map(|(_, f)| f).unwrap_or(&e.file);
47                format!(
48                    "{{id: '{}', name: '{}', path: '{}', language: '{}', symbol_count: {}}}",
49                    escape(&e.file),
50                    escape(name),
51                    escape(&e.file),
52                    escape(&e.language),
53                    e.symbols.len()
54                )
55            })
56            .collect();
57        conn.query(&format!("UNWIND [{}] AS f CREATE (:File {{id: f.id, name: f.name, path: f.path, language: f.language, symbol_count: f.symbol_count}})", file_rows.join(", ")))
58            .context("bulk file insert")?;
59
60        // 3. All Symbol nodes in chunks (query string size limit)
61        const SYM_CHUNK: usize = 2000;
62        let all_syms: Vec<String> = extractions.iter().flat_map(|e| {
63            e.symbols.iter().map(move |sym| format!(
64                "{{id: '{}', name: '{}', kind: '{}', file: '{}', start_line: {}, end_line: {}, signature_hash: '{}', language: '{}', visibility: '{}', parent: '{}', docstring: '{}', complexity: {}, parameters: '{}', return_type: '{}'}}",
65                escape(&sym.id), escape(&sym.name), sym.kind.as_str(), escape(&e.file),
66                sym.span.start_line, sym.span.end_line, escape(&sym.signature_hash),
67                escape(&sym.language), escape(sym.visibility.as_deref().unwrap_or("")),
68                escape(sym.parent.as_deref().unwrap_or("")),
69                escape(sym.docstring.as_deref().unwrap_or("")), sym.complexity,
70                escape(sym.parameters.as_deref().unwrap_or("")),
71                escape(sym.return_type.as_deref().unwrap_or(""))
72            ))
73        }).collect();
74        for chunk in all_syms.chunks(SYM_CHUNK) {
75            conn.query(&format!(
76                "UNWIND [{}] AS s CREATE (:Symbol {{id: s.id, name: s.name, kind: s.kind, file: s.file, start_line: s.start_line, end_line: s.end_line, signature_hash: s.signature_hash, language: s.language, visibility: s.visibility, parent: s.parent, docstring: s.docstring, complexity: s.complexity, parameters: s.parameters, return_type: s.return_type}})",
77                chunk.join(", ")
78            )).context("bulk symbol insert")?;
79        }
80
81        // 4. CONTAINS edges (module -> symbols) in chunks
82        let contains_pairs: Vec<String> = extractions
83            .iter()
84            .flat_map(|e| {
85                e.symbols.iter().map(move |sym| {
86                    format!("{{m: '{}', s: '{}'}}", escape(&e.file), escape(&sym.id))
87                })
88            })
89            .collect();
90        for chunk in contains_pairs.chunks(SYM_CHUNK) {
91            let _ = conn.query(&format!(
92                "UNWIND [{}] AS p MATCH (m:Module), (s:Symbol) WHERE m.id = p.m AND s.id = p.s CREATE (m)-[:CONTAINS]->(s)",
93                chunk.join(", ")
94            ));
95        }
96
97        // 5. DEFINES edges (file -> symbols) in chunks
98        let defines_pairs: Vec<String> = extractions
99            .iter()
100            .flat_map(|e| {
101                e.symbols.iter().map(move |sym| {
102                    format!("{{f: '{}', s: '{}'}}", escape(&e.file), escape(&sym.id))
103                })
104            })
105            .collect();
106        for chunk in defines_pairs.chunks(SYM_CHUNK) {
107            let _ = conn.query(&format!(
108                "UNWIND [{}] AS p MATCH (f:File), (s:Symbol) WHERE f.id = p.f AND s.id = p.s CREATE (f)-[:DEFINES]->(s)",
109                chunk.join(", ")
110            ));
111        }
112
113        // 6. All relation edges grouped by type
114        let mut calls_pairs: Vec<String> = Vec::new();
115        let mut inherits_pairs: Vec<String> = Vec::new();
116        let mut tested_by_pairs: Vec<String> = Vec::new();
117        let mut imports_pairs: Vec<String> = Vec::new();
118        let mut reads_pairs: Vec<String> = Vec::new();
119        let mut writes_pairs: Vec<String> = Vec::new();
120        let mut custom_pairs: HashMap<String, Vec<String>> = HashMap::new();
121        for e in extractions {
122            for rel in &e.relations {
123                let pair = format!(
124                    "{{a: '{}', b: '{}'}}",
125                    escape(&rel.source_id),
126                    escape(&rel.target_id)
127                );
128                match &rel.kind {
129                    RelationKind::Calls | RelationKind::CalledBy => calls_pairs.push(pair),
130                    RelationKind::Inherits | RelationKind::InheritedBy => inherits_pairs.push(pair),
131                    RelationKind::TestedBy | RelationKind::Tests => tested_by_pairs.push(pair),
132                    RelationKind::Imports | RelationKind::ImportedBy => imports_pairs.push(pair),
133                    RelationKind::Reads => reads_pairs.push(pair),
134                    RelationKind::Writes => writes_pairs.push(pair),
135                    RelationKind::Custom(name) => {
136                        custom_pairs.entry(name.clone()).or_default().push(pair);
137                    }
138                    _ => {}
139                }
140            }
141        }
142        for (pairs, rel_type) in [
143            (&calls_pairs, "CALLS"),
144            (&inherits_pairs, "INHERITS"),
145            (&tested_by_pairs, "TESTED_BY"),
146            (&reads_pairs, "READS"),
147            (&writes_pairs, "WRITES"),
148        ] {
149            for chunk in pairs.chunks(SYM_CHUNK) {
150                let _ = conn.query(&format!(
151                    "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:{rel_type}]->(b)",
152                    chunk.join(", ")
153                ));
154            }
155        }
156        for chunk in imports_pairs.chunks(SYM_CHUNK) {
157            let _ = conn.query(&format!(
158                "UNWIND [{}] AS p MATCH (a:Module), (b:Module) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:IMPORTS]->(b)",
159                chunk.join(", ")
160            ));
161        }
162        for (edge_name, pairs) in &custom_pairs {
163            if pairs.is_empty() {
164                continue;
165            }
166            let _ = ensure_custom_edge_table(conn, edge_name);
167            for chunk in pairs.chunks(SYM_CHUNK) {
168                let _ = conn.query(&format!(
169                    "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:{}]->(b)",
170                    chunk.join(", "),
171                    edge_name
172                ));
173            }
174        }
175
176        // Statement nodes + HAS_STATEMENT edges
177        let all_stmts: Vec<String> = extractions.iter().flat_map(|e| {
178            e.statements.iter().map(|s| format!(
179                "{{id: '{}', kind: '{}', condition: '{}', start_line: {}, end_line: {}, depth: {}, parent_symbol: '{}'}}",
180                escape(&s.id), s.kind.as_str(), escape(&s.condition),
181                s.start_line, s.end_line, s.depth, escape(&s.parent_symbol)
182            ))
183        }).collect();
184        for chunk in all_stmts.chunks(SYM_CHUNK) {
185            let _ = conn.query(&format!(
186                "UNWIND [{}] AS s CREATE (:Statement {{id: s.id, kind: s.kind, condition: s.condition, start_line: s.start_line, end_line: s.end_line, depth: s.depth, parent_symbol: s.parent_symbol}})",
187                chunk.join(", ")
188            ));
189        }
190        let stmt_edges: Vec<String> = extractions
191            .iter()
192            .flat_map(|e| {
193                e.statements.iter().map(|s| {
194                    format!(
195                        "{{a: '{}', b: '{}'}}",
196                        escape(&s.parent_symbol),
197                        escape(&s.id)
198                    )
199                })
200            })
201            .collect();
202        for chunk in stmt_edges.chunks(SYM_CHUNK) {
203            let _ = conn.query(&format!(
204                "UNWIND [{}] AS p MATCH (a:Symbol), (b:Statement) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:HAS_STATEMENT]->(b)",
205                chunk.join(", ")
206            ));
207        }
208
209        Ok(())
210    }
211}