use crate::parsing::resolution::{ImportBinding, InheritanceResolver, ResolutionScope};
use crate::parsing::{ScopeLevel, ScopeType};
use crate::{FileId, SymbolId};
use std::collections::{HashMap, HashSet};
pub struct SwiftResolutionContext {
#[allow(dead_code)]
file_id: FileId,
local_scopes: Vec<HashMap<String, SymbolId>>,
type_scopes: Vec<HashMap<String, SymbolId>>,
extension_scope: HashMap<String, HashMap<String, SymbolId>>,
protocol_scope: HashMap<String, SymbolId>,
module_scope: HashMap<String, SymbolId>,
import_scope: HashMap<String, SymbolId>,
scope_stack: Vec<ScopeType>,
import_bindings: HashMap<String, ImportBinding>,
current_type: Option<String>,
}
impl SwiftResolutionContext {
pub fn new(file_id: FileId) -> Self {
Self {
file_id,
local_scopes: Vec::new(),
type_scopes: Vec::new(),
extension_scope: HashMap::new(),
protocol_scope: HashMap::new(),
module_scope: HashMap::new(),
import_scope: HashMap::new(),
scope_stack: vec![ScopeType::Global],
import_bindings: HashMap::new(),
current_type: None,
}
}
pub fn set_current_type(&mut self, type_name: Option<String>) {
self.current_type = type_name;
}
fn current_local_scope_mut(&mut self) -> &mut HashMap<String, SymbolId> {
if self.local_scopes.is_empty() {
self.local_scopes.push(HashMap::new());
}
self.local_scopes.last_mut().unwrap()
}
fn current_type_scope_mut(&mut self) -> Option<&mut HashMap<String, SymbolId>> {
self.type_scopes.last_mut()
}
fn resolve_in_locals(&self, name: &str) -> Option<SymbolId> {
for scope in self.local_scopes.iter().rev() {
if let Some(&id) = scope.get(name) {
return Some(id);
}
}
None
}
fn resolve_in_types(&self, name: &str) -> Option<SymbolId> {
for scope in self.type_scopes.iter().rev() {
if let Some(&id) = scope.get(name) {
return Some(id);
}
}
None
}
fn resolve_in_extensions(&self, name: &str) -> Option<SymbolId> {
if let Some(ref type_name) = self.current_type {
if let Some(ext_methods) = self.extension_scope.get(type_name) {
if let Some(&id) = ext_methods.get(name) {
return Some(id);
}
}
}
for ext_methods in self.extension_scope.values() {
if let Some(&id) = ext_methods.get(name) {
return Some(id);
}
}
None
}
pub fn add_extension_symbol(&mut self, type_name: String, name: String, symbol_id: SymbolId) {
self.extension_scope
.entry(type_name)
.or_default()
.insert(name, symbol_id);
}
}
impl ResolutionScope for SwiftResolutionContext {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn add_symbol(&mut self, name: String, symbol_id: SymbolId, scope_level: ScopeLevel) {
match scope_level {
ScopeLevel::Local => {
self.current_local_scope_mut().insert(name, symbol_id);
}
ScopeLevel::Module => {
if matches!(self.scope_stack.last(), Some(ScopeType::Class)) {
if let Some(scope) = self.current_type_scope_mut() {
scope.insert(name.clone(), symbol_id);
}
}
self.module_scope.entry(name).or_insert(symbol_id);
}
ScopeLevel::Package => {
self.import_scope.insert(name, symbol_id);
}
ScopeLevel::Global => {
self.module_scope.insert(name, symbol_id);
}
}
}
fn resolve(&self, name: &str) -> Option<SymbolId> {
if let Some(id) = self.resolve_in_locals(name) {
return Some(id);
}
if let Some(id) = self.resolve_in_types(name) {
return Some(id);
}
if let Some(id) = self.resolve_in_extensions(name) {
return Some(id);
}
if let Some(&id) = self.protocol_scope.get(name) {
return Some(id);
}
if let Some(&id) = self.module_scope.get(name) {
return Some(id);
}
if let Some(&id) = self.import_scope.get(name) {
return Some(id);
}
if let Some((head, tail)) = name.split_once('.') {
if self.resolve(head).is_some() {
if let Some(id) = self.resolve_in_types(tail) {
return Some(id);
}
}
}
None
}
fn clear_local_scope(&mut self) {
if let Some(scope) = self.local_scopes.last_mut() {
scope.clear();
}
}
fn enter_scope(&mut self, scope_type: ScopeType) {
match scope_type {
ScopeType::Function { .. } | ScopeType::Block => {
self.local_scopes.push(HashMap::new());
}
ScopeType::Class => {
self.type_scopes.push(HashMap::new());
}
_ => {}
}
self.scope_stack.push(scope_type);
}
fn exit_scope(&mut self) {
if let Some(scope) = self.scope_stack.pop() {
match scope {
ScopeType::Function { .. } | ScopeType::Block => {
self.local_scopes.pop();
}
ScopeType::Class => {
self.type_scopes.pop();
}
_ => {}
}
}
}
fn symbols_in_scope(&self) -> Vec<(String, SymbolId, ScopeLevel)> {
let mut results = Vec::new();
if let Some(local) = self.local_scopes.last() {
for (name, &id) in local {
results.push((name.clone(), id, ScopeLevel::Local));
}
}
if let Some(type_scope) = self.type_scopes.last() {
for (name, &id) in type_scope {
results.push((name.clone(), id, ScopeLevel::Module));
}
}
for (name, &id) in &self.module_scope {
results.push((name.clone(), id, ScopeLevel::Module));
}
for (name, &id) in &self.import_scope {
results.push((name.clone(), id, ScopeLevel::Package));
}
results
}
fn resolve_relationship(
&self,
_from_name: &str,
to_name: &str,
_kind: crate::RelationKind,
_from_file: FileId,
) -> Option<SymbolId> {
self.resolve(to_name)
}
fn populate_imports(&mut self, _imports: &[crate::parsing::Import]) {
}
fn register_import_binding(&mut self, binding: ImportBinding) {
if let Some(symbol_id) = binding.resolved_symbol {
self.import_scope
.insert(binding.exposed_name.clone(), symbol_id);
}
self.import_bindings
.insert(binding.exposed_name.clone(), binding);
}
fn import_binding(&self, name: &str) -> Option<ImportBinding> {
self.import_bindings.get(name).cloned()
}
}
#[derive(Default)]
pub struct SwiftInheritanceResolver {
class_inheritance: HashMap<String, String>,
protocol_conformance: HashMap<String, Vec<String>>,
protocol_extensions: HashMap<String, HashSet<String>>,
type_extensions: HashMap<String, HashSet<String>>,
type_methods: HashMap<String, HashSet<String>>,
}
impl SwiftInheritanceResolver {
pub fn new() -> Self {
Self::default()
}
pub fn add_protocol_extension_method(&mut self, protocol: String, method: String) {
self.protocol_extensions
.entry(protocol)
.or_default()
.insert(method);
}
pub fn add_type_extension_method(&mut self, type_name: String, method: String) {
self.type_extensions
.entry(type_name)
.or_default()
.insert(method);
}
fn resolve_method_recursive(
&self,
ty: &str,
method: &str,
visited: &mut HashSet<String>,
) -> Option<String> {
if !visited.insert(ty.to_string()) {
return None; }
if self
.type_methods
.get(ty)
.is_some_and(|methods| methods.contains(method))
{
return Some(ty.to_string());
}
if self
.type_extensions
.get(ty)
.is_some_and(|methods| methods.contains(method))
{
return Some(format!("{ty} (extension)"));
}
if let Some(parent) = self.class_inheritance.get(ty) {
if let Some(found) = self.resolve_method_recursive(parent, method, visited) {
return Some(found);
}
}
if let Some(protocols) = self.protocol_conformance.get(ty) {
for protocol in protocols {
if self
.protocol_extensions
.get(protocol)
.is_some_and(|methods| methods.contains(method))
{
return Some(format!("{protocol} (extension)"));
}
}
}
None
}
fn collect_chain(&self, ty: &str, visited: &mut HashSet<String>, out: &mut Vec<String>) {
if !visited.insert(ty.to_string()) {
return; }
if let Some(parent) = self.class_inheritance.get(ty) {
out.push(parent.clone());
self.collect_chain(parent, visited, out);
}
if let Some(protocols) = self.protocol_conformance.get(ty) {
for protocol in protocols {
if visited.insert(protocol.clone()) {
out.push(protocol.clone());
}
}
}
}
fn gather_methods(&self, ty: &str, visited: &mut HashSet<String>, out: &mut HashSet<String>) {
if !visited.insert(ty.to_string()) {
return; }
if let Some(methods) = self.type_methods.get(ty) {
out.extend(methods.iter().cloned());
}
if let Some(methods) = self.type_extensions.get(ty) {
out.extend(methods.iter().cloned());
}
if let Some(parent) = self.class_inheritance.get(ty) {
self.gather_methods(parent, visited, out);
}
if let Some(protocols) = self.protocol_conformance.get(ty) {
for protocol in protocols {
if let Some(methods) = self.protocol_extensions.get(protocol) {
out.extend(methods.iter().cloned());
}
}
}
}
fn is_subtype_recursive(
&self,
child: &str,
parent: &str,
visited: &mut HashSet<String>,
) -> bool {
if !visited.insert(child.to_string()) {
return false; }
if let Some(p) = self.class_inheritance.get(child) {
if p == parent {
return true;
}
if self.is_subtype_recursive(p, parent, visited) {
return true;
}
}
if let Some(protocols) = self.protocol_conformance.get(child) {
for p in protocols {
if p == parent {
return true;
}
}
}
false
}
}
impl InheritanceResolver for SwiftInheritanceResolver {
fn add_inheritance(&mut self, child: String, parent: String, kind: &str) {
match kind {
"extends" | "class" => {
self.class_inheritance.insert(child, parent);
}
"implements" | "protocol" | "conforms" => {
self.protocol_conformance
.entry(child)
.or_default()
.push(parent);
}
_ => {
self.protocol_conformance
.entry(child)
.or_default()
.push(parent);
}
}
}
fn add_type_methods(&mut self, type_name: String, methods: Vec<String>) {
self.type_methods
.entry(type_name)
.or_default()
.extend(methods);
}
fn resolve_method(&self, type_name: &str, method_name: &str) -> Option<String> {
let mut visited = HashSet::new();
self.resolve_method_recursive(type_name, method_name, &mut visited)
}
fn get_inheritance_chain(&self, type_name: &str) -> Vec<String> {
let mut visited = HashSet::new();
let mut chain = Vec::new();
self.collect_chain(type_name, &mut visited, &mut chain);
chain
}
fn get_all_methods(&self, type_name: &str) -> Vec<String> {
let mut visited = HashSet::new();
let mut methods = HashSet::new();
self.gather_methods(type_name, &mut visited, &mut methods);
methods.into_iter().collect()
}
fn is_subtype(&self, child: &str, parent: &str) -> bool {
if child == parent {
return true;
}
let mut visited = HashSet::new();
self.is_subtype_recursive(child, parent, &mut visited)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolution_context() {
let mut ctx = SwiftResolutionContext::new(FileId(1));
ctx.add_symbol("topLevel".to_string(), SymbolId(1), ScopeLevel::Module);
assert_eq!(ctx.resolve("topLevel"), Some(SymbolId(1)));
}
#[test]
fn test_scope_nesting() {
let mut ctx = SwiftResolutionContext::new(FileId(1));
ctx.add_symbol("outer".to_string(), SymbolId(1), ScopeLevel::Module);
ctx.enter_scope(ScopeType::Function { hoisting: false });
ctx.add_symbol("inner".to_string(), SymbolId(2), ScopeLevel::Local);
assert_eq!(ctx.resolve("outer"), Some(SymbolId(1)));
assert_eq!(ctx.resolve("inner"), Some(SymbolId(2)));
ctx.exit_scope();
assert_eq!(ctx.resolve("outer"), Some(SymbolId(1)));
assert_eq!(ctx.resolve("inner"), None);
}
#[test]
fn test_inheritance_resolver() {
let mut resolver = SwiftInheritanceResolver::new();
resolver.add_inheritance("Dog".to_string(), "Animal".to_string(), "extends");
resolver.add_inheritance("Animal".to_string(), "Named".to_string(), "conforms");
resolver.add_type_methods("Animal".to_string(), vec!["makeSound".to_string()]);
resolver.add_type_methods("Dog".to_string(), vec!["fetch".to_string()]);
let chain = resolver.get_inheritance_chain("Dog");
assert!(chain.contains(&"Animal".to_string()));
assert!(resolver.is_subtype("Dog", "Animal"));
assert!(!resolver.is_subtype("Animal", "Dog"));
assert_eq!(
resolver.resolve_method("Dog", "fetch"),
Some("Dog".to_string())
);
assert_eq!(
resolver.resolve_method("Dog", "makeSound"),
Some("Animal".to_string())
);
}
#[test]
fn test_protocol_extensions() {
let mut resolver = SwiftInheritanceResolver::new();
resolver.add_protocol_extension_method("Drawable".to_string(), "draw".to_string());
resolver.add_inheritance("Rectangle".to_string(), "Drawable".to_string(), "conforms");
let result = resolver.resolve_method("Rectangle", "draw");
assert!(result.is_some());
assert!(result.unwrap().contains("Drawable"));
}
}