use std::collections::HashMap;
use crate::ir::ast::{ClassDefinition, Component, Equation, Expression, Location, Statement};
use crate::ir::visitor::Visitor;
use super::symbols::DefinedSymbol;
use super::type_checker::{TypeCheckResult, TypeError, TypeErrorSeverity};
use super::type_inference::SymbolType;
#[derive(Debug, Clone, Default)]
pub struct CheckConfig {
pub check_cardinality_context: bool,
pub check_class_member_access: bool,
pub check_scalar_subscripts: bool,
pub check_array_bounds: bool,
}
impl CheckConfig {
pub fn all() -> Self {
Self {
check_cardinality_context: true,
check_class_member_access: true,
check_scalar_subscripts: true,
check_array_bounds: true,
}
}
pub fn none() -> Self {
Self::default()
}
pub fn cardinality_only() -> Self {
Self {
check_cardinality_context: true,
..Self::default()
}
}
}
#[derive(Debug, Clone, Default)]
struct CheckContext {
in_condition: bool,
in_assert: bool,
current_component: Option<String>,
in_loop: bool,
in_while: bool,
#[allow(dead_code)]
in_function: bool,
}
pub struct CheckVisitor<'a> {
config: CheckConfig,
result: TypeCheckResult,
#[allow(dead_code)]
defined: &'a HashMap<String, DefinedSymbol>,
#[allow(dead_code)]
component_shapes: HashMap<String, Vec<usize>>,
context: CheckContext,
}
impl<'a> CheckVisitor<'a> {
pub fn new(config: CheckConfig, defined: &'a HashMap<String, DefinedSymbol>) -> Self {
Self {
config,
result: TypeCheckResult::new(),
defined,
component_shapes: HashMap::new(),
context: CheckContext::default(),
}
}
pub fn with_shapes(
config: CheckConfig,
defined: &'a HashMap<String, DefinedSymbol>,
component_shapes: HashMap<String, Vec<usize>>,
) -> Self {
Self {
config,
result: TypeCheckResult::new(),
defined,
component_shapes,
context: CheckContext::default(),
}
}
pub fn into_result(self) -> TypeCheckResult {
self.result
}
pub fn result(&self) -> &TypeCheckResult {
&self.result
}
fn add_cardinality_error(&mut self, location: Location) {
self.result.add_error(TypeError::new(
location,
SymbolType::Unknown,
SymbolType::Unknown,
"cardinality may only be used in the condition of an if-statement/equation or an assert.".to_string(),
TypeErrorSeverity::Error,
));
}
fn check_cardinality_in_expr(&mut self, expr: &Expression) {
if !self.config.check_cardinality_context {
return;
}
if self.context.in_condition || self.context.in_assert {
return;
}
if let Some(loc) = find_cardinality_call(expr) {
self.add_cardinality_error(loc);
}
}
}
impl Visitor for CheckVisitor<'_> {
fn enter_component(&mut self, node: &Component) {
if self.config.check_cardinality_context && !matches!(node.start, Expression::Empty) {
self.context.current_component = Some(node.name.clone());
self.check_cardinality_in_expr(&node.start);
self.context.current_component = None;
}
}
fn enter_equation(&mut self, node: &Equation) {
match node {
Equation::Simple { lhs, rhs } => {
self.check_cardinality_in_expr(lhs);
self.check_cardinality_in_expr(rhs);
}
Equation::If { cond_blocks, .. } => {
for block in cond_blocks {
self.context.in_condition = true;
let _ = &block.cond; self.context.in_condition = false;
}
}
Equation::FunctionCall { comp, args, .. } => {
let is_assert = comp.parts.len() == 1
&& comp.parts.first().is_some_and(|p| p.ident.text == "assert");
if is_assert {
self.context.in_assert = true;
} else {
for arg in args {
self.check_cardinality_in_expr(arg);
}
}
}
_ => {}
}
}
fn exit_equation(&mut self, node: &Equation) {
if let Equation::FunctionCall { comp, .. } = node {
let is_assert = comp.parts.len() == 1
&& comp.parts.first().is_some_and(|p| p.ident.text == "assert");
if is_assert {
self.context.in_assert = false;
}
}
}
fn enter_statement(&mut self, node: &Statement) {
match node {
Statement::Assignment { value, .. } => {
self.check_cardinality_in_expr(value);
}
Statement::If { cond_blocks, .. } => {
for block in cond_blocks {
self.context.in_condition = true;
let _ = &block.cond;
self.context.in_condition = false;
}
}
Statement::While(block) => {
self.context.in_condition = true;
let _ = &block.cond;
self.context.in_condition = false;
self.context.in_while = true;
}
Statement::For { .. } => {
self.context.in_loop = true;
}
Statement::FunctionCall { comp, args, .. } => {
let is_assert = comp.parts.len() == 1
&& comp.parts.first().is_some_and(|p| p.ident.text == "assert");
if is_assert {
self.context.in_assert = true;
} else {
for arg in args {
self.check_cardinality_in_expr(arg);
}
}
}
_ => {}
}
}
fn exit_statement(&mut self, node: &Statement) {
match node {
Statement::While(_) => {
self.context.in_while = false;
}
Statement::For { .. } => {
self.context.in_loop = false;
}
Statement::FunctionCall { comp, .. } => {
let is_assert = comp.parts.len() == 1
&& comp.parts.first().is_some_and(|p| p.ident.text == "assert");
if is_assert {
self.context.in_assert = false;
}
}
_ => {}
}
}
}
fn find_cardinality_call(expr: &Expression) -> Option<Location> {
match expr {
Expression::FunctionCall { comp, args, .. } => {
if comp.parts.len() == 1
&& let Some(first) = comp.parts.first()
&& first.ident.text == "cardinality"
{
return Some(first.ident.location.clone());
}
for arg in args {
if let Some(loc) = find_cardinality_call(arg) {
return Some(loc);
}
}
None
}
Expression::Binary { lhs, rhs, .. } => {
find_cardinality_call(lhs).or_else(|| find_cardinality_call(rhs))
}
Expression::Unary { rhs, .. } => find_cardinality_call(rhs),
Expression::If {
branches,
else_branch,
} => {
for (cond, then_expr) in branches {
if let Some(loc) = find_cardinality_call(cond) {
return Some(loc);
}
if let Some(loc) = find_cardinality_call(then_expr) {
return Some(loc);
}
}
find_cardinality_call(else_branch)
}
Expression::Array { elements, .. } => {
for elem in elements {
if let Some(loc) = find_cardinality_call(elem) {
return Some(loc);
}
}
None
}
Expression::Range { start, step, end } => {
if let Some(loc) = find_cardinality_call(start) {
return Some(loc);
}
if let Some(loc) = step.as_ref().and_then(|s| find_cardinality_call(s)) {
return Some(loc);
}
find_cardinality_call(end)
}
Expression::Tuple { elements } => {
for elem in elements {
if let Some(loc) = find_cardinality_call(elem) {
return Some(loc);
}
}
None
}
Expression::ArrayComprehension { expr, indices } => {
if let Some(loc) = find_cardinality_call(expr) {
return Some(loc);
}
for idx in indices {
if let Some(loc) = find_cardinality_call(&idx.range) {
return Some(loc);
}
}
None
}
Expression::Parenthesized { inner } => find_cardinality_call(inner),
Expression::ComponentReference(_) | Expression::Terminal { .. } | Expression::Empty => None,
}
}
pub fn check_class(
class: &ClassDefinition,
config: CheckConfig,
defined: &HashMap<String, DefinedSymbol>,
) -> TypeCheckResult {
use crate::ir::visitor::Visitable;
let mut visitor = CheckVisitor::new(config, defined);
class.accept(&mut visitor);
visitor.into_result()
}
pub fn check_all(
class: &ClassDefinition,
defined: &HashMap<String, DefinedSymbol>,
) -> TypeCheckResult {
check_class(class, CheckConfig::all(), defined)
}
pub fn check_semantic(
class: &ClassDefinition,
defined: &HashMap<String, DefinedSymbol>,
) -> TypeCheckResult {
use super::type_checker;
let mut result = TypeCheckResult::new();
result.merge(type_checker::check_cardinality_context(class));
result.merge(type_checker::check_cardinality_arguments(class));
result.merge(type_checker::check_class_member_access(class));
result.merge(type_checker::check_scalar_subscripts(class));
result.merge(type_checker::check_array_bounds(class));
result.merge(type_checker::check_component_bindings(class));
result.merge(type_checker::check_assert_arguments(class));
result.merge(type_checker::check_builtin_attribute_modifiers(class));
result.merge(type_checker::check_break_return_context(class));
result.merge(type_checker::check_start_modification_dimensions(class));
let visitor_result = check_all(class, defined);
result.merge(visitor_result);
result
}
pub fn check_semantic_with_types(
class: &ClassDefinition,
defined: &HashMap<String, DefinedSymbol>,
) -> TypeCheckResult {
use super::type_checker;
let mut result = check_semantic(class, defined);
result.merge(type_checker::check_equations(&class.equations, defined));
for algorithm_block in &class.algorithms {
result.merge(type_checker::check_statements(algorithm_block, defined));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_config_all() {
let config = CheckConfig::all();
assert!(config.check_cardinality_context);
assert!(config.check_class_member_access);
assert!(config.check_scalar_subscripts);
assert!(config.check_array_bounds);
}
#[test]
fn test_check_config_none() {
let config = CheckConfig::none();
assert!(!config.check_cardinality_context);
assert!(!config.check_class_member_access);
assert!(!config.check_scalar_subscripts);
assert!(!config.check_array_bounds);
}
#[test]
fn test_check_config_cardinality_only() {
let config = CheckConfig::cardinality_only();
assert!(config.check_cardinality_context);
assert!(!config.check_class_member_access);
assert!(!config.check_scalar_subscripts);
assert!(!config.check_array_bounds);
}
#[test]
fn test_find_cardinality_call_not_found() {
let expr = Expression::Empty;
assert!(find_cardinality_call(&expr).is_none());
}
#[test]
fn test_find_cardinality_call_in_array() {
let expr = Expression::Array {
elements: vec![Expression::Empty, Expression::Empty],
is_matrix: false,
};
assert!(find_cardinality_call(&expr).is_none());
}
#[test]
fn test_find_cardinality_call_in_parenthesized() {
let expr = Expression::Parenthesized {
inner: Box::new(Expression::Empty),
};
assert!(find_cardinality_call(&expr).is_none());
}
}