use std::collections::{HashMap, HashSet};
use crate::ir::analysis::symbol_table::SymbolTable;
use crate::ir::analysis::symbols::{DefinedSymbol, add_loop_indices_to_defined};
use crate::ir::ast::{
ClassDefinition, ComponentReference, Equation, Expression, Import, Statement, Subscript,
};
#[derive(Clone, Debug, Default)]
pub struct ReferenceCheckConfig {
pub imported_packages: HashSet<String>,
pub additional_globals: HashSet<String>,
}
impl ReferenceCheckConfig {
pub fn new() -> Self {
Self::default()
}
pub fn from_imports(imports: &[Import]) -> Self {
Self {
imported_packages: collect_imported_packages(imports),
additional_globals: HashSet::new(),
}
}
pub fn with_imported_packages(mut self, packages: HashSet<String>) -> Self {
self.imported_packages = packages;
self
}
pub fn with_additional_globals(mut self, globals: HashSet<String>) -> Self {
self.additional_globals = globals;
self
}
}
pub fn collect_imported_packages(imports: &[Import]) -> HashSet<String> {
let mut packages = HashSet::new();
for import in imports {
match import {
Import::Qualified { path, .. } => {
if let Some(first) = path.name.first() {
packages.insert(first.text.clone());
}
}
Import::Renamed { path, .. } => {
if let Some(first) = path.name.first() {
packages.insert(first.text.clone());
}
}
Import::Unqualified { path, .. } => {
if let Some(first) = path.name.first() {
packages.insert(first.text.clone());
}
}
Import::Selective { path, .. } => {
if let Some(first) = path.name.first() {
packages.insert(first.text.clone());
}
}
}
}
packages
}
#[derive(Clone, Debug)]
pub struct ReferenceError {
pub name: String,
pub line: u32,
pub col: u32,
pub message: String,
}
impl ReferenceError {
fn undefined_variable(name: &str, line: u32, col: u32) -> Self {
Self {
name: name.to_string(),
line,
col,
message: format!("Undefined variable '{}'", name),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ReferenceCheckResult {
pub errors: Vec<ReferenceError>,
pub used_symbols: HashSet<String>,
}
pub fn check_class_references(
class: &ClassDefinition,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
) -> ReferenceCheckResult {
check_class_references_with_config(class, defined, scope, &ReferenceCheckConfig::default())
}
pub fn check_class_references_with_config(
class: &ClassDefinition,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
config: &ReferenceCheckConfig,
) -> ReferenceCheckResult {
let mut result = ReferenceCheckResult::default();
for eq in class.iter_all_equations() {
check_equation(eq, defined, scope, config, &mut result);
}
for stmt in class.iter_all_statements() {
check_statement(stmt, defined, scope, config, &mut result);
}
for (_, comp) in class.iter_components() {
check_expression(&comp.start, defined, scope, config, &mut result);
}
result
}
fn check_equation(
eq: &Equation,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
config: &ReferenceCheckConfig,
result: &mut ReferenceCheckResult,
) {
match eq {
Equation::Empty => {}
Equation::Simple { lhs, rhs } => {
check_expression(lhs, defined, scope, config, result);
check_expression(rhs, defined, scope, config, result);
}
Equation::Connect { lhs, rhs } => {
check_component_ref(lhs, defined, scope, config, result);
check_component_ref(rhs, defined, scope, config, result);
}
Equation::For { indices, equations } => {
let mut local_defined = defined.clone();
add_loop_indices_to_defined(indices, &mut local_defined);
for index in indices {
check_expression(&index.range, &local_defined, scope, config, result);
}
for sub_eq in equations {
check_equation(sub_eq, &local_defined, scope, config, result);
}
}
Equation::When(blocks) => {
for block in blocks {
check_expression(&block.cond, defined, scope, config, result);
for sub_eq in &block.eqs {
check_equation(sub_eq, defined, scope, config, result);
}
}
}
Equation::If {
cond_blocks,
else_block,
} => {
for block in cond_blocks {
check_expression(&block.cond, defined, scope, config, result);
for sub_eq in &block.eqs {
check_equation(sub_eq, defined, scope, config, result);
}
}
if let Some(else_eqs) = else_block {
for sub_eq in else_eqs {
check_equation(sub_eq, defined, scope, config, result);
}
}
}
Equation::FunctionCall { comp: _, args } => {
for arg in args {
check_expression(arg, defined, scope, config, result);
}
}
}
}
fn check_statement(
stmt: &Statement,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
config: &ReferenceCheckConfig,
result: &mut ReferenceCheckResult,
) {
match stmt {
Statement::Empty => {}
Statement::Assignment { comp, value } => {
check_component_ref(comp, defined, scope, config, result);
check_expression(value, defined, scope, config, result);
}
Statement::FunctionCall {
comp: _,
args,
outputs,
} => {
for arg in args {
check_expression(arg, defined, scope, config, result);
}
for output in outputs {
check_expression(output, defined, scope, config, result);
}
}
Statement::For { indices, equations } => {
let mut local_defined = defined.clone();
add_loop_indices_to_defined(indices, &mut local_defined);
for index in indices {
check_expression(&index.range, &local_defined, scope, config, result);
}
for sub_stmt in equations {
check_statement(sub_stmt, &local_defined, scope, config, result);
}
}
Statement::While(block) => {
check_expression(&block.cond, defined, scope, config, result);
for sub_stmt in &block.stmts {
check_statement(sub_stmt, defined, scope, config, result);
}
}
Statement::If {
cond_blocks,
else_block,
} => {
for block in cond_blocks {
check_expression(&block.cond, defined, scope, config, result);
for sub_stmt in &block.stmts {
check_statement(sub_stmt, defined, scope, config, result);
}
}
if let Some(else_stmts) = else_block {
for sub_stmt in else_stmts {
check_statement(sub_stmt, defined, scope, config, result);
}
}
}
Statement::When(blocks) => {
for block in blocks {
check_expression(&block.cond, defined, scope, config, result);
for sub_stmt in &block.stmts {
check_statement(sub_stmt, defined, scope, config, result);
}
}
}
Statement::Return { .. } | Statement::Break { .. } => {}
}
}
fn check_expression(
expr: &Expression,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
config: &ReferenceCheckConfig,
result: &mut ReferenceCheckResult,
) {
match expr {
Expression::Empty => {}
Expression::ComponentReference(comp_ref) => {
check_component_ref(comp_ref, defined, scope, config, result);
}
Expression::Terminal { .. } => {}
Expression::FunctionCall { comp, args } => {
for part in &comp.parts {
if let Some(subs) = &part.subs {
for sub in subs {
if let Subscript::Expression(sub_expr) = sub {
check_expression(sub_expr, defined, scope, config, result);
}
}
}
}
for arg in args {
check_expression(arg, defined, scope, config, result);
}
}
Expression::Binary { lhs, rhs, .. } => {
check_expression(lhs, defined, scope, config, result);
check_expression(rhs, defined, scope, config, result);
}
Expression::Unary { rhs, .. } => {
check_expression(rhs, defined, scope, config, result);
}
Expression::Array { elements, .. } => {
for elem in elements {
check_expression(elem, defined, scope, config, result);
}
}
Expression::Tuple { elements } => {
for elem in elements {
check_expression(elem, defined, scope, config, result);
}
}
Expression::If {
branches,
else_branch,
} => {
for (cond, then_expr) in branches {
check_expression(cond, defined, scope, config, result);
check_expression(then_expr, defined, scope, config, result);
}
check_expression(else_branch, defined, scope, config, result);
}
Expression::Range { start, step, end } => {
check_expression(start, defined, scope, config, result);
if let Some(s) = step {
check_expression(s, defined, scope, config, result);
}
check_expression(end, defined, scope, config, result);
}
Expression::Parenthesized { inner } => {
check_expression(inner, defined, scope, config, result);
}
Expression::ArrayComprehension { expr, indices } => {
let mut local_defined = defined.clone();
add_loop_indices_to_defined(indices, &mut local_defined);
check_expression(expr, &local_defined, scope, config, result);
for idx in indices {
check_expression(&idx.range, &local_defined, scope, config, result);
}
}
}
}
fn check_component_ref(
comp_ref: &ComponentReference,
defined: &HashMap<String, DefinedSymbol>,
scope: &SymbolTable,
config: &ReferenceCheckConfig,
result: &mut ReferenceCheckResult,
) {
if let Some(first) = comp_ref.parts.first() {
let name = &first.ident.text;
result.used_symbols.insert(name.clone());
let is_defined = defined.contains_key(name)
|| scope.contains(name)
|| config.imported_packages.contains(name)
|| config.additional_globals.contains(name);
if !is_defined {
result.errors.push(ReferenceError::undefined_variable(
name,
first.ident.location.start_line,
first.ident.location.start_column,
));
}
if let Some(subs) = &first.subs {
for sub in subs {
if let Subscript::Expression(sub_expr) = sub {
check_expression(sub_expr, defined, scope, config, result);
}
}
}
}
for part in comp_ref.parts.iter().skip(1) {
if let Some(subs) = &part.subs {
for sub in subs {
if let Subscript::Expression(sub_expr) = sub {
check_expression(sub_expr, defined, scope, config, result);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::analysis::symbols::collect_defined_symbols;
use crate::modelica_grammar::ModelicaGrammar;
use crate::modelica_parser::parse;
fn parse_test_code(code: &str) -> crate::ir::ast::StoredDefinition {
let mut grammar = ModelicaGrammar::new();
parse(code, "test.mo", &mut grammar).expect("Failed to parse test code");
grammar.modelica.expect("No AST produced")
}
#[test]
fn test_undefined_reference() {
let code = r#"
model Test
Real x;
equation
x = y + 1.0;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].name, "y");
}
#[test]
fn test_for_loop_index() {
let code = r#"
model Test
Real x[10];
equation
for i in 1:10 loop
x[i] = i * 2.0;
end for;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert!(
result.errors.is_empty(),
"Expected no errors, got: {:?}",
result.errors
);
}
#[test]
fn test_used_symbols_tracking() {
let code = r#"
model Test
Real x;
Real y;
equation
x = y + 1.0;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert!(result.used_symbols.contains("x"));
assert!(result.used_symbols.contains("y"));
}
#[test]
fn test_array_comprehension_index() {
let code = r#"
model Test
Real x[10] = {i * 2 for i in 1:10};
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert!(
result.errors.is_empty(),
"Expected no errors, got: {:?}",
result.errors
);
}
#[test]
fn test_imported_package_reference() {
let code = r#"
model Test
Real x;
equation
x = Modelica.Constants.pi;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].name, "Modelica");
let config = ReferenceCheckConfig::new()
.with_imported_packages(["Modelica".to_string()].into_iter().collect());
let result = check_class_references_with_config(class, &defined, &scope, &config);
assert!(
result.errors.is_empty(),
"Expected no errors with imported package, got: {:?}",
result.errors
);
}
#[test]
fn test_additional_globals_reference() {
let code = r#"
model Test
Real x;
equation
x = PeerClass.value;
end Test;
"#;
let ast = parse_test_code(code);
let class = ast.class_list.get("Test").expect("Test class not found");
let defined = collect_defined_symbols(class);
let scope = SymbolTable::new();
let result = check_class_references(class, &defined, &scope);
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].name, "PeerClass");
let config = ReferenceCheckConfig::new()
.with_additional_globals(["PeerClass".to_string()].into_iter().collect());
let result = check_class_references_with_config(class, &defined, &scope, &config);
assert!(
result.errors.is_empty(),
"Expected no errors with additional global, got: {:?}",
result.errors
);
}
}