Skip to main content

infigraph_core/scip/
mod.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use arrow::array::{Int64Array, StringArray};
7use arrow::datatypes::DataType;
8use protobuf::Message;
9use scip::types::{symbol_information, Index, SymbolRole};
10
11use crate::graph::parquet_loader;
12use crate::graph::store_util::{escape, fwd_slash_path, unwind_edges_from_pairs};
13use crate::graph::GraphStore;
14use crate::model::{Span, SymbolKind};
15
16/// Import a SCIP index.scip file into the Infigraph graph store.
17///
18/// Matches SCIP definitions to existing tree-sitter symbols by (file, name)
19/// and enriches them with compiler-grade type information. Builds cross-file
20/// CALLS edges from SCIP references using an in-memory symbol map for speed.
21pub fn import_scip_index(
22    index_path: &Path,
23    store: &GraphStore,
24    project_root: Option<&Path>,
25) -> Result<ImportStats> {
26    let bytes = std::fs::read(index_path)
27        .with_context(|| format!("failed to read {}", index_path.display()))?;
28
29    let index = Index::parse_from_bytes(&bytes)
30        .with_context(|| format!("failed to parse SCIP index: {}", index_path.display()))?;
31
32    let mut stats = ImportStats::default();
33    let _lock = store.write_lock()?;
34    let conn = store.connection()?;
35
36    // Load learned pattern store for recording SCIP corrections
37    let mut learned_store = project_root
38        .map(crate::learned::LearnedStore::load)
39        .unwrap_or_default();
40
41    // Pre-load existing CALLS edges from tree-sitter resolution.
42    // Used to detect when SCIP resolves differently (= a correction to learn from).
43    let mut existing_calls: HashMap<String, std::collections::HashSet<String>> = HashMap::new();
44    if project_root.is_some() {
45        if let Ok(rows) = conn.query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id") {
46            for row in rows {
47                if row.len() < 2 {
48                    continue;
49                }
50                let src = row[0].to_string().trim_matches('"').to_string();
51                let tgt = row[1].to_string().trim_matches('"').to_string();
52                existing_calls.entry(src).or_default().insert(tgt);
53            }
54        }
55    }
56
57    // Pre-load all symbols from graph into memory: (file, name) -> Vec<symbol_id>
58    // and file -> sorted Vec<(start_line, end_line, symbol_id)> for containment lookup
59    let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
60    let mut file_symbols: HashMap<String, Vec<(u32, u32, String)>> = HashMap::new();
61
62    let q = "MATCH (s:Symbol) RETURN s.id, s.file, s.name, s.start_line, s.end_line";
63    if let Ok(rows) = conn.query(q) {
64        for row in rows {
65            if row.len() < 5 {
66                continue;
67            }
68            let sid = row[0].to_string().trim_matches('"').to_string();
69            let sfile = row[1].to_string().trim_matches('"').to_string();
70            let sname = row[2].to_string().trim_matches('"').to_string();
71            let sstart: u32 = row[3].to_string().trim_matches('"').parse().unwrap_or(0);
72            let send: u32 = row[4].to_string().trim_matches('"').parse().unwrap_or(0);
73
74            file_name_to_ids
75                .entry((sfile.clone(), sname))
76                .or_default()
77                .push(sid.clone());
78
79            file_symbols
80                .entry(sfile)
81                .or_default()
82                .push((sstart, send, sid));
83        }
84    }
85
86    // Sort file_symbols by span size (smallest first) for containment lookup
87    for syms in file_symbols.values_mut() {
88        syms.sort_by_key(|(s, e, _)| *e as i64 - *s as i64);
89    }
90
91    // Build SCIP symbol -> definition file mapping (cross-file resolution)
92    let mut scip_sym_to_file_name: HashMap<String, (String, String)> = HashMap::new();
93    for doc in &index.documents {
94        let file = &doc.relative_path;
95        for occ in &doc.occurrences {
96            if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
97                continue;
98            }
99            if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
100                continue;
101            }
102            let name = scip_sym_to_name(&occ.symbol);
103            scip_sym_to_file_name.insert(occ.symbol.clone(), (file.clone(), name));
104        }
105    }
106
107    // Pass 1: collect enrichments and new symbols in memory
108    let mut enrichments: Vec<(String, u32, u32, String)> = Vec::new();
109    let mut new_symbols: Vec<(String, String, String, String, u32, u32, String)> = Vec::new();
110
111    for doc in &index.documents {
112        let file = &doc.relative_path;
113
114        let sym_info_map: HashMap<&str, &scip::types::SymbolInformation> = doc
115            .symbols
116            .iter()
117            .map(|si| (si.symbol.as_str(), si))
118            .collect();
119
120        for occ in &doc.occurrences {
121            if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
122                continue;
123            }
124            let scip_sym = &occ.symbol;
125            if scip_sym.starts_with("local ") || scip_sym.starts_with('<') {
126                continue;
127            }
128
129            let name = scip_sym_to_name(scip_sym);
130            let span = parse_range(&occ.range, file);
131            let si = sym_info_map.get(scip_sym.as_str());
132            let docstring = si
133                .and_then(|s| s.documentation.first())
134                .map(|s| s.as_str())
135                .unwrap_or("");
136
137            let key = (file.clone(), name.clone());
138            if let Some(ids) = file_name_to_ids.get(&key) {
139                for sid in ids {
140                    enrichments.push((
141                        sid.clone(),
142                        span.start_line,
143                        span.end_line,
144                        docstring.to_string(),
145                    ));
146                    stats.symbols_enriched += 1;
147                }
148            } else {
149                let kind = si
150                    .map(|s| scip_kind_to_prism(&s.kind.enum_value_or_default()))
151                    .unwrap_or(SymbolKind::Function);
152                let sym_id = format!("{}::{}", file, name);
153                new_symbols.push((
154                    sym_id.clone(),
155                    name.clone(),
156                    kind.as_str().to_string(),
157                    file.clone(),
158                    span.start_line,
159                    span.end_line,
160                    docstring.to_string(),
161                ));
162                stats.symbols_added += 1;
163                file_name_to_ids
164                    .entry(key)
165                    .or_default()
166                    .push(sym_id.clone());
167                file_symbols.entry(file.clone()).or_default().push((
168                    span.start_line,
169                    span.end_line,
170                    sym_id,
171                ));
172            }
173        }
174
175        stats.files_processed += 1;
176    }
177
178    // Bulk insert new SCIP symbols via Parquet COPY FROM
179    const CHUNK: usize = 2000;
180    if !new_symbols.is_empty() {
181        let tmp = std::env::temp_dir();
182        let sym_pq = tmp.join("infigraph_scip_symbols.parquet");
183
184        let ids: Vec<&str> = new_symbols.iter().map(|(id, ..)| id.as_str()).collect();
185        let names: Vec<&str> = new_symbols
186            .iter()
187            .map(|(_, name, ..)| name.as_str())
188            .collect();
189        let kinds: Vec<&str> = new_symbols
190            .iter()
191            .map(|(_, _, kind, ..)| kind.as_str())
192            .collect();
193        let files: Vec<&str> = new_symbols
194            .iter()
195            .map(|(_, _, _, file, ..)| file.as_str())
196            .collect();
197        let start_lines: Vec<i64> = new_symbols
198            .iter()
199            .map(|(_, _, _, _, sl, ..)| *sl as i64)
200            .collect();
201        let end_lines: Vec<i64> = new_symbols.iter().map(|(.., el, _)| *el as i64).collect();
202        let docs: Vec<&str> = new_symbols.iter().map(|(.., doc)| doc.as_str()).collect();
203        let n = new_symbols.len();
204        let empty_str: Vec<&str> = vec![""; n];
205        let scip_lang: Vec<&str> = vec!["scip"; n];
206        let pub_vis: Vec<&str> = vec!["public"; n];
207        let zeros: Vec<i64> = vec![0; n];
208
209        let empty_str2: Vec<&str> = vec![""; n];
210        let pq_ok = parquet_loader::write_node_parquet(
211            &sym_pq,
212            &[
213                ("id", DataType::Utf8),
214                ("name", DataType::Utf8),
215                ("kind", DataType::Utf8),
216                ("file", DataType::Utf8),
217                ("start_line", DataType::Int64),
218                ("end_line", DataType::Int64),
219                ("signature_hash", DataType::Utf8),
220                ("language", DataType::Utf8),
221                ("visibility", DataType::Utf8),
222                ("parent", DataType::Utf8),
223                ("docstring", DataType::Utf8),
224                ("complexity", DataType::Int64),
225                ("parameters", DataType::Utf8),
226                ("return_type", DataType::Utf8),
227            ],
228            vec![
229                Arc::new(StringArray::from(ids)),
230                Arc::new(StringArray::from(names)),
231                Arc::new(StringArray::from(kinds)),
232                Arc::new(StringArray::from(files)),
233                Arc::new(Int64Array::from(start_lines)),
234                Arc::new(Int64Array::from(end_lines)),
235                Arc::new(StringArray::from(empty_str.clone())),
236                Arc::new(StringArray::from(scip_lang)),
237                Arc::new(StringArray::from(pub_vis)),
238                Arc::new(StringArray::from(empty_str)),
239                Arc::new(StringArray::from(docs)),
240                Arc::new(Int64Array::from(zeros)),
241                Arc::new(StringArray::from(empty_str2.clone())),
242                Arc::new(StringArray::from(empty_str2)),
243            ],
244        )
245        .is_ok();
246
247        let copy_ok = if pq_ok {
248            match conn.query(&format!(
249                "COPY Symbol (id, name, kind, file, start_line, end_line, signature_hash, language, visibility, parent, docstring, complexity, parameters, return_type) FROM '{}'",
250                fwd_slash_path(&sym_pq)
251            )) {
252                Ok(_) => true,
253                Err(e) => {
254                    eprintln!("Auto-SCIP: COPY Symbol failed ({e}), falling back to UNWIND");
255                    false
256                }
257            }
258        } else {
259            eprintln!("Auto-SCIP: parquet write failed, falling back to UNWIND");
260            false
261        };
262
263        if !copy_ok {
264            for chunk in new_symbols.chunks(CHUNK) {
265                let rows: Vec<String> = chunk
266                    .iter()
267                    .map(|(id, name, kind, file, start, end, doc)| {
268                        format!(
269                            "{{id: '{}', name: '{}', kind: '{}', file: '{}', sl: {}, el: {}, doc: '{}'}}",
270                            escape(id),
271                            escape(name),
272                            escape(kind),
273                            escape(file),
274                            start,
275                            end,
276                            escape(doc)
277                        )
278                    })
279                    .collect();
280                let _ = conn.query(&format!(
281                    "UNWIND [{}] AS s CREATE (:Symbol {{id: s.id, name: s.name, kind: s.kind, file: s.file, start_line: s.sl, end_line: s.el, signature_hash: '', language: 'scip', visibility: 'public', parent: '', docstring: s.doc, complexity: 0, parameters: '', return_type: ''}})",
282                    rows.join(", ")
283                ));
284            }
285        }
286        let _ = std::fs::remove_file(&sym_pq);
287    }
288
289    // Bulk write enrichments via UNWIND (updates can't use COPY FROM)
290    for chunk in enrichments.chunks(CHUNK) {
291        let rows: Vec<String> = chunk
292            .iter()
293            .map(|(id, start, end, doc)| {
294                format!(
295                    "{{id: '{}', sl: {}, el: {}, doc: '{}'}}",
296                    escape(id),
297                    start,
298                    end,
299                    escape(doc)
300                )
301            })
302            .collect();
303        let _ = conn.query(&format!(
304            "UNWIND [{}] AS e MATCH (s:Symbol) WHERE s.id = e.id SET s.start_line = e.sl, s.end_line = e.el, s.docstring = e.doc",
305            rows.join(", ")
306        ));
307    }
308
309    // Pass 2: build CALLS edges from references (all in-memory)
310    let mut calls_to_create: Vec<(String, String)> = Vec::new();
311    let mut seen_edges: std::collections::HashSet<(String, String)> =
312        std::collections::HashSet::new();
313
314    for doc in &index.documents {
315        let file = &doc.relative_path;
316
317        for occ in &doc.occurrences {
318            if (occ.symbol_roles & SymbolRole::Definition as i32) != 0 {
319                continue;
320            }
321            if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
322                continue;
323            }
324
325            let ref_line = occ.range.first().copied().unwrap_or(0) as u32;
326
327            let container_id = if let Some(syms) = file_symbols.get(file.as_str()) {
328                syms.iter()
329                    .find(|(start, end, _)| ref_line >= *start && ref_line <= *end)
330                    .map(|(_, _, id)| id.clone())
331            } else {
332                None
333            };
334            let Some(container_id) = container_id else {
335                continue;
336            };
337
338            let target_id = if let Some((tfile, tname)) = scip_sym_to_file_name.get(&occ.symbol) {
339                file_name_to_ids
340                    .get(&(tfile.clone(), tname.clone()))
341                    .and_then(|ids| ids.first())
342                    .cloned()
343            } else {
344                None
345            };
346            let Some(target_id) = target_id else {
347                continue;
348            };
349
350            if container_id == target_id {
351                continue;
352            }
353
354            // Detect SCIP correction: if tree-sitter had a CALLS edge from
355            // container_id to a *different* target for the same call name,
356            // SCIP is overriding it — record as a learned pattern.
357            if project_root.is_some() {
358                if let Some(existing_targets) = existing_calls.get(&container_id) {
359                    let call_name = target_id.rsplit("::").next().unwrap_or(&target_id);
360                    let target_file = target_id
361                        .rsplit("::")
362                        .nth(1)
363                        .or_else(|| target_id.split("::").next())
364                        .unwrap_or(&target_id);
365                    let ts_had_different = existing_targets.iter().any(|ts_tgt| {
366                        ts_tgt != &target_id
367                            && ts_tgt.rsplit("::").next().unwrap_or(ts_tgt) == call_name
368                    });
369                    if ts_had_different {
370                        let source_file = container_id.split("::").next().unwrap_or(&container_id);
371                        learned_store.record_correction(
372                            source_file,
373                            call_name,
374                            target_file,
375                            &target_id,
376                        );
377                        stats.corrections_learned += 1;
378                    }
379                }
380            }
381
382            let edge = (container_id, target_id);
383            if seen_edges.insert(edge.clone()) {
384                calls_to_create.push(edge);
385            }
386        }
387    }
388
389    // Bulk write CALLS edges via Parquet COPY FROM
390    if !calls_to_create.is_empty() {
391        let tmp = std::env::temp_dir();
392        let edge_pq = tmp.join("infigraph_scip_calls.parquet");
393        let refs: Vec<(&str, &str)> = calls_to_create
394            .iter()
395            .map(|(a, b)| (a.as_str(), b.as_str()))
396            .collect();
397        if parquet_loader::write_edge_parquet(&edge_pq, &refs).is_ok() {
398            if let Err(e) = conn.query(&format!("COPY CALLS FROM '{}'", fwd_slash_path(&edge_pq))) {
399                eprintln!("Auto-SCIP: COPY CALLS failed ({e}), falling back to UNWIND");
400                unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
401            }
402        } else {
403            unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
404        }
405        stats.references_added = calls_to_create.len();
406        let _ = std::fs::remove_file(&edge_pq);
407    }
408
409    // Persist learned corrections (if any were recorded)
410    if let Some(root) = project_root {
411        if stats.corrections_learned > 0 {
412            if let Err(e) = learned_store.save(root) {
413                eprintln!("warning: failed to save learned patterns: {e}");
414            }
415        }
416    }
417
418    Ok(stats)
419}
420
421fn parse_range(range: &[i32], file: &str) -> Span {
422    let (start_line, start_col, end_line, end_col) = match range.len() {
423        4 => (range[0], range[1], range[2], range[3]),
424        3 => (range[0], range[1], range[0], range[2]),
425        _ => (0, 0, 0, 0),
426    };
427    Span {
428        file: file.to_string(),
429        start_line: start_line as u32,
430        start_col: start_col as u32,
431        end_line: end_line as u32,
432        end_col: end_col as u32,
433    }
434}
435
436fn scip_sym_to_name(scip_sym: &str) -> String {
437    scip_sym
438        .rsplit_once('`')
439        .map(|(_, n)| n)
440        .or_else(|| scip_sym.rsplit(['#', '.', '/']).next())
441        .unwrap_or(scip_sym)
442        .trim_matches(|c| c == '(' || c == ')' || c == '`')
443        .to_string()
444}
445
446fn scip_kind_to_prism(kind: &symbol_information::Kind) -> SymbolKind {
447    use symbol_information::Kind::*;
448    match kind {
449        Function | AbstractMethod | StaticMethod | PureVirtualMethod | ProtocolMethod
450        | TraitMethod | TypeClassMethod => SymbolKind::Function,
451        Method | MethodAlias | MethodReceiver | MethodSpecification => SymbolKind::Method,
452        Class | SingletonClass => SymbolKind::Class,
453        Struct => SymbolKind::Struct,
454        Interface => SymbolKind::Interface,
455        Trait | TypeClass => SymbolKind::Trait,
456        Enum | EnumMember => SymbolKind::Enum,
457        Module | Namespace | Package => SymbolKind::Module,
458        Variable | StaticVariable | Field | SelfParameter | Parameter => SymbolKind::Variable,
459        Constant => SymbolKind::Constant,
460        _ => SymbolKind::Function,
461    }
462}
463
464#[derive(Default, Debug)]
465pub struct ImportStats {
466    pub files_processed: usize,
467    pub symbols_added: usize,
468    pub symbols_enriched: usize,
469    pub symbols_skipped: usize,
470    pub relations_added: usize,
471    pub references_added: usize,
472    pub corrections_learned: usize,
473}