use std::collections::HashMap;
use crate::ir::ast::{ClassDefinition, Component, Import, Location, StoredDefinition};
#[derive(Debug, Clone)]
pub struct ExternalSymbol {
pub qualified_name: String,
pub location: String,
pub line: u32,
pub column: u32,
pub kind: SymbolCategory,
pub detail: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SymbolCategory {
Package,
Model,
Class,
Block,
Connector,
Record,
Type,
Function,
Operator,
Component,
Parameter,
Constant,
}
pub trait SymbolLookup {
fn lookup_symbol(&self, name: &str) -> Option<ExternalSymbol>;
fn get_ast_for_symbol(&self, qualified_name: &str) -> Option<&StoredDefinition>;
}
#[derive(Debug, Clone)]
pub enum ResolvedSymbol<'a> {
Component {
component: &'a Component,
defined_in: &'a ClassDefinition,
inherited_via: Option<String>,
},
Class(&'a ClassDefinition),
External(ExternalSymbol),
}
pub struct ScopeResolver<'a, L: SymbolLookup + ?Sized = dyn SymbolLookup> {
ast: &'a StoredDefinition,
lookup: Option<&'a L>,
}
impl<'a> ScopeResolver<'a, dyn SymbolLookup> {
pub fn new(ast: &'a StoredDefinition) -> Self {
Self { ast, lookup: None }
}
}
impl<'a, L: SymbolLookup + ?Sized> ScopeResolver<'a, L> {
pub fn with_lookup(ast: &'a StoredDefinition, lookup: &'a L) -> Self {
Self {
ast,
lookup: Some(lookup),
}
}
pub fn within_prefix(&self) -> Option<String> {
self.ast.within.as_ref().map(|w| w.to_string())
}
pub fn class_at(&self, line: u32, col: u32) -> Option<&'a ClassDefinition> {
let mut best_match: Option<&ClassDefinition> = None;
let mut best_start_line = 0u32;
for class in self.ast.class_list.values() {
if Self::position_in_location(&class.location, line, col)
&& class.location.start_line > best_start_line
{
best_start_line = class.location.start_line;
best_match = Some(class);
}
for nested in class.classes.values() {
if Self::position_in_location(&nested.location, line, col)
&& nested.location.start_line > best_start_line
{
best_start_line = nested.location.start_line;
best_match = Some(nested);
}
}
}
best_match
}
pub fn class_at_0indexed(&self, line: u32, col: u32) -> Option<&'a ClassDefinition> {
self.class_at(line + 1, col + 1)
}
pub fn resolve(&self, name: &str, line: u32, col: u32) -> Option<ResolvedSymbol<'a>> {
let containing_class = self.class_at(line, col);
if let Some(class) = containing_class {
if let Some(component) = class.components.get(name) {
return Some(ResolvedSymbol::Component {
component,
defined_in: class,
inherited_via: None,
});
}
if let Some((component, base_class, base_name)) =
self.find_inherited_component(class, name)
{
return Some(ResolvedSymbol::Component {
component,
defined_in: base_class,
inherited_via: Some(base_name),
});
}
if let Some(resolved_path) = self.resolve_import_alias(class, name)
&& let Some(lookup) = &self.lookup
&& let Some(sym) = lookup.lookup_symbol(&resolved_path)
{
return Some(ResolvedSymbol::External(sym));
}
if let Some(nested) = class.classes.get(name) {
return Some(ResolvedSymbol::Class(nested));
}
}
if let Some(class) = self.ast.class_list.get(name) {
return Some(ResolvedSymbol::Class(class));
}
if let Some(lookup) = &self.lookup {
if let Some(within) = self.within_prefix() {
let qualified = format!("{}.{}", within, name);
if let Some(sym) = lookup.lookup_symbol(&qualified) {
return Some(ResolvedSymbol::External(sym));
}
}
if let Some(sym) = lookup.lookup_symbol(name) {
return Some(ResolvedSymbol::External(sym));
}
}
None
}
pub fn resolve_0indexed(&self, name: &str, line: u32, col: u32) -> Option<ResolvedSymbol<'a>> {
self.resolve(name, line + 1, col + 1)
}
pub fn visible_components(&self, line: u32, col: u32) -> Vec<ResolvedSymbol<'a>> {
let mut result = Vec::new();
if let Some(containing_class) = self.class_at(line, col) {
for component in containing_class.components.values() {
result.push(ResolvedSymbol::Component {
component,
defined_in: containing_class,
inherited_via: None,
});
}
for ext in &containing_class.extends {
let base_name = ext.comp.to_string();
if let Some(base_class) = self.ast.class_list.get(&base_name) {
for component in base_class.components.values() {
if !containing_class.components.contains_key(&component.name) {
result.push(ResolvedSymbol::Component {
component,
defined_in: base_class,
inherited_via: Some(base_name.clone()),
});
}
}
}
}
}
result
}
pub fn resolve_qualified(
&self,
qualified_name: &str,
line: u32,
col: u32,
) -> Option<ResolvedSymbol<'a>> {
let parts: Vec<&str> = qualified_name.split('.').collect();
if parts.is_empty() {
return None;
}
let first_part = parts[0];
let rest_parts = &parts[1..];
if let Some(class) = self.class_at(line, col) {
if let Some(resolved_path) = self.resolve_import_alias(class, first_part) {
let full_qualified = if rest_parts.is_empty() {
resolved_path
} else {
format!("{}.{}", resolved_path, rest_parts.join("."))
};
if let Some(lookup) = &self.lookup
&& let Some(sym) = lookup.lookup_symbol(&full_qualified)
{
return Some(ResolvedSymbol::External(sym));
}
}
if let Some(lookup) = &self.lookup {
let class_qualified = self.get_qualified_class_name(&class.name.text);
let relative_to_class = format!("{}.{}", class_qualified, qualified_name);
if let Some(sym) = lookup.lookup_symbol(&relative_to_class) {
return Some(ResolvedSymbol::External(sym));
}
}
}
if let Some(lookup) = &self.lookup {
if let Some(within) = self.within_prefix() {
let relative_to_within = format!("{}.{}", within, qualified_name);
if let Some(sym) = lookup.lookup_symbol(&relative_to_within) {
return Some(ResolvedSymbol::External(sym));
}
}
if let Some(sym) = lookup.lookup_symbol(qualified_name) {
return Some(ResolvedSymbol::External(sym));
}
}
if parts.len() >= 2
&& let Some(outer) = self.ast.class_list.get(first_part)
{
let mut current = outer;
for part in rest_parts {
if let Some(nested) = current.classes.get(*part) {
current = nested;
} else {
return None;
}
}
return Some(ResolvedSymbol::Class(current));
}
None
}
fn get_qualified_class_name(&self, class_name: &str) -> String {
if let Some(within) = self.within_prefix() {
format!("{}.{}", within, class_name)
} else {
class_name.to_string()
}
}
fn resolve_import_alias(&self, class: &ClassDefinition, alias: &str) -> Option<String> {
for import in &class.imports {
match import {
Import::Renamed {
alias: alias_token,
path,
..
} => {
if alias_token.text == alias {
return Some(path.to_string());
}
}
Import::Qualified { path, .. } => {
if let Some(last) = path.name.last()
&& last.text == alias
{
return Some(path.to_string());
}
}
_ => {}
}
}
None
}
fn find_class_locally(&self, name: &str) -> Option<&'a ClassDefinition> {
let parts: Vec<&str> = name.split('.').collect();
if parts.len() == 1 {
if let Some(class) = self.ast.class_list.get(name) {
return Some(class);
}
for class in self.ast.class_list.values() {
if let Some(nested) = class.classes.get(name) {
return Some(nested);
}
}
} else {
let first = parts[0];
if let Some(mut current) = self.ast.class_list.get(first) {
for part in &parts[1..] {
if let Some(nested) = current.classes.get(*part) {
current = nested;
} else {
return None;
}
}
return Some(current);
}
}
None
}
fn resolve_class_name(&self, name: &str) -> String {
if name.contains('.') {
if let Some(lookup) = &self.lookup {
if let Some(within) = self.within_prefix() {
let qualified = format!("{}.{}", within, name);
if lookup.lookup_symbol(&qualified).is_some() {
return qualified;
}
}
}
return name.to_string();
}
if let Some(lookup) = &self.lookup
&& let Some(within) = self.within_prefix()
{
let qualified = format!("{}.{}", within, name);
if lookup.lookup_symbol(&qualified).is_some() {
return qualified;
}
}
name.to_string()
}
fn find_inherited_component(
&self,
class: &'a ClassDefinition,
name: &str,
) -> Option<(&'a Component, &'a ClassDefinition, String)> {
for ext in &class.extends {
let base_name = ext.comp.to_string();
if let Some(base_class) = self.find_class_locally(&base_name) {
if let Some(component) = base_class.components.get(name) {
return Some((component, base_class, base_name));
}
if let Some(result) = self.find_inherited_component(base_class, name) {
return Some(result);
}
} else if let Some(lookup) = &self.lookup {
let qualified_base = self.resolve_class_name(&base_name);
if let Some(base_ast) = lookup.get_ast_for_symbol(&qualified_base) {
if let Some(base_class) = Self::find_class_in_ast(base_ast, &base_name)
&& let Some(component) = base_class.components.get(name)
{
return Some((component, base_class, base_name));
}
}
}
}
None
}
fn find_class_in_ast<'b>(ast: &'b StoredDefinition, name: &str) -> Option<&'b ClassDefinition> {
let parts: Vec<&str> = name.split('.').collect();
if parts.len() == 1 {
return ast.class_list.get(name);
}
let simple_name = parts.last()?;
ast.class_list.get(*simple_name)
}
fn position_in_location(loc: &Location, line: u32, col: u32) -> bool {
if line < loc.start_line || line > loc.end_line {
return false;
}
if line == loc.start_line && col < loc.start_column {
return false;
}
if line == loc.end_line && col > loc.end_column {
return false;
}
true
}
}
pub fn find_class_in_ast<'a>(
ast: &'a StoredDefinition,
qualified_name: &str,
) -> Option<&'a ClassDefinition> {
let parts: Vec<&str> = qualified_name.split('.').collect();
if parts.is_empty() {
return None;
}
let simple_name = parts.last().unwrap();
if let Some(class) = ast.class_list.get(*simple_name) {
return Some(class);
}
if let Some(within) = &ast.within {
let within_str = within.to_string();
let within_prefix = format!("{}.", within_str);
if qualified_name.starts_with(&within_prefix) {
let remainder = &qualified_name[within_prefix.len()..];
let remainder_parts: Vec<&str> = remainder.split('.').collect();
if !remainder_parts.is_empty()
&& let Some(class) = ast.class_list.get(remainder_parts[0])
{
if remainder_parts.len() == 1 {
return Some(class);
}
return find_nested_class(class, &remainder_parts[1..]);
}
}
}
let first_part = parts[0];
if let Some(class) = ast.class_list.get(first_part) {
if parts.len() == 1 {
return Some(class);
}
return find_nested_class(class, &parts[1..]);
}
ast.class_list.get(qualified_name)
}
pub fn find_nested_class<'a>(
parent: &'a ClassDefinition,
path: &[&str],
) -> Option<&'a ClassDefinition> {
if path.is_empty() {
return Some(parent);
}
if let Some(child) = parent.classes.get(path[0]) {
if path.len() == 1 {
return Some(child);
}
return find_nested_class(child, &path[1..]);
}
None
}
pub fn resolve_type_candidates(current_qualified: &str, type_name: &str) -> Vec<String> {
let mut candidates = Vec::new();
let current_parts: Vec<&str> = current_qualified.split('.').collect();
if current_parts.len() > 1 {
for i in (1..current_parts.len()).rev() {
let prefix = current_parts[..i].join(".");
candidates.push(format!("{}.{}", prefix, type_name));
}
}
candidates.push(type_name.to_string());
candidates
}
#[derive(Debug, Default)]
pub struct ImportResolver {
aliases: HashMap<String, String>,
}
impl ImportResolver {
pub fn new() -> Self {
Self {
aliases: HashMap::new(),
}
}
pub fn from_imports(imports: &[Import]) -> Self {
let mut resolver = Self::new();
for import in imports {
match import {
Import::Renamed { alias, path, .. } => {
resolver
.aliases
.insert(alias.text.clone(), path.to_string());
}
Import::Qualified { path, .. } => {
if let Some(last) = path.name.last() {
resolver.aliases.insert(last.text.clone(), path.to_string());
}
}
Import::Selective { path, names, .. } => {
let base_path = path.to_string();
for name in names {
resolver
.aliases
.insert(name.text.clone(), format!("{}.{}", base_path, name.text));
}
}
Import::Unqualified { .. } => {
}
}
}
resolver
}
pub fn resolve(&self, alias: &str) -> Option<&str> {
self.aliases.get(alias).map(|s| s.as_str())
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.aliases.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
}
pub fn collect_inherited_components<'a>(
class: &'a ClassDefinition,
peer_classes: &'a indexmap::IndexMap<String, ClassDefinition>,
) -> HashMap<String, (&'a Component, String)> {
let mut result = HashMap::new();
collect_inherited_recursive(class, peer_classes, &mut result);
result
}
fn collect_inherited_recursive<'a>(
class: &'a ClassDefinition,
peer_classes: &'a indexmap::IndexMap<String, ClassDefinition>,
result: &mut HashMap<String, (&'a Component, String)>,
) {
for ext in &class.extends {
let base_name = ext.comp.to_string();
if let Some(base_class) = peer_classes.get(&base_name) {
for (comp_name, comp) in &base_class.components {
if !result.contains_key(comp_name) {
result.insert(comp_name.clone(), (comp, base_name.clone()));
}
}
collect_inherited_recursive(base_class, peer_classes, result);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::modelica_grammar::ModelicaGrammar;
use crate::modelica_parser::parse;
fn parse_test_code(code: &str) -> StoredDefinition {
let mut grammar = ModelicaGrammar::new();
parse(code, "test.mo", &mut grammar).expect("Failed to parse test code");
grammar.modelica.expect("No AST produced")
}
#[test]
fn test_class_at_position() {
let code = r#"
class Outer
Real x;
class Inner
Real y;
end Inner;
end Outer;
"#;
let ast = parse_test_code(code);
let resolver = ScopeResolver::new(&ast);
let class = resolver.class_at(3, 5);
assert!(class.is_some());
assert_eq!(class.unwrap().name.text, "Outer");
let class = resolver.class_at(5, 5);
assert!(class.is_some());
assert_eq!(class.unwrap().name.text, "Inner");
}
#[test]
fn test_resolve_direct_component() {
let code = r#"
class Test
Real x;
Real y;
equation
x = y;
end Test;
"#;
let ast = parse_test_code(code);
let resolver = ScopeResolver::new(&ast);
let symbol = resolver.resolve("x", 6, 3);
assert!(symbol.is_some());
if let Some(ResolvedSymbol::Component {
component,
inherited_via,
..
}) = symbol
{
assert_eq!(component.name, "x");
assert!(inherited_via.is_none());
} else {
panic!("Expected Component");
}
}
#[test]
fn test_resolve_inherited_component() {
let code = r#"
class Base
Real v;
end Base;
class Derived
extends Base;
equation
v = 1;
end Derived;
"#;
let ast = parse_test_code(code);
let resolver = ScopeResolver::new(&ast);
let symbol = resolver.resolve("v", 9, 3);
assert!(symbol.is_some());
if let Some(ResolvedSymbol::Component {
component,
defined_in,
inherited_via,
}) = symbol
{
assert_eq!(component.name, "v");
assert_eq!(defined_in.name.text, "Base");
assert!(inherited_via.is_some());
} else {
panic!("Expected Component");
}
}
#[test]
fn test_resolve_class() {
let code = r#"
class MyClass
Real x;
end MyClass;
"#;
let ast = parse_test_code(code);
let resolver = ScopeResolver::new(&ast);
let symbol = resolver.resolve("MyClass", 1, 1);
assert!(symbol.is_some());
if let Some(ResolvedSymbol::Class(class)) = symbol {
assert_eq!(class.name.text, "MyClass");
} else {
panic!("Expected Class");
}
}
#[test]
fn test_find_class_in_ast_simple() {
let code = r#"
model TestModel
Real x;
end TestModel;
"#;
let ast = parse_test_code(code);
let result = find_class_in_ast(&ast, "TestModel");
assert!(result.is_some());
assert!(result.unwrap().components.contains_key("x"));
}
#[test]
fn test_find_class_in_ast_with_within() {
let code = r#"
within Modelica.Blocks.Continuous;
model PID
Real x;
end PID;
"#;
let ast = parse_test_code(code);
let result = find_class_in_ast(&ast, "Modelica.Blocks.Continuous.PID");
assert!(result.is_some(), "Should find PID by qualified name");
assert!(result.unwrap().components.contains_key("x"));
let result = find_class_in_ast(&ast, "PID");
assert!(result.is_some(), "Should find PID by simple name");
}
#[test]
fn test_find_class_in_ast_nested() {
let code = r#"
package MyPackage
model InnerModel
Real z;
end InnerModel;
end MyPackage;
"#;
let ast = parse_test_code(code);
let result = find_class_in_ast(&ast, "MyPackage.InnerModel");
assert!(result.is_some(), "Should find nested InnerModel");
assert!(result.unwrap().components.contains_key("z"));
}
#[test]
fn test_resolve_type_candidates() {
let candidates =
resolve_type_candidates("Modelica.Blocks.Continuous.PID", "Interfaces.SISO");
assert_eq!(candidates.len(), 4);
assert_eq!(candidates[0], "Modelica.Blocks.Continuous.Interfaces.SISO");
assert_eq!(candidates[1], "Modelica.Blocks.Interfaces.SISO");
assert_eq!(candidates[2], "Modelica.Interfaces.SISO");
assert_eq!(candidates[3], "Interfaces.SISO");
let candidates = resolve_type_candidates("PID", "SISO");
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0], "SISO");
let candidates = resolve_type_candidates("Modelica.PID", "SISO");
assert_eq!(candidates.len(), 2);
assert_eq!(candidates[0], "Modelica.SISO");
assert_eq!(candidates[1], "SISO");
}
#[test]
fn test_import_resolver() {
let code = r#"
model Test
import Modelica.Blocks.Continuous.PID;
import SI = Modelica.Units.SI;
import Modelica.Constants.{pi, e};
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let resolver = ImportResolver::from_imports(&class.imports);
assert_eq!(
resolver.resolve("PID"),
Some("Modelica.Blocks.Continuous.PID")
);
assert_eq!(resolver.resolve("SI"), Some("Modelica.Units.SI"));
assert_eq!(resolver.resolve("pi"), Some("Modelica.Constants.pi"));
assert_eq!(resolver.resolve("e"), Some("Modelica.Constants.e"));
}
#[test]
fn test_collect_inherited_components() {
let code = r#"
class Base
Real x;
Real y;
end Base;
class Derived
extends Base;
Real z;
end Derived;
"#;
let ast = parse_test_code(code);
let derived = ast.class_list.get("Derived").expect("Derived not found");
let inherited = collect_inherited_components(derived, &ast.class_list);
assert!(inherited.contains_key("x"), "Should inherit x");
assert!(inherited.contains_key("y"), "Should inherit y");
assert!(!inherited.contains_key("z"), "z is direct, not inherited");
let (_, base_name) = inherited.get("x").unwrap();
assert_eq!(base_name, "Base");
}
}