Skip to main content

infiniloom_engine/embedding/
import_resolver.rs

1//! Import-aware call graph resolution for embedding chunks.
2//!
3//! This module builds per-file import scope maps and resolves raw call names
4//! against imported symbols, enabling more accurate `called_by` relationships.
5//! Supports Rust, TypeScript, and Python import patterns.
6//!
7//! # Design
8//!
9//! Instead of matching calls purely by unqualified name (which produces false
10//! positives when multiple files define symbols with the same name), the resolver
11//! uses each file's import statements to determine which definition a call refers to.
12//!
13//! The resolver produces two new fields on `ChunkContext`:
14//! - `qualified_calls`: calls successfully resolved to a qualified name via imports
15//! - `unresolved_calls`: calls that could not be matched to any import or local symbol
16
17use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
18
19use super::types::EmbedChunk;
20
21/// Import resolver that builds per-file import maps from chunk metadata.
22///
23/// After construction from a set of chunks, it can resolve a raw call name
24/// from a given file to a qualified name (module + symbol) or report it as
25/// unresolved.
26pub struct ImportResolver {
27    /// Per-file import maps: file_path -> (imported_name -> source_module)
28    file_imports: HashMap<String, HashMap<String, String>>,
29    /// Per-file local symbols: file_path -> set of symbol names defined in that file
30    file_symbols: HashMap<String, HashSet<String>>,
31}
32
33impl ImportResolver {
34    /// Build an `ImportResolver` from the generated chunks.
35    ///
36    /// For each chunk we:
37    /// 1. Record its symbol name in `file_symbols[file]`
38    /// 2. Parse its `context.imports` strings into `file_imports[file]`
39    pub fn from_chunks(chunks: &[EmbedChunk]) -> Self {
40        let mut file_imports: HashMap<String, HashMap<String, String>> = HashMap::new();
41        let mut file_symbols: HashMap<String, HashSet<String>> = HashMap::new();
42
43        for chunk in chunks {
44            let file = &chunk.source.file;
45
46            // Record this chunk's symbol name as a local symbol for its file
47            if !chunk.source.symbol.is_empty() && chunk.source.symbol != "<top_level>" {
48                file_symbols
49                    .entry(file.clone())
50                    .or_default()
51                    .insert(chunk.source.symbol.clone());
52            }
53
54            // Parse import strings for this chunk's file
55            for import_str in &chunk.context.imports {
56                let parsed = parse_import(import_str);
57                for (name, source) in parsed {
58                    file_imports
59                        .entry(file.clone())
60                        .or_default()
61                        .insert(name, source);
62                }
63            }
64        }
65
66        Self { file_imports, file_symbols }
67    }
68
69    /// Resolve a call from a given file.
70    ///
71    /// Resolution order:
72    /// 1. If `call_name` is defined in the same file, return `"file::call_name"`
73    /// 2. If `call_name` appears in the file's imports, return `"source_module::call_name"`
74    /// 3. Otherwise return `None` (unresolved)
75    pub fn resolve_call(&self, file: &str, call_name: &str) -> Option<String> {
76        // 1. Check same-file symbols
77        if let Some(symbols) = self.file_symbols.get(file) {
78            if symbols.contains(call_name) {
79                return Some(format!("{}::{}", file, call_name));
80            }
81        }
82
83        // 2. Check imports for this file
84        if let Some(imports) = self.file_imports.get(file) {
85            if let Some(source) = imports.get(call_name) {
86                return Some(format!("{}::{}", source, call_name));
87            }
88        }
89
90        // 3. Unresolved
91        None
92    }
93
94    /// Resolve all calls for all chunks, populating `qualified_calls` and `unresolved_calls`.
95    ///
96    /// Also builds an improved reverse call map using qualified names for the
97    /// `called_by` pass, reducing false-positive matches.
98    pub fn resolve_all_calls(&self, chunks: &mut [EmbedChunk]) {
99        for chunk in chunks.iter_mut() {
100            let file = &chunk.source.file;
101            let mut qualified = BTreeSet::new();
102            let mut unresolved = BTreeSet::new();
103
104            for call_name in &chunk.context.calls {
105                match self.resolve_call(file, call_name) {
106                    Some(qname) => {
107                        qualified.insert(qname);
108                    },
109                    None => {
110                        unresolved.insert(call_name.clone());
111                    },
112                }
113            }
114
115            chunk.context.qualified_calls = qualified.into_iter().collect();
116            chunk.context.unresolved_calls = unresolved.into_iter().collect();
117        }
118    }
119
120    /// Build a reverse call map using qualified names for more accurate `called_by`.
121    ///
122    /// Returns a map from qualified callee name -> set of caller identifiers (FQN or symbol name).
123    /// This is used alongside the existing unqualified matching to improve accuracy.
124    pub fn build_qualified_reverse_map(
125        &self,
126        chunks: &[EmbedChunk],
127    ) -> BTreeMap<String, BTreeSet<String>> {
128        let mut reverse: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
129
130        for chunk in chunks {
131            let caller_fqn = chunk
132                .source
133                .fqn
134                .as_deref()
135                .unwrap_or(&chunk.source.symbol)
136                .to_owned();
137
138            for qcall in &chunk.context.qualified_calls {
139                reverse
140                    .entry(qcall.clone())
141                    .or_default()
142                    .insert(caller_fqn.clone());
143            }
144        }
145
146        reverse
147    }
148}
149
150/// Parse a single import string into (imported_name, source_module) pairs.
151///
152/// Supports three patterns:
153/// - **Rust**: `use crate::auth::jwt::verify_token` -> `("verify_token", "crate::auth::jwt")`
154/// - **TypeScript**: `import { verify } from './auth/jwt'` -> `("verify", "./auth/jwt")`
155/// - **Python**: `from auth.jwt import verify` -> `("verify", "auth.jwt")`
156///
157/// Also handles multi-import forms:
158/// - Rust: `use crate::auth::{Token, verify}` -> two entries
159/// - Python: `from auth import verify, Token` -> two entries
160/// - TypeScript: `import { verify, Token } from './auth'` -> two entries
161fn parse_import(import_str: &str) -> Vec<(String, String)> {
162    let trimmed = import_str.trim();
163
164    // Try Rust pattern: `use path::to::module::Symbol` or `use path::{A, B}`
165    if let Some(result) = parse_rust_import(trimmed) {
166        return result;
167    }
168
169    // Try TypeScript/JavaScript pattern: `import { X, Y } from 'module'`
170    if let Some(result) = parse_typescript_import(trimmed) {
171        return result;
172    }
173
174    // Try Python pattern: `from module import X, Y`
175    if let Some(result) = parse_python_import(trimmed) {
176        return result;
177    }
178
179    Vec::new()
180}
181
182/// Parse a Rust `use` statement.
183///
184/// Handles:
185/// - `use crate::auth::jwt::verify_token;`
186/// - `use crate::auth::{Token, verify};`
187/// - `use std::collections::HashMap;`
188/// - `use super::types::EmbedChunk;`
189fn parse_rust_import(s: &str) -> Option<Vec<(String, String)>> {
190    let s = s.strip_prefix("use ")?.trim_end_matches(';').trim();
191
192    // Check for brace group: `path::{A, B}`
193    if let Some(brace_start) = s.find("::{") {
194        let module_path = &s[..brace_start];
195        let brace_content = s.get(brace_start + 3..)?.strip_suffix('}')?.trim();
196
197        let results: Vec<(String, String)> = brace_content
198            .split(',')
199            .filter_map(|item| {
200                let name = item.trim();
201                if name.is_empty() {
202                    return None;
203                }
204                // Handle `Name as Alias`
205                let imported_name = if let Some(alias_pos) = name.find(" as ") {
206                    name[alias_pos + 4..].trim()
207                } else {
208                    name
209                };
210                Some((imported_name.to_owned(), module_path.to_owned()))
211            })
212            .collect();
213
214        if results.is_empty() {
215            None
216        } else {
217            Some(results)
218        }
219    } else {
220        // Simple use: `use path::to::Symbol`
221        // or `use path::to::Symbol as Alias`
222        let (path, alias) = if let Some(as_pos) = s.find(" as ") {
223            (&s[..as_pos], Some(s[as_pos + 4..].trim()))
224        } else {
225            (s, None)
226        };
227
228        if let Some(last_sep) = path.rfind("::") {
229            let module = &path[..last_sep];
230            let symbol = alias.unwrap_or(&path[last_sep + 2..]);
231            Some(vec![(symbol.to_owned(), module.to_owned())])
232        } else {
233            // Top-level import like `use serde;`
234            let symbol = alias.unwrap_or(path);
235            Some(vec![(symbol.to_owned(), String::new())])
236        }
237    }
238}
239
240/// Parse a TypeScript/JavaScript import statement.
241///
242/// Handles:
243/// - `import { verify, Token } from './auth/jwt'`
244/// - `import { verify as check } from './auth'`
245/// - `import verify from './auth/jwt'` (default import)
246fn parse_typescript_import(s: &str) -> Option<Vec<(String, String)>> {
247    let s = s.strip_prefix("import ")?;
248
249    // Extract the `from 'module'` or `from "module"` part
250    let from_pos = s.rfind(" from ")?;
251    let names_part = s[..from_pos].trim();
252    let module_part = s[from_pos + 6..].trim();
253
254    // Strip quotes from module path
255    let module = module_part
256        .trim_matches('\'')
257        .trim_matches('"')
258        .trim_end_matches(';');
259
260    // Check for named imports: { A, B }
261    if let Some(brace_content) = names_part
262        .strip_prefix('{')
263        .and_then(|s| s.strip_suffix('}'))
264    {
265        let results: Vec<(String, String)> = brace_content
266            .split(',')
267            .filter_map(|item| {
268                let item = item.trim();
269                if item.is_empty() {
270                    return None;
271                }
272                // Handle `name as alias`
273                let imported_name = if let Some(as_pos) = item.find(" as ") {
274                    item[as_pos + 4..].trim()
275                } else {
276                    item
277                };
278                Some((imported_name.to_owned(), module.to_owned()))
279            })
280            .collect();
281
282        if results.is_empty() {
283            None
284        } else {
285            Some(results)
286        }
287    } else {
288        // Default import: `import verify from './auth'`
289        let name = names_part.trim();
290        if name.is_empty() {
291            None
292        } else {
293            Some(vec![(name.to_owned(), module.to_owned())])
294        }
295    }
296}
297
298/// Parse a Python import statement.
299///
300/// Handles:
301/// - `from auth.jwt import verify`
302/// - `from auth.jwt import verify, Token`
303/// - `from auth.jwt import verify as check`
304/// - `import auth.jwt` (maps `jwt` -> `auth`)
305fn parse_python_import(s: &str) -> Option<Vec<(String, String)>> {
306    // `from module import names`
307    if let Some(rest) = s.strip_prefix("from ") {
308        let import_pos = rest.find(" import ")?;
309        let module = rest[..import_pos].trim();
310        let names_part = rest[import_pos + 8..].trim();
311
312        let results: Vec<(String, String)> = names_part
313            .split(',')
314            .filter_map(|item| {
315                let item = item.trim();
316                if item.is_empty() {
317                    return None;
318                }
319                let imported_name = if let Some(as_pos) = item.find(" as ") {
320                    item[as_pos + 4..].trim()
321                } else {
322                    item
323                };
324                Some((imported_name.to_owned(), module.to_owned()))
325            })
326            .collect();
327
328        if results.is_empty() {
329            None
330        } else {
331            Some(results)
332        }
333    }
334    // `import module.submodule`
335    else if let Some(rest) = s.strip_prefix("import ") {
336        let module_path = rest.trim().trim_end_matches(';');
337        // For `import auth.jwt`, map `jwt` -> `auth`
338        if let Some(last_dot) = module_path.rfind('.') {
339            let parent = &module_path[..last_dot];
340            let name = &module_path[last_dot + 1..];
341            Some(vec![(name.to_owned(), parent.to_owned())])
342        } else {
343            // Top-level import like `import os`
344            Some(vec![(module_path.to_owned(), String::new())])
345        }
346    } else {
347        None
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    // === Rust import parsing ===
356
357    #[test]
358    fn test_parse_rust_simple_import() {
359        let result = parse_import("use crate::auth::jwt::verify_token;");
360        assert_eq!(result, vec![("verify_token".to_owned(), "crate::auth::jwt".to_owned())]);
361    }
362
363    #[test]
364    fn test_parse_rust_brace_import() {
365        let result = parse_import("use crate::auth::{Token, verify};");
366        assert_eq!(result.len(), 2);
367        assert!(result.contains(&("Token".to_owned(), "crate::auth".to_owned())));
368        assert!(result.contains(&("verify".to_owned(), "crate::auth".to_owned())));
369    }
370
371    #[test]
372    fn test_parse_rust_alias_import() {
373        let result = parse_import("use std::collections::HashMap as Map;");
374        assert_eq!(result, vec![("Map".to_owned(), "std::collections".to_owned())]);
375    }
376
377    #[test]
378    fn test_parse_rust_super_import() {
379        let result = parse_import("use super::types::EmbedChunk;");
380        assert_eq!(result, vec![("EmbedChunk".to_owned(), "super::types".to_owned())]);
381    }
382
383    // === TypeScript import parsing ===
384
385    #[test]
386    fn test_parse_typescript_named_import() {
387        let result = parse_import("import { verify } from './auth/jwt'");
388        assert_eq!(result, vec![("verify".to_owned(), "./auth/jwt".to_owned())]);
389    }
390
391    #[test]
392    fn test_parse_typescript_multi_import() {
393        let result = parse_import("import { verify, Token } from './auth'");
394        assert_eq!(result.len(), 2);
395        assert!(result.contains(&("verify".to_owned(), "./auth".to_owned())));
396        assert!(result.contains(&("Token".to_owned(), "./auth".to_owned())));
397    }
398
399    #[test]
400    fn test_parse_typescript_alias_import() {
401        let result = parse_import("import { verify as check } from './auth'");
402        assert_eq!(result, vec![("check".to_owned(), "./auth".to_owned())]);
403    }
404
405    #[test]
406    fn test_parse_typescript_default_import() {
407        let result = parse_import("import Router from 'express'");
408        assert_eq!(result, vec![("Router".to_owned(), "express".to_owned())]);
409    }
410
411    #[test]
412    fn test_parse_typescript_double_quotes() {
413        let result = parse_import("import { verify } from \"./auth/jwt\"");
414        assert_eq!(result, vec![("verify".to_owned(), "./auth/jwt".to_owned())]);
415    }
416
417    // === Python import parsing ===
418
419    #[test]
420    fn test_parse_python_from_import() {
421        let result = parse_import("from auth.jwt import verify");
422        assert_eq!(result, vec![("verify".to_owned(), "auth.jwt".to_owned())]);
423    }
424
425    #[test]
426    fn test_parse_python_multi_import() {
427        let result = parse_import("from auth.jwt import verify, Token");
428        assert_eq!(result.len(), 2);
429        assert!(result.contains(&("verify".to_owned(), "auth.jwt".to_owned())));
430        assert!(result.contains(&("Token".to_owned(), "auth.jwt".to_owned())));
431    }
432
433    #[test]
434    fn test_parse_python_alias_import() {
435        let result = parse_import("from auth.jwt import verify as check");
436        assert_eq!(result, vec![("check".to_owned(), "auth.jwt".to_owned())]);
437    }
438
439    #[test]
440    fn test_parse_python_plain_import() {
441        let result = parse_import("import os.path");
442        assert_eq!(result, vec![("path".to_owned(), "os".to_owned())]);
443    }
444
445    #[test]
446    fn test_parse_python_toplevel_import() {
447        let result = parse_import("import os");
448        assert_eq!(result, vec![("os".to_owned(), String::new())]);
449    }
450
451    // === Same-file resolution ===
452
453    #[test]
454    fn test_resolve_same_file() {
455        let chunks = vec![make_chunk("src/lib.rs", "foo", &[], &[])];
456        let resolver = ImportResolver::from_chunks(&chunks);
457        let resolved = resolver.resolve_call("src/lib.rs", "foo");
458        assert_eq!(resolved, Some("src/lib.rs::foo".to_owned()));
459    }
460
461    #[test]
462    fn test_resolve_via_import() {
463        let chunks =
464            vec![make_chunk("src/main.rs", "main", &["use crate::auth::verify;"], &["verify"])];
465        let resolver = ImportResolver::from_chunks(&chunks);
466        let resolved = resolver.resolve_call("src/main.rs", "verify");
467        assert_eq!(resolved, Some("crate::auth::verify".to_owned()));
468    }
469
470    #[test]
471    fn test_resolve_unresolved() {
472        let chunks = vec![make_chunk("src/main.rs", "main", &[], &["unknown_fn"])];
473        let resolver = ImportResolver::from_chunks(&chunks);
474        let resolved = resolver.resolve_call("src/main.rs", "unknown_fn");
475        assert_eq!(resolved, None);
476    }
477
478    #[test]
479    fn test_resolve_all_calls() {
480        let mut chunks = vec![
481            make_chunk(
482                "src/main.rs",
483                "main",
484                &["use crate::auth::verify;"],
485                &["verify", "unknown"],
486            ),
487            make_chunk("src/auth.rs", "verify", &[], &[]),
488        ];
489
490        let resolver = ImportResolver::from_chunks(&chunks);
491        resolver.resolve_all_calls(&mut chunks);
492
493        assert_eq!(chunks[0].context.qualified_calls, vec!["crate::auth::verify".to_owned()]);
494        assert_eq!(chunks[0].context.unresolved_calls, vec!["unknown".to_owned()]);
495    }
496
497    #[test]
498    fn test_unrecognized_import_format() {
499        let result = parse_import("require('some-module')");
500        assert!(result.is_empty());
501    }
502
503    /// Helper to create a minimal test chunk
504    fn make_chunk(file: &str, symbol: &str, imports: &[&str], calls: &[&str]) -> EmbedChunk {
505        use super::super::types::{
506            ChunkContext, ChunkKind, ChunkSource, RepoIdentifier, Visibility,
507        };
508
509        EmbedChunk {
510            id: format!("ec_{}", symbol),
511            full_hash: String::new(),
512            content: String::new(),
513            tokens: 0,
514            kind: ChunkKind::Function,
515            source: ChunkSource {
516                repo: RepoIdentifier::default(),
517                file: file.to_owned(),
518                lines: (1, 10),
519                symbol: symbol.to_owned(),
520                fqn: None,
521                language: "Rust".to_owned(),
522                parent: None,
523                visibility: Visibility::Public,
524                is_test: false,
525                module_path: None,
526                parent_chunk_id: None,
527            },
528            context: ChunkContext {
529                imports: imports.iter().map(|s| s.to_string()).collect(),
530                calls: calls.iter().map(|s| s.to_string()).collect(),
531                ..Default::default()
532            },
533            children_ids: Vec::new(),
534            repr: "code".to_string(),
535            code_chunk_id: None,
536            part: None,
537        }
538    }
539}