use crate::parsing::resolution::ImportBinding;
use crate::parsing::{InheritanceResolver, ResolutionScope, ScopeLevel, ScopeType};
use crate::{FileId, SymbolId};
use std::collections::HashMap;
pub struct RustResolutionContext {
#[allow(dead_code)]
file_id: FileId,
local_scope: HashMap<String, SymbolId>,
imported_symbols: HashMap<String, SymbolId>,
module_symbols: HashMap<String, SymbolId>,
crate_symbols: HashMap<String, SymbolId>,
scope_stack: Vec<ScopeType>,
imports: Vec<(String, Option<String>)>,
import_bindings: HashMap<String, ImportBinding>,
}
impl RustResolutionContext {
pub fn new(file_id: FileId) -> Self {
Self {
file_id,
local_scope: HashMap::new(),
imported_symbols: HashMap::new(),
module_symbols: HashMap::new(),
crate_symbols: HashMap::new(),
scope_stack: Vec::new(),
imports: Vec::new(),
import_bindings: HashMap::new(),
}
}
pub fn add_import(&mut self, path: String, alias: Option<String>) {
self.imports.push((path, alias));
}
pub fn add_local(&mut self, name: String, symbol_id: SymbolId) {
self.local_scope.insert(name, symbol_id);
}
pub fn add_import_symbol(&mut self, name: String, symbol_id: SymbolId, _is_aliased: bool) {
self.imported_symbols.insert(name, symbol_id);
}
pub fn add_module_symbol(&mut self, name: String, symbol_id: SymbolId) {
self.module_symbols.insert(name, symbol_id);
}
pub fn add_crate_symbol(&mut self, name: String, symbol_id: SymbolId) {
self.crate_symbols.insert(name, symbol_id);
}
pub fn is_imported(&self, name: &str) -> bool {
self.imported_symbols.contains_key(name)
}
}
impl ResolutionScope for RustResolutionContext {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn add_symbol(&mut self, name: String, symbol_id: SymbolId, scope_level: ScopeLevel) {
match scope_level {
ScopeLevel::Local => {
self.local_scope.insert(name, symbol_id);
}
ScopeLevel::Module => {
self.module_symbols.insert(name, symbol_id);
}
ScopeLevel::Package => {
self.imported_symbols.insert(name, symbol_id);
}
ScopeLevel::Global => {
self.crate_symbols.insert(name, symbol_id);
}
}
}
fn resolve(&self, name: &str) -> Option<SymbolId> {
if let Some(&id) = self.local_scope.get(name) {
return Some(id);
}
if let Some(&id) = self.imported_symbols.get(name) {
return Some(id);
}
if let Some(&id) = self.module_symbols.get(name) {
return Some(id);
}
if let Some(&id) = self.crate_symbols.get(name) {
return Some(id);
}
if name.contains("::") {
if let Some(&id) = self.imported_symbols.get(name) {
return Some(id);
}
if let Some(&id) = self.module_symbols.get(name) {
return Some(id);
}
if let Some(&id) = self.crate_symbols.get(name) {
return Some(id);
}
let parts: Vec<&str> = name.split("::").collect();
if parts.len() == 2 {
let type_or_module = parts[0];
let method_or_func = parts[1];
if self.resolve(type_or_module).is_some() {
return self.resolve(method_or_func);
}
}
return None;
}
None
}
fn clear_local_scope(&mut self) {
self.local_scope.clear();
}
fn enter_scope(&mut self, scope_type: ScopeType) {
self.scope_stack.push(scope_type);
}
fn exit_scope(&mut self) {
self.scope_stack.pop();
if matches!(
self.scope_stack.last(),
None | Some(ScopeType::Module | ScopeType::Global)
) {
self.clear_local_scope();
}
}
fn symbols_in_scope(&self) -> Vec<(String, SymbolId, ScopeLevel)> {
let mut symbols = Vec::new();
for (name, &id) in &self.local_scope {
symbols.push((name.clone(), id, ScopeLevel::Local));
}
for (name, &id) in &self.imported_symbols {
symbols.push((name.clone(), id, ScopeLevel::Package));
}
for (name, &id) in &self.module_symbols {
symbols.push((name.clone(), id, ScopeLevel::Module));
}
for (name, &id) in &self.crate_symbols {
symbols.push((name.clone(), id, ScopeLevel::Global));
}
symbols
}
fn resolve_relationship(
&self,
_from_name: &str,
to_name: &str,
kind: crate::RelationKind,
_from_file: FileId,
) -> Option<SymbolId> {
use crate::RelationKind;
match kind {
RelationKind::Defines => {
if let Some(method_id) = self.resolve(to_name) {
return Some(method_id);
}
None
}
RelationKind::Calls => {
if to_name.contains("::") {
self.resolve(to_name)
} else {
self.resolve(to_name)
}
}
_ => {
self.resolve(to_name)
}
}
}
fn populate_imports(&mut self, imports: &[crate::parsing::Import]) {
for import in imports {
self.add_import(import.path.clone(), import.alias.clone());
}
}
fn register_import_binding(&mut self, binding: ImportBinding) {
self.import_bindings
.insert(binding.exposed_name.clone(), binding);
}
fn import_binding(&self, name: &str) -> Option<ImportBinding> {
self.import_bindings.get(name).cloned()
}
}
#[derive(Clone)]
pub struct RustTraitResolver {
type_to_traits: HashMap<String, Vec<(String, FileId)>>,
trait_methods: HashMap<String, Vec<String>>,
type_method_to_trait: HashMap<(String, String), String>,
inherent_methods: HashMap<String, Vec<String>>,
}
impl Default for RustTraitResolver {
fn default() -> Self {
Self::new()
}
}
impl RustTraitResolver {
pub fn new() -> Self {
Self {
type_to_traits: HashMap::new(),
trait_methods: HashMap::new(),
type_method_to_trait: HashMap::new(),
inherent_methods: HashMap::new(),
}
}
fn is_inherent_method(&self, type_name: &str, method_name: &str) -> bool {
self.inherent_methods
.get(type_name)
.map(|methods| methods.iter().any(|m| m == method_name))
.unwrap_or(false)
}
}
impl InheritanceResolver for RustTraitResolver {
fn add_inheritance(&mut self, child: String, parent: String, kind: &str) {
if kind == "implements" {
self.type_to_traits
.entry(child)
.or_default()
.push((parent, FileId::new(1).unwrap()));
}
}
fn resolve_method(&self, type_name: &str, method_name: &str) -> Option<String> {
if self.is_inherent_method(type_name, method_name) {
return Some(type_name.to_string());
}
if let Some(trait_name) = self
.type_method_to_trait
.get(&(type_name.to_string(), method_name.to_string()))
{
return Some(trait_name.clone());
}
if let Some(traits) = self.type_to_traits.get(type_name) {
for (trait_name, _) in traits {
if let Some(methods) = self.trait_methods.get(trait_name) {
if methods.iter().any(|m| m == method_name) {
return Some(trait_name.clone());
}
}
}
}
None
}
fn get_inheritance_chain(&self, type_name: &str) -> Vec<String> {
let mut chain = vec![type_name.to_string()];
if let Some(traits) = self.type_to_traits.get(type_name) {
for (trait_name, _) in traits {
if !chain.contains(trait_name) {
chain.push(trait_name.clone());
}
}
}
chain
}
fn is_subtype(&self, child: &str, parent: &str) -> bool {
if let Some(traits) = self.type_to_traits.get(child) {
traits.iter().any(|(trait_name, _)| trait_name == parent)
} else {
false
}
}
fn add_type_methods(&mut self, type_name: String, methods: Vec<String>) {
self.inherent_methods
.entry(type_name)
.or_default()
.extend(methods);
}
fn get_all_methods(&self, type_name: &str) -> Vec<String> {
let mut all_methods = Vec::new();
if let Some(methods) = self.inherent_methods.get(type_name) {
all_methods.extend(methods.clone());
}
if let Some(traits) = self.type_to_traits.get(type_name) {
for (trait_name, _) in traits {
if let Some(methods) = self.trait_methods.get(trait_name) {
for method in methods {
if !all_methods.contains(method) {
all_methods.push(method.clone());
}
}
}
}
}
all_methods
}
}
impl RustTraitResolver {
pub fn add_trait_impl(&mut self, type_name: String, trait_name: String, file_id: FileId) {
self.type_to_traits
.entry(type_name)
.or_default()
.push((trait_name, file_id));
}
pub fn add_trait_methods(&mut self, trait_name: String, methods: Vec<String>) {
self.trait_methods.insert(trait_name, methods);
}
pub fn add_inherent_methods(&mut self, type_name: String, methods: Vec<String>) {
self.inherent_methods
.entry(type_name)
.or_default()
.extend(methods);
}
pub fn resolve_method_trait(&self, type_name: &str, method_name: &str) -> Option<&str> {
if self.is_inherent_method(type_name, method_name) {
return None;
}
if let Some(trait_name) = self
.type_method_to_trait
.get(&(type_name.to_string(), method_name.to_string()))
{
return Some(trait_name);
}
if let Some(traits) = self.type_to_traits.get(type_name) {
let mut matching_traits = Vec::new();
for (trait_name, _) in traits {
if let Some(methods) = self.trait_methods.get(trait_name) {
if methods.contains(&method_name.to_string()) {
matching_traits.push(trait_name.as_str());
}
}
}
if !matching_traits.is_empty() {
if matching_traits.len() > 1 {
eprintln!(
"WARNING: Ambiguous method '{method_name}' on type '{type_name}' - found in traits: {matching_traits:?}"
);
}
return Some(matching_traits[0]);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{FileId, SymbolId};
#[test]
fn test_resolve_qualified_module_path() {
let mut context = RustResolutionContext::new(FileId::new(1).unwrap());
let symbol_id = SymbolId::new(42).unwrap();
println!("\n=== Testing CURRENT behavior (bug demonstration) ===");
context.add_symbol(
"init_global_dirs".to_string(),
symbol_id,
ScopeLevel::Global,
);
let result1 = context.resolve("init_global_dirs");
println!("Resolving 'init_global_dirs': {result1:?} (Expected: Some(SymbolId(42)))");
assert_eq!(result1, Some(symbol_id));
let result2 = context.resolve("crate::init::init_global_dirs");
println!(
"Resolving 'crate::init::init_global_dirs': {result2:?} (Expected: Some(SymbolId(42)) but got None!)"
);
context.clear_local_scope();
println!("\n=== Testing PROPOSED FIX ===");
context.add_symbol(
"init_global_dirs".to_string(),
symbol_id,
ScopeLevel::Global,
);
context.add_symbol(
"crate::init::init_global_dirs".to_string(),
symbol_id,
ScopeLevel::Global,
);
let result3 = context.resolve("init_global_dirs");
println!("Resolving 'init_global_dirs': {result3:?} (Expected: Some(SymbolId(42)))");
assert_eq!(result3, Some(symbol_id));
let result4 = context.resolve("crate::init::init_global_dirs");
println!(
"Resolving 'crate::init::init_global_dirs': {result4:?} (Expected: Some(SymbolId(42)))"
);
assert_eq!(
result4,
Some(symbol_id),
"With fix applied, qualified path should resolve!"
);
}
}