Skip to main content

infigraph_core/scip/
mod.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use arrow::array::{Int64Array, StringArray};
7use arrow::datatypes::DataType;
8use protobuf::Message;
9use scip::types::{symbol_information, Index, SymbolRole};
10
11use crate::graph::parquet_loader;
12use crate::graph::store_util::{escape, fwd_slash_path, unwind_edges_from_pairs};
13use crate::graph::GraphStore;
14use crate::model::{Span, SymbolKind};
15
16/// Import a SCIP index.scip file into the Infigraph graph store.
17///
18/// Matches SCIP definitions to existing tree-sitter symbols by (file, name)
19/// and enriches them with compiler-grade type information. Builds cross-file
20/// CALLS edges from SCIP references using an in-memory symbol map for speed.
21pub fn import_scip_index(
22    index_path: &Path,
23    store: &GraphStore,
24    project_root: Option<&Path>,
25) -> Result<ImportStats> {
26    let bytes = std::fs::read(index_path)
27        .with_context(|| format!("failed to read {}", index_path.display()))?;
28
29    let index = Index::parse_from_bytes(&bytes)
30        .with_context(|| format!("failed to parse SCIP index: {}", index_path.display()))?;
31
32    let mut stats = ImportStats::default();
33    let conn = store.connection()?;
34
35    // Load learned pattern store for recording SCIP corrections
36    let mut learned_store = project_root
37        .map(crate::learned::LearnedStore::load)
38        .unwrap_or_default();
39
40    // Pre-load existing CALLS edges from tree-sitter resolution.
41    // Used to detect when SCIP resolves differently (= a correction to learn from).
42    let mut existing_calls: HashMap<String, std::collections::HashSet<String>> = HashMap::new();
43    if project_root.is_some() {
44        if let Ok(rows) = conn.query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id") {
45            for row in rows {
46                if row.len() < 2 {
47                    continue;
48                }
49                let src = row[0].to_string().trim_matches('"').to_string();
50                let tgt = row[1].to_string().trim_matches('"').to_string();
51                existing_calls.entry(src).or_default().insert(tgt);
52            }
53        }
54    }
55
56    // Pre-load all symbols from graph into memory: (file, name) -> Vec<symbol_id>
57    // and file -> sorted Vec<(start_line, end_line, symbol_id)> for containment lookup
58    let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
59    let mut file_symbols: HashMap<String, Vec<(u32, u32, String)>> = HashMap::new();
60
61    let q = "MATCH (s:Symbol) RETURN s.id, s.file, s.name, s.start_line, s.end_line";
62    if let Ok(rows) = conn.query(q) {
63        for row in rows {
64            if row.len() < 5 {
65                continue;
66            }
67            let sid = row[0].to_string().trim_matches('"').to_string();
68            let sfile = row[1].to_string().trim_matches('"').to_string();
69            let sname = row[2].to_string().trim_matches('"').to_string();
70            let sstart: u32 = row[3].to_string().trim_matches('"').parse().unwrap_or(0);
71            let send: u32 = row[4].to_string().trim_matches('"').parse().unwrap_or(0);
72
73            file_name_to_ids
74                .entry((sfile.clone(), sname))
75                .or_default()
76                .push(sid.clone());
77
78            file_symbols
79                .entry(sfile)
80                .or_default()
81                .push((sstart, send, sid));
82        }
83    }
84
85    // Sort file_symbols by span size (smallest first) for containment lookup
86    for syms in file_symbols.values_mut() {
87        syms.sort_by_key(|(s, e, _)| *e as i64 - *s as i64);
88    }
89
90    // Build SCIP symbol -> definition file mapping (cross-file resolution)
91    let mut scip_sym_to_file_name: HashMap<String, (String, String)> = HashMap::new();
92    for doc in &index.documents {
93        let file = &doc.relative_path;
94        for occ in &doc.occurrences {
95            if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
96                continue;
97            }
98            if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
99                continue;
100            }
101            let name = scip_sym_to_name(&occ.symbol);
102            scip_sym_to_file_name.insert(occ.symbol.clone(), (file.clone(), name));
103        }
104    }
105
106    // Pass 1: collect enrichments and new symbols in memory
107    let mut enrichments: Vec<(String, u32, u32, String)> = Vec::new();
108    let mut new_symbols: Vec<(String, String, String, String, u32, u32, String)> = Vec::new();
109
110    for doc in &index.documents {
111        let file = &doc.relative_path;
112
113        let sym_info_map: HashMap<&str, &scip::types::SymbolInformation> = doc
114            .symbols
115            .iter()
116            .map(|si| (si.symbol.as_str(), si))
117            .collect();
118
119        for occ in &doc.occurrences {
120            if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
121                continue;
122            }
123            let scip_sym = &occ.symbol;
124            if scip_sym.starts_with("local ") || scip_sym.starts_with('<') {
125                continue;
126            }
127
128            let name = scip_sym_to_name(scip_sym);
129            let span = parse_range(&occ.range, file);
130            let si = sym_info_map.get(scip_sym.as_str());
131            let docstring = si
132                .and_then(|s| s.documentation.first())
133                .map(|s| s.as_str())
134                .unwrap_or("");
135
136            let key = (file.clone(), name.clone());
137            if let Some(ids) = file_name_to_ids.get(&key) {
138                for sid in ids {
139                    enrichments.push((
140                        sid.clone(),
141                        span.start_line,
142                        span.end_line,
143                        docstring.to_string(),
144                    ));
145                    stats.symbols_enriched += 1;
146                }
147            } else {
148                let kind = si
149                    .map(|s| scip_kind_to_prism(&s.kind.enum_value_or_default()))
150                    .unwrap_or(SymbolKind::Function);
151                let sym_id = format!("{}::{}", file, name);
152                new_symbols.push((
153                    sym_id.clone(),
154                    name.clone(),
155                    kind.as_str().to_string(),
156                    file.clone(),
157                    span.start_line,
158                    span.end_line,
159                    docstring.to_string(),
160                ));
161                stats.symbols_added += 1;
162                file_name_to_ids
163                    .entry(key)
164                    .or_default()
165                    .push(sym_id.clone());
166                file_symbols.entry(file.clone()).or_default().push((
167                    span.start_line,
168                    span.end_line,
169                    sym_id,
170                ));
171            }
172        }
173
174        stats.files_processed += 1;
175    }
176
177    // Bulk insert new SCIP symbols via Parquet COPY FROM
178    const CHUNK: usize = 2000;
179    if !new_symbols.is_empty() {
180        let tmp = std::env::temp_dir();
181        let sym_pq = tmp.join("infigraph_scip_symbols.parquet");
182
183        let ids: Vec<&str> = new_symbols.iter().map(|(id, ..)| id.as_str()).collect();
184        let names: Vec<&str> = new_symbols
185            .iter()
186            .map(|(_, name, ..)| name.as_str())
187            .collect();
188        let kinds: Vec<&str> = new_symbols
189            .iter()
190            .map(|(_, _, kind, ..)| kind.as_str())
191            .collect();
192        let files: Vec<&str> = new_symbols
193            .iter()
194            .map(|(_, _, _, file, ..)| file.as_str())
195            .collect();
196        let start_lines: Vec<i64> = new_symbols
197            .iter()
198            .map(|(_, _, _, _, sl, ..)| *sl as i64)
199            .collect();
200        let end_lines: Vec<i64> = new_symbols.iter().map(|(.., el, _)| *el as i64).collect();
201        let docs: Vec<&str> = new_symbols.iter().map(|(.., doc)| doc.as_str()).collect();
202        let n = new_symbols.len();
203        let empty_str: Vec<&str> = vec![""; n];
204        let scip_lang: Vec<&str> = vec!["scip"; n];
205        let pub_vis: Vec<&str> = vec!["public"; n];
206        let zeros: Vec<i64> = vec![0; n];
207
208        let empty_str2: Vec<&str> = vec![""; n];
209        let pq_ok = parquet_loader::write_node_parquet(
210            &sym_pq,
211            &[
212                ("id", DataType::Utf8),
213                ("name", DataType::Utf8),
214                ("kind", DataType::Utf8),
215                ("file", DataType::Utf8),
216                ("start_line", DataType::Int64),
217                ("end_line", DataType::Int64),
218                ("signature_hash", DataType::Utf8),
219                ("language", DataType::Utf8),
220                ("visibility", DataType::Utf8),
221                ("parent", DataType::Utf8),
222                ("docstring", DataType::Utf8),
223                ("complexity", DataType::Int64),
224                ("parameters", DataType::Utf8),
225                ("return_type", DataType::Utf8),
226            ],
227            vec![
228                Arc::new(StringArray::from(ids)),
229                Arc::new(StringArray::from(names)),
230                Arc::new(StringArray::from(kinds)),
231                Arc::new(StringArray::from(files)),
232                Arc::new(Int64Array::from(start_lines)),
233                Arc::new(Int64Array::from(end_lines)),
234                Arc::new(StringArray::from(empty_str.clone())),
235                Arc::new(StringArray::from(scip_lang)),
236                Arc::new(StringArray::from(pub_vis)),
237                Arc::new(StringArray::from(empty_str)),
238                Arc::new(StringArray::from(docs)),
239                Arc::new(Int64Array::from(zeros)),
240                Arc::new(StringArray::from(empty_str2.clone())),
241                Arc::new(StringArray::from(empty_str2)),
242            ],
243        )
244        .is_ok();
245
246        let copy_ok = if pq_ok {
247            match conn.query(&format!(
248                "COPY Symbol (id, name, kind, file, start_line, end_line, signature_hash, language, visibility, parent, docstring, complexity, parameters, return_type) FROM '{}'",
249                fwd_slash_path(&sym_pq)
250            )) {
251                Ok(_) => true,
252                Err(e) => {
253                    eprintln!("Auto-SCIP: COPY Symbol failed ({e}), falling back to UNWIND");
254                    false
255                }
256            }
257        } else {
258            eprintln!("Auto-SCIP: parquet write failed, falling back to UNWIND");
259            false
260        };
261
262        if !copy_ok {
263            for chunk in new_symbols.chunks(CHUNK) {
264                let rows: Vec<String> = chunk
265                    .iter()
266                    .map(|(id, name, kind, file, start, end, doc)| {
267                        format!(
268                            "{{id: '{}', name: '{}', kind: '{}', file: '{}', sl: {}, el: {}, doc: '{}'}}",
269                            escape(id),
270                            escape(name),
271                            escape(kind),
272                            escape(file),
273                            start,
274                            end,
275                            escape(doc)
276                        )
277                    })
278                    .collect();
279                let _ = conn.query(&format!(
280                    "UNWIND [{}] AS s CREATE (:Symbol {{id: s.id, name: s.name, kind: s.kind, file: s.file, start_line: s.sl, end_line: s.el, signature_hash: '', language: 'scip', visibility: 'public', parent: '', docstring: s.doc, complexity: 0, parameters: '', return_type: ''}})",
281                    rows.join(", ")
282                ));
283            }
284        }
285        let _ = std::fs::remove_file(&sym_pq);
286    }
287
288    // Bulk write enrichments via UNWIND (updates can't use COPY FROM)
289    for chunk in enrichments.chunks(CHUNK) {
290        let rows: Vec<String> = chunk
291            .iter()
292            .map(|(id, start, end, doc)| {
293                format!(
294                    "{{id: '{}', sl: {}, el: {}, doc: '{}'}}",
295                    escape(id),
296                    start,
297                    end,
298                    escape(doc)
299                )
300            })
301            .collect();
302        let _ = conn.query(&format!(
303            "UNWIND [{}] AS e MATCH (s:Symbol) WHERE s.id = e.id SET s.start_line = e.sl, s.end_line = e.el, s.docstring = e.doc",
304            rows.join(", ")
305        ));
306    }
307
308    // Pass 2: build CALLS edges from references (all in-memory)
309    let mut calls_to_create: Vec<(String, String)> = Vec::new();
310    let mut seen_edges: std::collections::HashSet<(String, String)> =
311        std::collections::HashSet::new();
312
313    for doc in &index.documents {
314        let file = &doc.relative_path;
315
316        for occ in &doc.occurrences {
317            if (occ.symbol_roles & SymbolRole::Definition as i32) != 0 {
318                continue;
319            }
320            if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
321                continue;
322            }
323
324            let ref_line = occ.range.first().copied().unwrap_or(0) as u32;
325
326            let container_id = if let Some(syms) = file_symbols.get(file.as_str()) {
327                syms.iter()
328                    .find(|(start, end, _)| ref_line >= *start && ref_line <= *end)
329                    .map(|(_, _, id)| id.clone())
330            } else {
331                None
332            };
333            let Some(container_id) = container_id else {
334                continue;
335            };
336
337            let target_id = if let Some((tfile, tname)) = scip_sym_to_file_name.get(&occ.symbol) {
338                file_name_to_ids
339                    .get(&(tfile.clone(), tname.clone()))
340                    .and_then(|ids| ids.first())
341                    .cloned()
342            } else {
343                None
344            };
345            let Some(target_id) = target_id else {
346                continue;
347            };
348
349            if container_id == target_id {
350                continue;
351            }
352
353            // Detect SCIP correction: if tree-sitter had a CALLS edge from
354            // container_id to a *different* target for the same call name,
355            // SCIP is overriding it — record as a learned pattern.
356            if project_root.is_some() {
357                if let Some(existing_targets) = existing_calls.get(&container_id) {
358                    let call_name = target_id.rsplit("::").next().unwrap_or(&target_id);
359                    let target_file = target_id
360                        .rsplit("::")
361                        .nth(1)
362                        .or_else(|| target_id.split("::").next())
363                        .unwrap_or(&target_id);
364                    let ts_had_different = existing_targets.iter().any(|ts_tgt| {
365                        ts_tgt != &target_id
366                            && ts_tgt.rsplit("::").next().unwrap_or(ts_tgt) == call_name
367                    });
368                    if ts_had_different {
369                        let source_file = container_id.split("::").next().unwrap_or(&container_id);
370                        learned_store.record_correction(
371                            source_file,
372                            call_name,
373                            target_file,
374                            &target_id,
375                        );
376                        stats.corrections_learned += 1;
377                    }
378                }
379            }
380
381            let edge = (container_id, target_id);
382            if seen_edges.insert(edge.clone()) {
383                calls_to_create.push(edge);
384            }
385        }
386    }
387
388    // Bulk write CALLS edges via Parquet COPY FROM
389    if !calls_to_create.is_empty() {
390        let tmp = std::env::temp_dir();
391        let edge_pq = tmp.join("infigraph_scip_calls.parquet");
392        let refs: Vec<(&str, &str)> = calls_to_create
393            .iter()
394            .map(|(a, b)| (a.as_str(), b.as_str()))
395            .collect();
396        if parquet_loader::write_edge_parquet(&edge_pq, &refs).is_ok() {
397            if let Err(e) = conn.query(&format!("COPY CALLS FROM '{}'", fwd_slash_path(&edge_pq))) {
398                eprintln!("Auto-SCIP: COPY CALLS failed ({e}), falling back to UNWIND");
399                unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
400            }
401        } else {
402            unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
403        }
404        stats.references_added = calls_to_create.len();
405        let _ = std::fs::remove_file(&edge_pq);
406    }
407
408    // Persist learned corrections (if any were recorded)
409    if let Some(root) = project_root {
410        if stats.corrections_learned > 0 {
411            if let Err(e) = learned_store.save(root) {
412                eprintln!("warning: failed to save learned patterns: {e}");
413            }
414        }
415    }
416
417    Ok(stats)
418}
419
420fn parse_range(range: &[i32], file: &str) -> Span {
421    let (start_line, start_col, end_line, end_col) = match range.len() {
422        4 => (range[0], range[1], range[2], range[3]),
423        3 => (range[0], range[1], range[0], range[2]),
424        _ => (0, 0, 0, 0),
425    };
426    Span {
427        file: file.to_string(),
428        start_line: start_line as u32,
429        start_col: start_col as u32,
430        end_line: end_line as u32,
431        end_col: end_col as u32,
432    }
433}
434
435fn scip_sym_to_name(scip_sym: &str) -> String {
436    scip_sym
437        .rsplit_once('`')
438        .map(|(_, n)| n)
439        .or_else(|| scip_sym.rsplit(['#', '.', '/']).next())
440        .unwrap_or(scip_sym)
441        .trim_matches(|c| c == '(' || c == ')' || c == '`')
442        .to_string()
443}
444
445fn scip_kind_to_prism(kind: &symbol_information::Kind) -> SymbolKind {
446    use symbol_information::Kind::*;
447    match kind {
448        Function | AbstractMethod | StaticMethod | PureVirtualMethod | ProtocolMethod
449        | TraitMethod | TypeClassMethod => SymbolKind::Function,
450        Method | MethodAlias | MethodReceiver | MethodSpecification => SymbolKind::Method,
451        Class | SingletonClass => SymbolKind::Class,
452        Struct => SymbolKind::Struct,
453        Interface => SymbolKind::Interface,
454        Trait | TypeClass => SymbolKind::Trait,
455        Enum | EnumMember => SymbolKind::Enum,
456        Module | Namespace | Package => SymbolKind::Module,
457        Variable | StaticVariable | Field | SelfParameter | Parameter => SymbolKind::Variable,
458        Constant => SymbolKind::Constant,
459        _ => SymbolKind::Function,
460    }
461}
462
463#[derive(Default, Debug)]
464pub struct ImportStats {
465    pub files_processed: usize,
466    pub symbols_added: usize,
467    pub symbols_enriched: usize,
468    pub symbols_skipped: usize,
469    pub relations_added: usize,
470    pub references_added: usize,
471    pub corrections_learned: usize,
472}