use std::collections::{HashMap, HashSet};
use crate::ir::analysis::symbol_trait::SymbolInfo;
use crate::ir::analysis::type_inference::{SymbolType, type_from_name};
use crate::ir::ast::{
Causality, ClassDefinition, ClassType, Component, ComponentReference, Expression, ForIndex,
Variability,
};
use crate::ir::visitor::{Visitable, Visitor};
#[derive(Clone, Debug)]
pub struct DefinedSymbol {
pub name: String,
pub line: u32,
pub col: u32,
pub is_parameter: bool,
pub is_constant: bool,
pub is_class: bool,
pub has_default: bool,
pub declared_type: SymbolType,
pub shape: Vec<usize>,
pub function_return: Option<(SymbolType, Vec<usize>)>,
}
impl DefinedSymbol {
pub fn type_name(&self) -> String {
self.declared_type.to_string()
}
}
impl DefinedSymbol {
pub fn from_component(name: &str, comp: &Component) -> (String, Self) {
let line = comp
.type_name
.name
.first()
.map(|t| t.location.start_line)
.unwrap_or(1);
let col = comp
.type_name
.name
.first()
.map(|t| t.location.start_column)
.unwrap_or(1);
let has_start = !matches!(comp.start, Expression::Empty);
let is_parameter = matches!(comp.variability, Variability::Parameter(_));
let is_constant = matches!(comp.variability, Variability::Constant(_));
let declared_type = type_from_name(&comp.type_name.to_string());
(
name.to_string(),
Self {
name: name.to_string(),
line,
col,
is_parameter,
is_constant,
is_class: false,
has_default: has_start,
declared_type,
shape: comp.shape.clone(),
function_return: None,
},
)
}
pub fn from_class(name: &str, class: &ClassDefinition) -> (String, Self) {
let function_return = if matches!(class.class_type, ClassType::Function) {
class
.components
.values()
.find(|c| matches!(c.causality, Causality::Output(_)))
.map(|output| {
(
type_from_name(&output.type_name.to_string()),
output.shape.clone(),
)
})
} else {
None
};
(
name.to_string(),
Self {
name: name.to_string(),
line: class.name.location.start_line,
col: class.name.location.start_column,
is_parameter: false,
is_constant: false,
is_class: true,
has_default: true,
declared_type: SymbolType::Class(name.to_string()),
shape: vec![],
function_return,
},
)
}
pub fn loop_index(name: &str, line: u32, col: u32) -> Self {
Self {
name: name.to_string(),
line,
col,
is_parameter: false,
is_constant: false,
is_class: false,
has_default: true,
declared_type: SymbolType::Integer,
shape: vec![],
function_return: None,
}
}
}
pub fn add_loop_indices_to_defined(
indices: &[ForIndex],
defined: &mut HashMap<String, DefinedSymbol>,
) {
for index in indices {
defined.insert(
index.ident.text.clone(),
DefinedSymbol::loop_index(
&index.ident.text,
index.ident.location.start_line,
index.ident.location.start_column,
),
);
}
}
pub fn is_class_instance_type(type_name: &str) -> bool {
!matches!(
type_name,
"Real" | "Integer" | "Boolean" | "String" | "StateSelect" | "ExternalObject"
)
}
pub fn is_class_instance(symbol_type: &SymbolType) -> bool {
match symbol_type {
SymbolType::Class(name) => {
!matches!(name.as_str(), "StateSelect" | "ExternalObject")
}
SymbolType::Enumeration(_) => true,
SymbolType::Array(inner, _) => is_class_instance(inner),
_ => false,
}
}
impl SymbolInfo for DefinedSymbol {
fn name(&self) -> &str {
&self.name
}
fn symbol_type(&self) -> SymbolType {
self.declared_type.clone()
}
fn line(&self) -> u32 {
self.line
}
fn column(&self) -> u32 {
self.col
}
fn is_parameter(&self) -> bool {
self.is_parameter
}
fn is_constant(&self) -> bool {
self.is_constant
}
fn is_class(&self) -> bool {
self.is_class
}
}
pub fn collect_defined_symbols(class: &ClassDefinition) -> HashMap<String, DefinedSymbol> {
let mut defined = HashMap::new();
for (name, comp) in class.iter_components() {
let (sym_name, symbol) = DefinedSymbol::from_component(name, comp);
defined.insert(sym_name, symbol);
}
for (name, nested_class) in class.iter_classes() {
let (sym_name, symbol) = DefinedSymbol::from_class(name, nested_class);
defined.insert(sym_name, symbol);
}
defined
}
pub fn collect_used_symbols(class: &ClassDefinition) -> HashSet<String> {
let mut collector = SymbolCollectorVisitor::new();
class.accept(&mut collector);
collector.into_symbols()
}
pub fn collect_used_symbols_simple(class: &ClassDefinition) -> HashSet<String> {
crate::ir::visitor::collect_component_refs(class, |cref| {
cref.parts.first().map(|p| p.ident.text.clone())
})
}
pub struct SymbolCollectorVisitor {
symbols: HashSet<String>,
}
impl SymbolCollectorVisitor {
pub fn new() -> Self {
Self {
symbols: HashSet::new(),
}
}
pub fn symbols(&self) -> &HashSet<String> {
&self.symbols
}
pub fn into_symbols(self) -> HashSet<String> {
self.symbols
}
}
impl Default for SymbolCollectorVisitor {
fn default() -> Self {
Self::new()
}
}
impl Visitor for SymbolCollectorVisitor {
fn enter_component_reference(&mut self, node: &ComponentReference) {
if let Some(first) = node.parts.first() {
self.symbols.insert(first.ident.text.clone());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::modelica_grammar::ModelicaGrammar;
use crate::modelica_parser::parse;
fn parse_test_code(code: &str) -> crate::ir::ast::StoredDefinition {
let mut grammar = ModelicaGrammar::new();
parse(code, "test.mo", &mut grammar).expect("Failed to parse test code");
grammar.modelica.expect("No AST produced")
}
#[test]
fn test_collect_used_symbols_basic() {
let code = r#"
model Test
Real x;
Real y;
equation
x = y + 1.0;
y = x * 2.0;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let symbols = collect_used_symbols(class);
assert!(symbols.contains("x"));
assert!(symbols.contains("y"));
}
#[test]
fn test_collect_used_symbols_with_for_loop() {
let code = r#"
model Test
Real x[10];
equation
for i in 1:10 loop
x[i] = i * 2.0;
end for;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let symbols = collect_used_symbols(class);
assert!(symbols.contains("x"));
assert!(symbols.contains("i"));
}
#[test]
fn test_collect_used_symbols_with_function_call() {
let code = r#"
model Test
Real x;
Real y;
equation
x = sin(y);
y = cos(x + 1.0);
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let symbols = collect_used_symbols(class);
assert!(symbols.contains("x"));
assert!(symbols.contains("y"));
assert!(symbols.contains("sin"));
assert!(symbols.contains("cos"));
}
#[test]
fn test_symbol_collector_visitor_directly() {
let code = r#"
model Test
Real x;
Real y;
equation
x = y + 1.0;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let mut collector = SymbolCollectorVisitor::new();
class.accept(&mut collector);
assert!(collector.symbols().contains("x"));
assert!(collector.symbols().contains("y"));
let symbols = collector.into_symbols();
assert!(symbols.contains("x"));
assert!(symbols.contains("y"));
}
#[test]
fn test_collect_used_symbols_simple_matches_regular() {
let code = r#"
model Test
Real x;
Real y;
Real z;
equation
x = y + z;
y = sin(x);
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let regular = collect_used_symbols(class);
let simple = collect_used_symbols_simple(class);
assert_eq!(
regular, simple,
"Both collection methods should return same symbols"
);
assert!(regular.contains("x"));
assert!(regular.contains("y"));
assert!(regular.contains("z"));
assert!(regular.contains("sin"));
}
}