use std::collections::{HashMap, HashSet};
use crate::declaration::{Declaration, Module};
use crate::expression::{Block, Expression, ExpressionKind, Statement};
use crate::operator::BinaryOperator;
use crate::types::Type;
pub fn validate_module(module: &Module) -> Result<(), String> {
let mut newtypes: HashMap<String, Type> = HashMap::new();
for declaration in module {
if let Declaration::Type(newtype) = declaration {
newtypes.insert(newtype.name.clone(), newtype.inner_type.clone());
}
}
let context = ValidationContext { newtypes };
for declaration in module {
match declaration {
Declaration::Function(function) => {
context.validate_block(&function.body)?;
}
Declaration::Constant(constant) => {
context.validate_expression(&constant.value)?;
}
_ => {}
}
}
Ok(())
}
struct ValidationContext {
newtypes: HashMap<String, Type>,
}
impl ValidationContext {
fn validate_block(&self, block: &Block) -> Result<(), String> {
for statement in &block.statements {
self.validate_statement(statement)?;
}
if let Some(result) = &block.result {
self.validate_expression(result)?;
}
Ok(())
}
fn validate_statement(&self, statement: &Statement) -> Result<(), String> {
match statement {
Statement::Expression(expression) | Statement::Return(Some(expression)) => {
self.validate_expression(expression)?;
}
Statement::Let { value, .. } => {
self.validate_expression(value)?;
}
Statement::Assign(target, value) => {
self.validate_expression(target)?;
self.validate_expression(value)?;
Self::check_replace_types(target, value)?;
}
Statement::Label {
initial_arguments, ..
} => {
for argument in initial_arguments {
self.validate_expression(argument)?;
}
}
Statement::Jump { arguments, .. } => {
for argument in arguments {
self.validate_expression(argument)?;
}
}
Statement::MultiReplace {
targets, values, ..
} => {
for target in targets {
self.validate_expression(target)?;
}
for value in values {
self.validate_expression(value)?;
}
}
Statement::Defer(inner) => {
self.validate_statement(inner)?;
}
Statement::Return(None) => {}
}
Ok(())
}
fn validate_expression(&self, expression: &Expression) -> Result<(), String> {
match &expression.kind {
ExpressionKind::BinaryOperation(operator, left, right) => {
self.validate_expression(left)?;
self.validate_expression(right)?;
self.check_binary_operands(operator, left, right)?;
}
ExpressionKind::TypeConstruction(name, fields) => {
for (_, value) in fields {
self.validate_expression(value)?;
}
self.check_construction_fields(name, fields)?;
}
ExpressionKind::Replace(target, value) | ExpressionKind::OpAssign(_, target, value) => {
self.validate_expression(target)?;
self.validate_expression(value)?;
Self::check_replace_types(target, value)?;
}
ExpressionKind::Call(callee, arguments) => {
self.validate_expression(callee)?;
for argument in arguments {
self.validate_expression(argument)?;
}
}
ExpressionKind::UnaryOperation(_, operand)
| ExpressionKind::Dereference(operand)
| ExpressionKind::Convert(operand, _)
| ExpressionKind::Transmute(operand, _) => {
self.validate_expression(operand)?;
}
ExpressionKind::Field(object, _) => {
self.validate_expression(object)?;
}
ExpressionKind::Index(object, index) => {
self.validate_expression(object)?;
self.validate_expression(index)?;
}
ExpressionKind::ArrayLiteral(elements)
| ExpressionKind::TupleLiteral(elements)
| ExpressionKind::Print(elements) => {
for element in elements {
self.validate_expression(element)?;
}
}
ExpressionKind::Block(block) => {
self.validate_block(block)?;
}
ExpressionKind::If {
condition,
then_branch,
else_branch,
} => {
self.validate_expression(condition)?;
self.validate_block(then_branch)?;
if let Some(else_branch) = else_branch {
self.validate_block(else_branch)?;
}
}
ExpressionKind::Match { value, arms } => {
self.validate_expression(value)?;
for arm in arms {
self.validate_block(&arm.body)?;
}
}
ExpressionKind::Slice(array, start, end) => {
self.validate_expression(array)?;
if let Some(start) = start {
self.validate_expression(start)?;
}
if let Some(end) = end {
self.validate_expression(end)?;
}
}
ExpressionKind::Literal(_)
| ExpressionKind::Variable(_)
| ExpressionKind::SizeOf(_) => {}
}
Ok(())
}
fn resolve_underlying(&self, resolved_type: &Type) -> Type {
match resolved_type {
Type::Named(name) => self.newtypes.get(name).map_or_else(
|| resolved_type.clone(),
|inner| self.resolve_underlying(inner),
),
Type::Pointer(mutability, inner) => {
Type::Pointer(*mutability, Box::new(self.resolve_underlying(inner)))
}
other => other.clone(),
}
}
fn check_binary_operands(
&self,
operator: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> Result<(), String> {
if matches!(operator, BinaryOperator::Logical(_)) {
return Ok(());
}
let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type)
else {
return Ok(());
};
if left_type == right_type {
return Ok(());
}
let left_resolved = self.resolve_underlying(left_type);
let right_resolved = self.resolve_underlying(right_type);
if left_resolved != right_resolved {
return Err(format!(
"type mismatch in '{operator}': left is {left_type}, right is {right_type}",
));
}
if matches!(left_type, Type::Named(_)) && matches!(right_type, Type::Named(_)) {
return Err(format!(
"cannot mix distinct types in '{operator}': left is {left_type}, right is {right_type}",
));
}
Ok(())
}
fn check_replace_types(target: &Expression, value: &Expression) -> Result<(), String> {
let Some(target_resolved) = &target.resolved_type else {
return Ok(());
};
let Some(value_type) = &value.resolved_type else {
return Ok(());
};
let target_type = match target_resolved {
Type::Pointer(_, inner) => inner.as_ref(),
other => other,
};
if target_type == value_type {
return Ok(());
}
if matches!(target_type, Type::Named(_)) && matches!(value_type, Type::Named(_)) {
return Err(format!(
"type mismatch in assignment: target is {target_type}, value is {value_type}",
));
}
Ok(())
}
fn check_construction_fields(
&self,
type_name: &str,
fields: &[(String, Expression)],
) -> Result<(), String> {
let Some(inner) = self.newtypes.get(type_name) else {
return Ok(());
};
let expected_fields: Vec<&str> = match inner {
Type::Tuple(field_types) => field_types
.iter()
.filter_map(|field_type| match field_type {
Type::Named(name) => Some(name.as_str()),
_ => None,
})
.collect(),
Type::Named(name) => vec![name.as_str()],
_ => return Ok(()),
};
let mut seen = HashSet::new();
for (field_name, _) in fields {
if !expected_fields.contains(&field_name.as_str()) {
return Err(format!("'{type_name}' has no field '{field_name}'"));
}
if !seen.insert(field_name.as_str()) {
return Err(format!(
"duplicate field '{field_name}' in '{type_name}' construction"
));
}
}
for expected in &expected_fields {
if !seen.contains(expected) {
return Err(format!(
"missing field '{expected}' in '{type_name}' construction"
));
}
}
Ok(())
}
}