1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use arrow::array::{Int64Array, StringArray};
7use arrow::datatypes::DataType;
8use protobuf::Message;
9use scip::types::{symbol_information, Index, SymbolRole};
10
11use crate::graph::parquet_loader;
12use crate::graph::store_util::{escape, fwd_slash_path, unwind_edges_from_pairs};
13use crate::graph::GraphStore;
14use crate::model::{Span, SymbolKind};
15
16pub fn import_scip_index(
22 index_path: &Path,
23 store: &GraphStore,
24 project_root: Option<&Path>,
25) -> Result<ImportStats> {
26 let bytes = std::fs::read(index_path)
27 .with_context(|| format!("failed to read {}", index_path.display()))?;
28
29 let index = Index::parse_from_bytes(&bytes)
30 .with_context(|| format!("failed to parse SCIP index: {}", index_path.display()))?;
31
32 let mut stats = ImportStats::default();
33 let conn = store.connection()?;
34
35 let mut learned_store = project_root
37 .map(crate::learned::LearnedStore::load)
38 .unwrap_or_default();
39
40 let mut existing_calls: HashMap<String, std::collections::HashSet<String>> = HashMap::new();
43 if project_root.is_some() {
44 if let Ok(rows) = conn.query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id") {
45 for row in rows {
46 if row.len() < 2 {
47 continue;
48 }
49 let src = row[0].to_string().trim_matches('"').to_string();
50 let tgt = row[1].to_string().trim_matches('"').to_string();
51 existing_calls.entry(src).or_default().insert(tgt);
52 }
53 }
54 }
55
56 let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
59 let mut file_symbols: HashMap<String, Vec<(u32, u32, String)>> = HashMap::new();
60
61 let q = "MATCH (s:Symbol) RETURN s.id, s.file, s.name, s.start_line, s.end_line";
62 if let Ok(rows) = conn.query(q) {
63 for row in rows {
64 if row.len() < 5 {
65 continue;
66 }
67 let sid = row[0].to_string().trim_matches('"').to_string();
68 let sfile = row[1].to_string().trim_matches('"').to_string();
69 let sname = row[2].to_string().trim_matches('"').to_string();
70 let sstart: u32 = row[3].to_string().trim_matches('"').parse().unwrap_or(0);
71 let send: u32 = row[4].to_string().trim_matches('"').parse().unwrap_or(0);
72
73 file_name_to_ids
74 .entry((sfile.clone(), sname))
75 .or_default()
76 .push(sid.clone());
77
78 file_symbols
79 .entry(sfile)
80 .or_default()
81 .push((sstart, send, sid));
82 }
83 }
84
85 for syms in file_symbols.values_mut() {
87 syms.sort_by_key(|(s, e, _)| *e as i64 - *s as i64);
88 }
89
90 let mut scip_sym_to_file_name: HashMap<String, (String, String)> = HashMap::new();
92 for doc in &index.documents {
93 let file = &doc.relative_path;
94 for occ in &doc.occurrences {
95 if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
96 continue;
97 }
98 if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
99 continue;
100 }
101 let name = scip_sym_to_name(&occ.symbol);
102 scip_sym_to_file_name.insert(occ.symbol.clone(), (file.clone(), name));
103 }
104 }
105
106 let mut enrichments: Vec<(String, u32, u32, String)> = Vec::new();
108 let mut new_symbols: Vec<(String, String, String, String, u32, u32, String)> = Vec::new();
109
110 for doc in &index.documents {
111 let file = &doc.relative_path;
112
113 let sym_info_map: HashMap<&str, &scip::types::SymbolInformation> = doc
114 .symbols
115 .iter()
116 .map(|si| (si.symbol.as_str(), si))
117 .collect();
118
119 for occ in &doc.occurrences {
120 if (occ.symbol_roles & SymbolRole::Definition as i32) == 0 {
121 continue;
122 }
123 let scip_sym = &occ.symbol;
124 if scip_sym.starts_with("local ") || scip_sym.starts_with('<') {
125 continue;
126 }
127
128 let name = scip_sym_to_name(scip_sym);
129 let span = parse_range(&occ.range, file);
130 let si = sym_info_map.get(scip_sym.as_str());
131 let docstring = si
132 .and_then(|s| s.documentation.first())
133 .map(|s| s.as_str())
134 .unwrap_or("");
135
136 let key = (file.clone(), name.clone());
137 if let Some(ids) = file_name_to_ids.get(&key) {
138 for sid in ids {
139 enrichments.push((
140 sid.clone(),
141 span.start_line,
142 span.end_line,
143 docstring.to_string(),
144 ));
145 stats.symbols_enriched += 1;
146 }
147 } else {
148 let kind = si
149 .map(|s| scip_kind_to_prism(&s.kind.enum_value_or_default()))
150 .unwrap_or(SymbolKind::Function);
151 let sym_id = format!("{}::{}", file, name);
152 new_symbols.push((
153 sym_id.clone(),
154 name.clone(),
155 kind.as_str().to_string(),
156 file.clone(),
157 span.start_line,
158 span.end_line,
159 docstring.to_string(),
160 ));
161 stats.symbols_added += 1;
162 file_name_to_ids
163 .entry(key)
164 .or_default()
165 .push(sym_id.clone());
166 file_symbols.entry(file.clone()).or_default().push((
167 span.start_line,
168 span.end_line,
169 sym_id,
170 ));
171 }
172 }
173
174 stats.files_processed += 1;
175 }
176
177 const CHUNK: usize = 2000;
179 if !new_symbols.is_empty() {
180 let tmp = std::env::temp_dir();
181 let sym_pq = tmp.join("infigraph_scip_symbols.parquet");
182
183 let ids: Vec<&str> = new_symbols.iter().map(|(id, ..)| id.as_str()).collect();
184 let names: Vec<&str> = new_symbols
185 .iter()
186 .map(|(_, name, ..)| name.as_str())
187 .collect();
188 let kinds: Vec<&str> = new_symbols
189 .iter()
190 .map(|(_, _, kind, ..)| kind.as_str())
191 .collect();
192 let files: Vec<&str> = new_symbols
193 .iter()
194 .map(|(_, _, _, file, ..)| file.as_str())
195 .collect();
196 let start_lines: Vec<i64> = new_symbols
197 .iter()
198 .map(|(_, _, _, _, sl, ..)| *sl as i64)
199 .collect();
200 let end_lines: Vec<i64> = new_symbols.iter().map(|(.., el, _)| *el as i64).collect();
201 let docs: Vec<&str> = new_symbols.iter().map(|(.., doc)| doc.as_str()).collect();
202 let n = new_symbols.len();
203 let empty_str: Vec<&str> = vec![""; n];
204 let scip_lang: Vec<&str> = vec!["scip"; n];
205 let pub_vis: Vec<&str> = vec!["public"; n];
206 let zeros: Vec<i64> = vec![0; n];
207
208 let empty_str2: Vec<&str> = vec![""; n];
209 let pq_ok = parquet_loader::write_node_parquet(
210 &sym_pq,
211 &[
212 ("id", DataType::Utf8),
213 ("name", DataType::Utf8),
214 ("kind", DataType::Utf8),
215 ("file", DataType::Utf8),
216 ("start_line", DataType::Int64),
217 ("end_line", DataType::Int64),
218 ("signature_hash", DataType::Utf8),
219 ("language", DataType::Utf8),
220 ("visibility", DataType::Utf8),
221 ("parent", DataType::Utf8),
222 ("docstring", DataType::Utf8),
223 ("complexity", DataType::Int64),
224 ("parameters", DataType::Utf8),
225 ("return_type", DataType::Utf8),
226 ],
227 vec![
228 Arc::new(StringArray::from(ids)),
229 Arc::new(StringArray::from(names)),
230 Arc::new(StringArray::from(kinds)),
231 Arc::new(StringArray::from(files)),
232 Arc::new(Int64Array::from(start_lines)),
233 Arc::new(Int64Array::from(end_lines)),
234 Arc::new(StringArray::from(empty_str.clone())),
235 Arc::new(StringArray::from(scip_lang)),
236 Arc::new(StringArray::from(pub_vis)),
237 Arc::new(StringArray::from(empty_str)),
238 Arc::new(StringArray::from(docs)),
239 Arc::new(Int64Array::from(zeros)),
240 Arc::new(StringArray::from(empty_str2.clone())),
241 Arc::new(StringArray::from(empty_str2)),
242 ],
243 )
244 .is_ok();
245
246 let copy_ok = if pq_ok {
247 match conn.query(&format!(
248 "COPY Symbol (id, name, kind, file, start_line, end_line, signature_hash, language, visibility, parent, docstring, complexity, parameters, return_type) FROM '{}'",
249 fwd_slash_path(&sym_pq)
250 )) {
251 Ok(_) => true,
252 Err(e) => {
253 eprintln!("Auto-SCIP: COPY Symbol failed ({e}), falling back to UNWIND");
254 false
255 }
256 }
257 } else {
258 eprintln!("Auto-SCIP: parquet write failed, falling back to UNWIND");
259 false
260 };
261
262 if !copy_ok {
263 for chunk in new_symbols.chunks(CHUNK) {
264 let rows: Vec<String> = chunk
265 .iter()
266 .map(|(id, name, kind, file, start, end, doc)| {
267 format!(
268 "{{id: '{}', name: '{}', kind: '{}', file: '{}', sl: {}, el: {}, doc: '{}'}}",
269 escape(id),
270 escape(name),
271 escape(kind),
272 escape(file),
273 start,
274 end,
275 escape(doc)
276 )
277 })
278 .collect();
279 let _ = conn.query(&format!(
280 "UNWIND [{}] AS s CREATE (:Symbol {{id: s.id, name: s.name, kind: s.kind, file: s.file, start_line: s.sl, end_line: s.el, signature_hash: '', language: 'scip', visibility: 'public', parent: '', docstring: s.doc, complexity: 0, parameters: '', return_type: ''}})",
281 rows.join(", ")
282 ));
283 }
284 }
285 let _ = std::fs::remove_file(&sym_pq);
286 }
287
288 for chunk in enrichments.chunks(CHUNK) {
290 let rows: Vec<String> = chunk
291 .iter()
292 .map(|(id, start, end, doc)| {
293 format!(
294 "{{id: '{}', sl: {}, el: {}, doc: '{}'}}",
295 escape(id),
296 start,
297 end,
298 escape(doc)
299 )
300 })
301 .collect();
302 let _ = conn.query(&format!(
303 "UNWIND [{}] AS e MATCH (s:Symbol) WHERE s.id = e.id SET s.start_line = e.sl, s.end_line = e.el, s.docstring = e.doc",
304 rows.join(", ")
305 ));
306 }
307
308 let mut calls_to_create: Vec<(String, String)> = Vec::new();
310 let mut seen_edges: std::collections::HashSet<(String, String)> =
311 std::collections::HashSet::new();
312
313 for doc in &index.documents {
314 let file = &doc.relative_path;
315
316 for occ in &doc.occurrences {
317 if (occ.symbol_roles & SymbolRole::Definition as i32) != 0 {
318 continue;
319 }
320 if occ.symbol.starts_with("local ") || occ.symbol.starts_with('<') {
321 continue;
322 }
323
324 let ref_line = occ.range.first().copied().unwrap_or(0) as u32;
325
326 let container_id = if let Some(syms) = file_symbols.get(file.as_str()) {
327 syms.iter()
328 .find(|(start, end, _)| ref_line >= *start && ref_line <= *end)
329 .map(|(_, _, id)| id.clone())
330 } else {
331 None
332 };
333 let Some(container_id) = container_id else {
334 continue;
335 };
336
337 let target_id = if let Some((tfile, tname)) = scip_sym_to_file_name.get(&occ.symbol) {
338 file_name_to_ids
339 .get(&(tfile.clone(), tname.clone()))
340 .and_then(|ids| ids.first())
341 .cloned()
342 } else {
343 None
344 };
345 let Some(target_id) = target_id else {
346 continue;
347 };
348
349 if container_id == target_id {
350 continue;
351 }
352
353 if project_root.is_some() {
357 if let Some(existing_targets) = existing_calls.get(&container_id) {
358 let call_name = target_id.rsplit("::").next().unwrap_or(&target_id);
359 let target_file = target_id
360 .rsplit("::")
361 .nth(1)
362 .or_else(|| target_id.split("::").next())
363 .unwrap_or(&target_id);
364 let ts_had_different = existing_targets.iter().any(|ts_tgt| {
365 ts_tgt != &target_id
366 && ts_tgt.rsplit("::").next().unwrap_or(ts_tgt) == call_name
367 });
368 if ts_had_different {
369 let source_file = container_id.split("::").next().unwrap_or(&container_id);
370 learned_store.record_correction(
371 source_file,
372 call_name,
373 target_file,
374 &target_id,
375 );
376 stats.corrections_learned += 1;
377 }
378 }
379 }
380
381 let edge = (container_id, target_id);
382 if seen_edges.insert(edge.clone()) {
383 calls_to_create.push(edge);
384 }
385 }
386 }
387
388 if !calls_to_create.is_empty() {
390 let tmp = std::env::temp_dir();
391 let edge_pq = tmp.join("infigraph_scip_calls.parquet");
392 let refs: Vec<(&str, &str)> = calls_to_create
393 .iter()
394 .map(|(a, b)| (a.as_str(), b.as_str()))
395 .collect();
396 if parquet_loader::write_edge_parquet(&edge_pq, &refs).is_ok() {
397 if let Err(e) = conn.query(&format!("COPY CALLS FROM '{}'", fwd_slash_path(&edge_pq))) {
398 eprintln!("Auto-SCIP: COPY CALLS failed ({e}), falling back to UNWIND");
399 unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
400 }
401 } else {
402 unwind_edges_from_pairs(&conn, &refs, "CALLS", "Symbol", "Symbol");
403 }
404 stats.references_added = calls_to_create.len();
405 let _ = std::fs::remove_file(&edge_pq);
406 }
407
408 if let Some(root) = project_root {
410 if stats.corrections_learned > 0 {
411 if let Err(e) = learned_store.save(root) {
412 eprintln!("warning: failed to save learned patterns: {e}");
413 }
414 }
415 }
416
417 Ok(stats)
418}
419
420fn parse_range(range: &[i32], file: &str) -> Span {
421 let (start_line, start_col, end_line, end_col) = match range.len() {
422 4 => (range[0], range[1], range[2], range[3]),
423 3 => (range[0], range[1], range[0], range[2]),
424 _ => (0, 0, 0, 0),
425 };
426 Span {
427 file: file.to_string(),
428 start_line: start_line as u32,
429 start_col: start_col as u32,
430 end_line: end_line as u32,
431 end_col: end_col as u32,
432 }
433}
434
435fn scip_sym_to_name(scip_sym: &str) -> String {
436 scip_sym
437 .rsplit_once('`')
438 .map(|(_, n)| n)
439 .or_else(|| scip_sym.rsplit(['#', '.', '/']).next())
440 .unwrap_or(scip_sym)
441 .trim_matches(|c| c == '(' || c == ')' || c == '`')
442 .to_string()
443}
444
445fn scip_kind_to_prism(kind: &symbol_information::Kind) -> SymbolKind {
446 use symbol_information::Kind::*;
447 match kind {
448 Function | AbstractMethod | StaticMethod | PureVirtualMethod | ProtocolMethod
449 | TraitMethod | TypeClassMethod => SymbolKind::Function,
450 Method | MethodAlias | MethodReceiver | MethodSpecification => SymbolKind::Method,
451 Class | SingletonClass => SymbolKind::Class,
452 Struct => SymbolKind::Struct,
453 Interface => SymbolKind::Interface,
454 Trait | TypeClass => SymbolKind::Trait,
455 Enum | EnumMember => SymbolKind::Enum,
456 Module | Namespace | Package => SymbolKind::Module,
457 Variable | StaticVariable | Field | SelfParameter | Parameter => SymbolKind::Variable,
458 Constant => SymbolKind::Constant,
459 _ => SymbolKind::Function,
460 }
461}
462
463#[derive(Default, Debug)]
464pub struct ImportStats {
465 pub files_processed: usize,
466 pub symbols_added: usize,
467 pub symbols_enriched: usize,
468 pub symbols_skipped: usize,
469 pub relations_added: usize,
470 pub references_added: usize,
471 pub corrections_learned: usize,
472}