Skip to main content

infigraph_core/resolve/
mod.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4
5use crate::graph::store::GraphStore;
6use crate::learned::LearnedStore;
7use crate::model::{FileExtraction, RelationKind};
8
9/// Post-indexing pass that resolves call edges using cross-file symbol lookup.
10/// Builds symbol map from the full graph (not just re-indexed files) so
11/// incremental indexing doesn't lose cross-file resolution.
12pub fn resolve_calls_incremental(
13    store: &GraphStore,
14    extractions: &[FileExtraction],
15    learned_store: Option<&LearnedStore>,
16) -> Result<ResolveStats> {
17    if extractions.is_empty() {
18        return Ok(ResolveStats {
19            total_calls: 0,
20            resolved: 0,
21            unresolved: 0,
22            learned_resolved: 0,
23            inherits_resolved: 0,
24        });
25    }
26
27    let conn = store.connection()?;
28
29    // Build global symbol table from full graph: name -> [(id, file, kind)]
30    let mut symbol_map: HashMap<String, Vec<(String, String, String)>> = HashMap::new();
31    for (name, id, file, kind) in store.get_all_symbols()? {
32        symbol_map.entry(name).or_default().push((id, file, kind));
33    }
34
35    let mut stats = resolve_with_map(&conn, extractions, &symbol_map, learned_store)?;
36    stats.inherits_resolved = resolve_inherits(&conn, extractions, &symbol_map)?;
37    Ok(stats)
38}
39
40/// Post-indexing pass that resolves call edges using cross-file symbol lookup.
41///
42/// Problem: During extraction, `authenticate()` called in `main.py` creates
43/// a CALLS relation targeting `main.py::authenticate`. But the real symbol
44/// is `auth.py::authenticate`. This pass:
45///
46/// 1. Builds a symbol table from all extractions
47/// 2. For each CALLS relation where the target doesn't exist locally,
48///    searches the global symbol table by name
49/// 3. Creates the resolved CALLS edge in the graph
50pub fn resolve_calls(
51    store: &GraphStore,
52    extractions: &[FileExtraction],
53    learned_store: Option<&LearnedStore>,
54) -> Result<ResolveStats> {
55    let conn = store.connection()?;
56
57    // Build global symbol table: name -> list of (id, file, kind)
58    let mut symbol_map: HashMap<String, Vec<(String, String, String)>> = HashMap::new();
59    for ext in extractions {
60        for sym in &ext.symbols {
61            symbol_map.entry(sym.name.clone()).or_default().push((
62                sym.id.clone(),
63                ext.file.clone(),
64                sym.kind.as_str().to_string(),
65            ));
66        }
67    }
68
69    let mut stats = resolve_with_map(&conn, extractions, &symbol_map, learned_store)?;
70    stats.inherits_resolved = resolve_inherits(&conn, extractions, &symbol_map)?;
71    Ok(stats)
72}
73
74fn resolve_with_map(
75    conn: &kuzu::Connection<'_>,
76    extractions: &[FileExtraction],
77    symbol_map: &HashMap<String, Vec<(String, String, String)>>,
78    learned_store: Option<&LearnedStore>,
79) -> Result<ResolveStats> {
80    let mut resolved = 0;
81    let mut unresolved = 0;
82    let mut total_dangling = 0;
83    let mut resolved_pairs: Vec<(String, String)> = Vec::new();
84    let mut learned_resolved = 0usize;
85
86    // Build class-method index: "ClassName::method" -> symbol_id
87    let mut class_method_map: HashMap<String, Vec<(String, String)>> = HashMap::new();
88    for candidates in symbol_map.values() {
89        for (id, _file, kind) in candidates {
90            if kind == "Method" || kind == "Function" {
91                let parts: Vec<&str> = id.rsplitn(3, "::").collect();
92                if parts.len() >= 2 {
93                    let method = parts[0];
94                    let class = parts[1];
95                    let key = format!("{}::{}", class, method);
96                    class_method_map
97                        .entry(key)
98                        .or_default()
99                        .push((id.clone(), _file.clone()));
100                }
101            }
102        }
103    }
104
105    for ext in extractions {
106        let local_symbols: HashMap<&str, &str> = ext
107            .symbols
108            .iter()
109            .map(|s| (s.name.as_str(), s.id.as_str()))
110            .collect();
111
112        let imported_stems: std::collections::HashSet<String> = ext
113            .relations
114            .iter()
115            .filter(|r| r.kind == RelationKind::Imports)
116            .map(|r| {
117                let raw = r
118                    .target_id
119                    .rsplit(['/', '\\', '.'])
120                    .next()
121                    .unwrap_or(&r.target_id);
122                raw.to_lowercase()
123            })
124            .collect();
125
126        let source_is_sql = ext.file.ends_with(".sql");
127
128        for rel in &ext.relations {
129            if rel.kind != RelationKind::Calls {
130                continue;
131            }
132
133            let target_name = rel.target_id.rsplit("::").next().unwrap_or(&rel.target_id);
134
135            if local_symbols.contains_key(target_name) {
136                continue;
137            }
138
139            total_dangling += 1;
140
141            // Layer 3: Learned pattern lookup (from prior SCIP corrections).
142            if let Some(ls) = learned_store {
143                if let Some(pattern) = ls.lookup(&ext.file, target_name) {
144                    let target_exists = symbol_map.values().any(|candidates| {
145                        candidates
146                            .iter()
147                            .any(|(id, _, _)| *id == pattern.resolved_to_symbol)
148                    });
149                    if target_exists {
150                        resolved_pairs
151                            .push((rel.source_id.clone(), pattern.resolved_to_symbol.clone()));
152                        resolved += 1;
153                        learned_resolved += 1;
154                        continue;
155                    }
156                }
157            }
158
159            // Strategy 1: Receiver-aware resolution.
160            if let Some(ref receiver) = rel.receiver {
161                let qualified = format!("{}::{}", receiver, target_name);
162                if let Some(matches) = class_method_map.get(&qualified) {
163                    let best = if matches.len() == 1 {
164                        Some(matches[0].0.clone())
165                    } else {
166                        matches
167                            .iter()
168                            .find(|(_, f)| {
169                                let stem = std::path::Path::new(f)
170                                    .file_stem()
171                                    .and_then(|s| s.to_str())
172                                    .map(|s| s.to_lowercase())
173                                    .unwrap_or_default();
174                                imported_stems.contains(&stem)
175                            })
176                            .or(matches.first())
177                            .map(|(id, _)| id.clone())
178                    };
179                    if let Some(target_id) = best {
180                        resolved_pairs.push((rel.source_id.clone(), target_id));
181                        resolved += 1;
182                        continue;
183                    }
184                }
185            }
186
187            // Strategy 2: Enclosing-class preference.
188            let caller_class = rel.source_id.rsplit("::").nth(1).map(|s| s.to_string());
189
190            if let Some(candidates) = symbol_map.get(target_name) {
191                let cross_file: Vec<_> = candidates
192                    .iter()
193                    .filter(|(_, f, kind)| {
194                        if *f == ext.file {
195                            return false;
196                        }
197                        if source_is_sql && f.ends_with(".sql") && kind == "Function" {
198                            return false;
199                        }
200                        true
201                    })
202                    .collect();
203
204                let resolved_id = if cross_file.len() == 1 {
205                    Some(cross_file[0].0.clone())
206                } else if cross_file.len() > 1 {
207                    let by_receiver: Option<String> = rel.receiver.as_ref().and_then(|recv| {
208                        cross_file
209                            .iter()
210                            .find(|(id, _, _)| id.contains(&format!("::{}::{}", recv, target_name)))
211                            .map(|(id, _, _)| id.clone())
212                    });
213
214                    if by_receiver.is_some() {
215                        by_receiver
216                    } else if let Some(ref cls) = caller_class {
217                        let same_class = cross_file
218                            .iter()
219                            .find(|(id, _, _)| id.contains(&format!("::{cls}::")))
220                            .map(|(id, _, _)| id.clone());
221                        if same_class.is_some() {
222                            same_class
223                        } else {
224                            import_scope_match(&cross_file, &imported_stems, source_is_sql)
225                        }
226                    } else {
227                        import_scope_match(&cross_file, &imported_stems, source_is_sql)
228                    }
229                } else {
230                    None
231                };
232
233                if let Some(target_id) = resolved_id {
234                    resolved_pairs.push((rel.source_id.clone(), target_id));
235                    resolved += 1;
236                } else {
237                    unresolved += 1;
238                }
239            } else {
240                unresolved += 1;
241            }
242        }
243    }
244
245    // Batch insert resolved CALLS edges via COPY FROM parquet
246    if !resolved_pairs.is_empty() {
247        let mut known_ids: std::collections::HashSet<&str> = symbol_map
248            .values()
249            .flat_map(|v| v.iter().map(|(id, _, _)| id.as_str()))
250            .collect();
251        for ext in extractions {
252            for sym in &ext.symbols {
253                known_ids.insert(&sym.id);
254            }
255        }
256        let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
257        for ext in extractions {
258            for sym in &ext.symbols {
259                file_name_to_ids
260                    .entry((ext.file.clone(), sym.name.clone()))
261                    .or_default()
262                    .push(sym.id.clone());
263            }
264        }
265        for candidates in symbol_map.values() {
266            for (id, file, _kind) in candidates {
267                let name = id.rsplit("::").next().unwrap_or(id);
268                file_name_to_ids
269                    .entry((file.clone(), name.to_string()))
270                    .or_default()
271                    .push(id.clone());
272            }
273        }
274
275        let fixed_pairs: Vec<(String, String)> = resolved_pairs
276            .iter()
277            .flat_map(|(src, tgt)| {
278                if known_ids.contains(src.as_str()) {
279                    vec![(src.clone(), tgt.clone())]
280                } else if let Some(sep) = src.rfind("::") {
281                    let file_part = &src[..sep];
282                    let name_part = &src[sep + 2..];
283                    if let Some(ids) =
284                        file_name_to_ids.get(&(file_part.to_string(), name_part.to_string()))
285                    {
286                        ids.iter()
287                            .filter(|id| known_ids.contains(id.as_str()))
288                            .map(|id| (id.clone(), tgt.clone()))
289                            .collect::<Vec<_>>()
290                    } else {
291                        vec![(src.clone(), tgt.clone())]
292                    }
293                } else {
294                    vec![(src.clone(), tgt.clone())]
295                }
296            })
297            .collect();
298
299        let valid_pairs: Vec<&(String, String)> = fixed_pairs
300            .iter()
301            .filter(|(src, tgt)| {
302                known_ids.contains(src.as_str()) && known_ids.contains(tgt.as_str())
303            })
304            .collect();
305
306        let refs: Vec<(&str, &str)> = valid_pairs
307            .iter()
308            .map(|(a, b)| (a.as_str(), b.as_str()))
309            .collect();
310        let pq_path = std::env::temp_dir().join("infigraph_resolve_calls.parquet");
311        crate::graph::parquet_loader::write_edge_parquet(&pq_path, &refs)?;
312        let copy_result = conn.query(&format!(
313            "COPY CALLS FROM '{}'",
314            pq_path.to_string_lossy().replace('\\', "/")
315        ));
316        if let Err(e) = copy_result {
317            eprintln!("[resolve] COPY FROM parquet failed ({e}), falling back to UNWIND");
318            const CHUNK_SIZE: usize = 500;
319            for chunk in refs.chunks(CHUNK_SIZE) {
320                let pair_list: Vec<String> = chunk
321                    .iter()
322                    .map(|(a, b)| format!("{{a: '{}', b: '{}'}}", escape(a), escape(b)))
323                    .collect();
324                let _ = conn.query(&format!(
325                    "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:CALLS]->(b)",
326                    pair_list.join(", ")
327                ));
328            }
329        }
330        let _ = std::fs::remove_file(&pq_path);
331    }
332
333    Ok(ResolveStats {
334        total_calls: total_dangling,
335        resolved,
336        unresolved,
337        learned_resolved,
338        inherits_resolved: 0,
339    })
340}
341
342/// Targeted re-resolution for a subset of files.
343pub fn re_resolve_for_files(
344    store: &GraphStore,
345    files: &[String],
346    extractions: &[FileExtraction],
347    learned_store: Option<&LearnedStore>,
348) -> Result<ResolveStats> {
349    if files.is_empty() || extractions.is_empty() {
350        return Ok(ResolveStats {
351            total_calls: 0,
352            resolved: 0,
353            unresolved: 0,
354            learned_resolved: 0,
355            inherits_resolved: 0,
356        });
357    }
358
359    let conn = store.connection()?;
360
361    for file in files {
362        let escaped = escape(file);
363        let _ = conn.query(&format!(
364            "MATCH (a:Symbol)-[r:CALLS]->(b:Symbol) WHERE a.file = '{}' DELETE r",
365            escaped
366        ));
367        let _ = conn.query(&format!(
368            "MATCH (a:Symbol)-[r:INHERITS]->(b:Symbol) WHERE a.file = '{}' DELETE r",
369            escaped
370        ));
371    }
372
373    let mut symbol_map: HashMap<String, Vec<(String, String, String)>> = HashMap::new();
374    for (name, id, file, kind) in store.get_all_symbols()? {
375        symbol_map.entry(name).or_default().push((id, file, kind));
376    }
377
378    let target_files: std::collections::HashSet<&str> = files.iter().map(|f| f.as_str()).collect();
379    let filtered: Vec<&FileExtraction> = extractions
380        .iter()
381        .filter(|e| target_files.contains(e.file.as_str()))
382        .collect();
383
384    let filtered_owned: Vec<FileExtraction> = filtered.into_iter().cloned().collect();
385    let mut stats = resolve_with_map(&conn, &filtered_owned, &symbol_map, learned_store)?;
386    stats.inherits_resolved = resolve_inherits(&conn, &filtered_owned, &symbol_map)?;
387    Ok(stats)
388}
389
390#[derive(Debug)]
391pub struct ResolveStats {
392    pub total_calls: usize,
393    pub resolved: usize,
394    pub unresolved: usize,
395    pub learned_resolved: usize,
396    pub inherits_resolved: usize,
397}
398
399impl std::fmt::Display for ResolveStats {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        if self.learned_resolved > 0 {
402            write!(
403                f,
404                "Call resolution: {} cross-file calls, {} resolved ({} from learned patterns), {} unresolved (builtins/externals)",
405                self.total_calls, self.resolved, self.learned_resolved, self.unresolved
406            )?;
407        } else {
408            write!(
409                f,
410                "Call resolution: {} cross-file calls, {} resolved, {} unresolved (builtins/externals)",
411                self.total_calls, self.resolved, self.unresolved
412            )?;
413        }
414        if self.inherits_resolved > 0 {
415            write!(f, ", {} inheritance edges resolved", self.inherits_resolved)?;
416        }
417        Ok(())
418    }
419}
420
421const TYPE_KINDS: &[&str] = &["Class", "Interface", "Struct", "Trait", "Enum"];
422
423fn resolve_inherits(
424    conn: &kuzu::Connection<'_>,
425    extractions: &[FileExtraction],
426    symbol_map: &HashMap<String, Vec<(String, String, String)>>,
427) -> Result<usize> {
428    let mut resolved_pairs: Vec<(String, String)> = Vec::new();
429
430    for ext in extractions {
431        let local_symbols: std::collections::HashSet<&str> =
432            ext.symbols.iter().map(|s| s.name.as_str()).collect();
433
434        let imported_stems: std::collections::HashSet<String> = ext
435            .relations
436            .iter()
437            .filter(|r| r.kind == RelationKind::Imports)
438            .map(|r| {
439                let raw = r
440                    .target_id
441                    .rsplit(['/', '\\', '.'])
442                    .next()
443                    .unwrap_or(&r.target_id);
444                raw.to_lowercase()
445            })
446            .collect();
447
448        for rel in &ext.relations {
449            if rel.kind != RelationKind::Inherits {
450                continue;
451            }
452
453            let target_name = rel.target_id.rsplit("::").next().unwrap_or(&rel.target_id);
454
455            if local_symbols.contains(target_name) {
456                continue;
457            }
458
459            if let Some(candidates) = symbol_map.get(target_name) {
460                let cross_file: Vec<_> = candidates
461                    .iter()
462                    .filter(|(_, f, kind)| *f != ext.file && TYPE_KINDS.contains(&kind.as_str()))
463                    .collect();
464
465                let resolved_id = if cross_file.len() == 1 {
466                    Some(cross_file[0].0.clone())
467                } else if cross_file.len() > 1 {
468                    let in_scope = cross_file.iter().find(|(_, f, _)| {
469                        let stem = std::path::Path::new(f)
470                            .file_stem()
471                            .and_then(|s| s.to_str())
472                            .map(|s| s.to_lowercase())
473                            .unwrap_or_default();
474                        imported_stems.contains(&stem)
475                    });
476                    let by_kind = cross_file.iter().find(|(_, _, k)| k == "Interface");
477                    in_scope
478                        .or(by_kind)
479                        .or(cross_file.first())
480                        .map(|(id, _, _)| id.clone())
481                } else {
482                    None
483                };
484
485                if let Some(target_id) = resolved_id {
486                    resolved_pairs.push((rel.source_id.clone(), target_id));
487                }
488            }
489        }
490    }
491
492    if resolved_pairs.is_empty() {
493        return Ok(0);
494    }
495
496    let count = resolved_pairs.len();
497
498    let mut known_ids: std::collections::HashSet<&str> = symbol_map
499        .values()
500        .flat_map(|v| v.iter().map(|(id, _, _)| id.as_str()))
501        .collect();
502    for ext in extractions {
503        for sym in &ext.symbols {
504            known_ids.insert(&sym.id);
505        }
506    }
507
508    let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
509    for ext in extractions {
510        for sym in &ext.symbols {
511            file_name_to_ids
512                .entry((ext.file.clone(), sym.name.clone()))
513                .or_default()
514                .push(sym.id.clone());
515        }
516    }
517    for candidates in symbol_map.values() {
518        for (id, file, _) in candidates {
519            let name = id.rsplit("::").next().unwrap_or(id);
520            file_name_to_ids
521                .entry((file.clone(), name.to_string()))
522                .or_default()
523                .push(id.clone());
524        }
525    }
526
527    let fixed_pairs: Vec<(String, String)> = resolved_pairs
528        .iter()
529        .flat_map(|(src, tgt)| {
530            if known_ids.contains(src.as_str()) {
531                vec![(src.clone(), tgt.clone())]
532            } else if let Some(sep) = src.rfind("::") {
533                let file_part = &src[..sep];
534                let name_part = &src[sep + 2..];
535                if let Some(ids) =
536                    file_name_to_ids.get(&(file_part.to_string(), name_part.to_string()))
537                {
538                    ids.iter()
539                        .filter(|id| known_ids.contains(id.as_str()))
540                        .map(|id| (id.clone(), tgt.clone()))
541                        .collect::<Vec<_>>()
542                } else {
543                    vec![(src.clone(), tgt.clone())]
544                }
545            } else {
546                vec![(src.clone(), tgt.clone())]
547            }
548        })
549        .collect();
550
551    let valid_pairs: Vec<&(String, String)> = fixed_pairs
552        .iter()
553        .filter(|(src, tgt)| known_ids.contains(src.as_str()) && known_ids.contains(tgt.as_str()))
554        .collect();
555
556    if valid_pairs.is_empty() {
557        return Ok(0);
558    }
559
560    let refs: Vec<(&str, &str)> = valid_pairs
561        .iter()
562        .map(|(a, b)| (a.as_str(), b.as_str()))
563        .collect();
564    let pq_path = std::env::temp_dir().join("infigraph_resolve_inherits.parquet");
565    crate::graph::parquet_loader::write_edge_parquet(&pq_path, &refs)?;
566    let copy_result = conn.query(&format!(
567        "COPY INHERITS FROM '{}'",
568        pq_path.to_string_lossy().replace('\\', "/")
569    ));
570    if let Err(e) = copy_result {
571        eprintln!("[resolve] COPY INHERITS FROM parquet failed ({e}), falling back to UNWIND");
572        const CHUNK_SIZE: usize = 500;
573        for chunk in refs.chunks(CHUNK_SIZE) {
574            let pair_list: Vec<String> = chunk
575                .iter()
576                .map(|(a, b)| format!("{{a: '{}', b: '{}'}}", escape(a), escape(b)))
577                .collect();
578            let _ = conn.query(&format!(
579                "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:INHERITS]->(b)",
580                pair_list.join(", ")
581            ));
582        }
583    }
584    let _ = std::fs::remove_file(&pq_path);
585
586    Ok(count)
587}
588
589fn import_scope_match(
590    cross_file: &[&(String, String, String)],
591    imported_stems: &std::collections::HashSet<String>,
592    source_is_sql: bool,
593) -> Option<String> {
594    let in_scope: Vec<_> = if !imported_stems.is_empty() {
595        cross_file
596            .iter()
597            .filter(|(_, f, _)| {
598                let stem = std::path::Path::new(f)
599                    .file_stem()
600                    .and_then(|s| s.to_str())
601                    .map(|s| s.to_lowercase())
602                    .unwrap_or_default();
603                imported_stems.contains(&stem)
604            })
605            .collect()
606    } else {
607        vec![]
608    };
609    if !in_scope.is_empty() {
610        Some(in_scope[0].0.clone())
611    } else if source_is_sql {
612        cross_file
613            .iter()
614            .find(|(_, _, k)| *k == "Class")
615            .map(|(id, _, _)| id.clone())
616    } else {
617        None
618    }
619}
620
621fn escape(s: &str) -> String {
622    s.replace('\'', "\\'")
623}