use crate::expression::{Binding, Block, Expression, ExpressionKind, Literal};
use crate::operator::{ArithmeticOperator, BinaryOperator, UnaryOperator};
use crate::types::{FloatWidth, IntWidth, Mutability, Signedness, Type};
use std::collections::HashSet;
use super::ResolutionContext;
pub(super) fn literal_type(literal: &Literal) -> Type {
match literal {
Literal::Bool(_) => Type::Bool,
Literal::Integer(_) => Type::Int(IntWidth::W32, Signedness::Signed),
Literal::Float(_) => Type::Float(FloatWidth::W64),
Literal::String(_) => Type::Pointer(
Mutability::Shared,
Box::new(Type::Int(IntWidth::W8, Signedness::Unsigned)),
),
Literal::Null => Type::Pointer(
Mutability::Mutable,
Box::new(Type::Int(IntWidth::W8, Signedness::Unsigned)),
),
}
}
pub(super) fn is_default_integer(expression: &Expression) -> bool {
matches!(
expression.kind,
ExpressionKind::Literal(Literal::Integer(_))
) && expression.resolved_type == Some(Type::Int(IntWidth::W32, Signedness::Signed))
}
pub(super) fn is_default_float(expression: &Expression) -> bool {
matches!(expression.kind, ExpressionKind::Literal(Literal::Float(_)))
&& expression.resolved_type == Some(Type::Float(FloatWidth::W64))
}
pub(super) fn coerce_numeric_literals(left: &mut Expression, right: &mut Expression) {
if is_default_integer(left) && matches!(right.resolved_type, Some(Type::Int(..))) {
left.resolved_type.clone_from(&right.resolved_type);
} else if is_default_integer(right) && matches!(left.resolved_type, Some(Type::Int(..))) {
right.resolved_type.clone_from(&left.resolved_type);
} else if is_default_float(left) && matches!(right.resolved_type, Some(Type::Float(_))) {
left.resolved_type.clone_from(&right.resolved_type);
} else if is_default_float(right) && matches!(left.resolved_type, Some(Type::Float(_))) {
right.resolved_type.clone_from(&left.resolved_type);
}
}
pub(super) fn downcast_to_base(expression: &mut Expression, base_type: Type) {
let placeholder = Expression::new(
ExpressionKind::Literal(Literal::Bool(false)),
Some(Type::Bool),
);
let old = std::mem::replace(expression, placeholder);
*expression = Expression::new(
ExpressionKind::Convert(Box::new(old), base_type.clone()),
Some(base_type),
);
}
pub(super) fn coerce_mixed_newtypes(
left: &mut Expression,
right: &mut Expression,
context: &ResolutionContext,
) {
let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type) else {
return;
};
if left_type == right_type {
return;
}
let left_resolved = context.resolve_underlying(left_type);
let right_resolved = context.resolve_underlying(right_type);
if left_resolved != right_resolved {
return;
}
let left_is_named = matches!(left_type, Type::Named(_));
let right_is_named = matches!(right_type, Type::Named(_));
if left_is_named && right_is_named {
downcast_to_base(right, right_resolved);
downcast_to_base(left, left_resolved);
} else if left_is_named {
downcast_to_base(left, left_resolved);
} else if right_is_named {
downcast_to_base(right, right_resolved);
}
}
pub(super) fn wrap_binding_type(binding: Binding, inner: Type) -> Type {
match binding {
Binding::Variable => Type::Pointer(Mutability::Mutable, Box::new(inner)),
Binding::Reference => Type::Pointer(Mutability::Shared, Box::new(inner)),
Binding::Value => inner,
}
}
pub(super) fn is_auto_ref_chain(expression: &Expression, context: &ResolutionContext) -> bool {
match &expression.kind {
ExpressionKind::Variable(name) => context.is_auto_ref(name),
ExpressionKind::Field(inner, _) | ExpressionKind::Index(inner, _) => {
is_auto_ref_chain(inner, context)
}
_ => false,
}
}
pub(super) fn try_auto_deref_in_place(expression: &mut Expression, context: &ResolutionContext) {
if !is_auto_ref_chain(expression, context) {
return;
}
let Some(ref resolved_type) = expression.resolved_type else {
return;
};
if let Type::Pointer(_, inner) = resolved_type {
let inner_type = *inner.clone();
let placeholder = Expression::new(
ExpressionKind::Literal(Literal::Bool(false)),
Some(Type::Bool),
);
let old = std::mem::replace(expression, placeholder);
*expression = Expression::new(ExpressionKind::Dereference(Box::new(old)), Some(inner_type));
}
}
pub(super) fn is_pointer_param_chain(expression: &Expression, context: &ResolutionContext) -> bool {
match &expression.kind {
ExpressionKind::Variable(name) => context.is_pointer_param(name),
ExpressionKind::Field(inner, _) | ExpressionKind::Index(inner, _) => {
is_pointer_param_chain(inner, context)
}
_ => false,
}
}
pub(super) fn try_auto_deref_for_context(
expression: &mut Expression,
expected_type: Option<&Type>,
context: &ResolutionContext,
) {
if let Some(ty) = expected_type {
let resolved = context.resolve_named(ty);
if matches!(resolved, Type::Pointer(..)) {
if let ExpressionKind::Dereference(ref operand) = expression.kind
&& is_pointer_param_chain(operand, context)
{
let placeholder = Expression::new(
ExpressionKind::Literal(Literal::Bool(false)),
Some(Type::Bool),
);
let old = std::mem::replace(expression, placeholder);
if let ExpressionKind::Dereference(operand) = old.kind {
*expression = *operand;
}
return;
}
if is_pointer_param_chain(expression, context) {
return;
}
}
}
try_auto_deref_in_place(expression, context);
}
pub(super) fn operator_result_type(operator: &BinaryOperator, operand_type: &Type) -> Type {
match operator {
BinaryOperator::Comparison(_) | BinaryOperator::Logical(_) => Type::Bool,
BinaryOperator::Arithmetic(_) | BinaryOperator::Bitwise(_) => operand_type.clone(),
}
}
pub(super) fn unary_result_type(operator: &UnaryOperator, operand_type: &Type) -> Type {
match operator {
UnaryOperator::LogicalNot => Type::Bool,
UnaryOperator::Negate | UnaryOperator::BitwiseNot => operand_type.clone(),
}
}
pub(super) fn resolve_field_type(
object_type: &Type,
field_name: &str,
context: &ResolutionContext,
) -> Result<Type, String> {
let stripped = strip_pointers(object_type);
if let Type::Slice(..) = stripped {
if field_name == "length" {
return Ok(Type::Int(
crate::types::IntWidth::WSize,
crate::types::Signedness::Unsigned,
));
}
if field_name == "data"
&& let Type::Slice(mutability, element) = stripped
{
return Ok(Type::Pointer(*mutability, element.clone()));
}
}
let mut current = stripped.clone();
loop {
if let Type::Pointer(_, inner) = ¤t {
current = strip_pointers(inner).clone();
continue;
}
if let Type::Named(ref type_name) = current {
if let Some(field_type) = context.lookup_field(type_name, field_name) {
return Ok(field_type);
}
if let Some(inner) = context.newtypes.get(type_name) {
current = inner.clone();
continue;
}
}
break;
}
if field_name.chars().all(|c| c.is_ascii_digit()) {
Ok(object_type.clone())
} else {
Err(format!(
"type {object_type} has no field '{field_name}' — use '{field_name}: value' for numeric conversions"
))
}
}
pub(super) fn strip_pointers(ty: &Type) -> &Type {
match ty {
Type::Pointer(_, inner) => strip_pointers(inner),
other => other,
}
}
pub(super) fn auto_deref_replace_target(
target: &mut Expression,
context: &ResolutionContext,
) -> Result<(), String> {
let Some(original_type) = target.resolved_type.clone() else {
return Ok(());
};
if is_mutable_place(target, context) {
let resolved = context.resolve_named(&original_type);
if let Type::Pointer(Mutability::Mutable, inner) = &resolved {
let inner_resolved = context.resolve_named(inner);
if matches!(inner_resolved, Type::Pointer(Mutability::Shared, _)) {
return Err(format!(
"left side of := points to shared reference (&T), cannot assign through {original_type}",
));
}
}
return Ok(());
}
let deref_count = count_derefs_to_mutable(&original_type, context);
if deref_count == 0 || deref_count == usize::MAX {
return Err(format!(
"left side of := must reach a mutable pointer (|T), got {original_type}",
));
}
for _ in 0..deref_count {
let Some(current_type) = target.resolved_type.clone() else {
break;
};
let resolved = context.resolve_named(¤t_type);
if let Type::Pointer(_, inner) = resolved {
let placeholder = Expression::new(
ExpressionKind::Literal(Literal::Bool(false)),
Some(Type::Bool),
);
let old = std::mem::replace(target, placeholder);
*target = Expression::new(ExpressionKind::Dereference(Box::new(old)), Some(*inner));
}
}
Ok(())
}
pub(super) fn count_derefs_to_mutable(ty: &Type, context: &ResolutionContext) -> usize {
let resolved = context.resolve_named(ty);
match &resolved {
Type::Pointer(Mutability::Mutable, _) => 0,
Type::Pointer(Mutability::Shared, inner) => {
let inner_count = count_derefs_to_mutable(inner, context);
if inner_count == usize::MAX {
usize::MAX
} else {
1 + inner_count
}
}
_ => usize::MAX,
}
}
pub(super) fn is_mutable_place(expression: &Expression, context: &ResolutionContext) -> bool {
if let Some(ref ty) = expression.resolved_type {
let resolved = context.resolve_named(ty);
if matches!(resolved, Type::Pointer(Mutability::Mutable, _)) {
return true;
}
}
match &expression.kind {
ExpressionKind::Field(object, _) | ExpressionKind::Index(object, _) => {
is_mutable_place(object, context)
}
ExpressionKind::Dereference(inner) => {
if let Some(ref ty) = inner.resolved_type {
let resolved = context.resolve_named(ty);
matches!(resolved, Type::Pointer(Mutability::Mutable, _))
} else {
false
}
}
_ => false,
}
}
pub(super) fn resolve_index_type(array_type: &Type, context: &ResolutionContext) -> Type {
let resolved = context.resolve_named(array_type);
match &resolved {
Type::Array(inner, _) | Type::Vector(inner, _) => (**inner).clone(),
Type::Pointer(mutability, inner) => {
let resolved_inner = context.resolve_named(inner);
match &resolved_inner {
Type::Array(element, _) | Type::Vector(element, _) => {
Type::Pointer(*mutability, element.clone())
}
_ => (**inner).clone(),
}
}
Type::Slice(_, inner) => (**inner).clone(),
other => other.clone(),
}
}
pub(super) fn block_result_type(block: &Block) -> Type {
block
.result
.as_ref()
.and_then(|e| e.resolved_type.clone())
.unwrap_or_else(Type::unit)
}
pub(super) fn is_pointer_arithmetic(
context: &ResolutionContext,
operator: &BinaryOperator,
left_type: &Type,
right_type: &Type,
) -> bool {
let is_add_or_sub = matches!(
operator,
BinaryOperator::Arithmetic(ArithmeticOperator::Add | ArithmeticOperator::Subtract)
);
if !is_add_or_sub {
return false;
}
let left_resolved = context.resolve_underlying(left_type);
let right_resolved = context.resolve_underlying(right_type);
let is_pointer = |t: &Type| matches!(t, Type::Pointer(..));
let is_integer = |t: &Type| matches!(t, Type::Int(..));
(is_pointer(&left_resolved) && is_integer(&right_resolved))
|| (is_integer(&left_resolved) && is_pointer(&right_resolved))
}
pub(super) fn check_binary_operands(
context: &ResolutionContext,
operator: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> Result<(), String> {
if matches!(operator, BinaryOperator::Logical(_)) {
return Ok(());
}
let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type) else {
return Ok(());
};
if left_type == right_type {
return Ok(());
}
if is_pointer_arithmetic(context, operator, left_type, right_type) {
return Ok(());
}
let left_resolved = context.resolve_underlying(left_type);
let right_resolved = context.resolve_underlying(right_type);
if left_resolved != right_resolved {
return Err(format!(
"type mismatch in '{operator}': left is {left_type}, right is {right_type}",
));
}
Ok(())
}
pub(super) fn check_replace_types(target: &Expression, value: &Expression) -> Result<(), String> {
let Some(ref target_resolved) = target.resolved_type else {
return Ok(());
};
let Some(ref value_type) = value.resolved_type else {
return Ok(());
};
let target_type = match target_resolved {
Type::Pointer(_, inner) => inner.as_ref(),
other => other,
};
if target_type == value_type {
return Ok(());
}
if matches!(target_type, Type::Named(_)) && matches!(value_type, Type::Named(_)) {
return Err(format!(
"type mismatch in assignment: target is {target_type}, value is {value_type}",
));
}
Ok(())
}
pub(super) fn match_fields_by_type_or_position(
inner_type: &Type,
type_name: &str,
values: Vec<Expression>,
) -> Result<Vec<(String, Expression)>, String> {
let Type::Tuple(field_types) = inner_type else {
if values.len() != 1 {
return Err(format!(
"'{type_name}': expected 1 value for single-field type, got {}",
values.len()
));
}
let field_name = match inner_type {
Type::Named(name) => name.clone(),
_ => return Err(format!("'{type_name}': cannot construct non-tuple type")),
};
return Ok(vec![(field_name, values.into_iter().next().unwrap())]);
};
let field_names: Vec<String> = field_types
.iter()
.filter_map(|ft| {
if let Type::Named(name) = ft {
Some(name.clone())
} else {
None
}
})
.collect();
if values.len() != field_names.len() {
return Err(format!(
"'{type_name}': expected {} fields, got {}",
field_names.len(),
values.len()
));
}
let mut fields = Vec::new();
let mut used = vec![false; values.len()];
for field_name in &field_names {
let position = values.iter().enumerate().find_map(|(i, value)| {
if !used[i]
&& matches!(&value.resolved_type, Some(Type::Named(name)) if name == field_name)
{
Some(i)
} else {
None
}
});
if let Some(position) = position {
used[position] = true;
fields.push((field_name.clone(), values[position].clone()));
} else {
break;
}
}
if fields.len() == field_names.len() {
return Ok(fields);
}
let mut seen = std::collections::HashSet::new();
for value in &values {
if let Some(Type::Named(name)) = &value.resolved_type
&& field_names.contains(name)
&& !seen.insert(name.clone())
{
return Err(format!(
"duplicate field '{name}' in '{type_name}' construction"
));
}
}
Ok(field_names.into_iter().zip(values).collect())
}
pub(super) fn check_construction_fields(
context: &ResolutionContext,
type_name: &str,
fields: &[(String, Expression)],
) -> Result<(), String> {
let Some(inner) = context.newtypes.get(type_name) else {
return Ok(());
};
let expected_fields: Vec<&str> = match inner {
Type::Tuple(field_types) => field_types
.iter()
.filter_map(|field_type| match field_type {
Type::Named(name) => Some(name.as_str()),
_ => None,
})
.collect(),
Type::Named(name) => vec![name.as_str()],
_ => return Ok(()),
};
let mut seen = HashSet::new();
for (field_name, _) in fields {
if !expected_fields.contains(&field_name.as_str()) {
return Err(format!("'{type_name}' has no field '{field_name}'"));
}
if !seen.insert(field_name.as_str()) {
return Err(format!(
"duplicate field '{field_name}' in '{type_name}' construction"
));
}
}
for expected in &expected_fields {
if !seen.contains(expected) {
return Err(format!(
"missing field '{expected}' in '{type_name}' construction"
));
}
}
for (field_name, value) in fields {
if let Some(ref value_type) = value.resolved_type {
if matches!(value_type, Type::Named(n) if n == field_name) {
continue;
}
let field_underlying = context
.newtypes
.get(field_name)
.map(|t| context.resolve_named(t));
if let Some(ref expected) = field_underlying {
let value_resolved = context.resolve_named(value_type);
if !types_compatible(expected, &value_resolved) {
return Err(format!(
"field '{field_name}' in '{type_name}': expected type '{field_name}', got {value_type}"
));
}
}
}
}
Ok(())
}
pub(super) fn reorder_arguments_by_type(
params: &[Type],
arguments: &mut Vec<Expression>,
context: &ResolutionContext,
) {
if params.len() != arguments.len() || arguments.is_empty() {
return;
}
let mut reordered_indices = vec![None; params.len()];
let mut used = vec![false; arguments.len()];
for (param_index, param_type) in params.iter().enumerate() {
let found = arguments.iter().enumerate().position(|(i, arg)| {
if used[i] {
return false;
}
let Some(ref arg_type) = arg.resolved_type else {
return false;
};
arg_type == param_type
});
if let Some(index) = found {
used[index] = true;
reordered_indices[param_index] = Some(index);
}
}
for (param_index, param_type) in params.iter().enumerate() {
if reordered_indices[param_index].is_some() {
continue;
}
let param_resolved = context.resolve_underlying(param_type);
let found = arguments.iter().enumerate().position(|(i, arg)| {
if used[i] {
return false;
}
let Some(ref arg_type) = arg.resolved_type else {
return false;
};
let arg_resolved = context.resolve_underlying(arg_type);
types_compatible(¶m_resolved, &arg_resolved)
});
if let Some(index) = found {
used[index] = true;
reordered_indices[param_index] = Some(index);
} else {
return;
}
}
let indices: Vec<usize> = match reordered_indices.into_iter().collect() {
Some(v) => v,
None => return,
};
let already_ordered = indices.iter().enumerate().all(|(i, &index)| i == index);
if already_ordered {
return;
}
let mut old_arguments: Vec<Option<Expression>> =
std::mem::take(arguments).into_iter().map(Some).collect();
*arguments = indices
.into_iter()
.map(|i| old_arguments[i].take().unwrap())
.collect();
}
pub(super) fn types_compatible(expected: &Type, actual: &Type) -> bool {
match (expected, actual) {
(Type::Int(..), Type::Int(..)) => true,
(Type::Float(..), Type::Float(..)) => true,
(Type::Bool, Type::Bool) => true,
(Type::Pointer(_, a), Type::Pointer(_, b)) => types_compatible(a, b),
(Type::Pointer(_, inner), other) | (other, Type::Pointer(_, inner)) => {
types_compatible(inner, other)
}
(Type::Array(a, n), Type::Array(b, m)) => n == m && types_compatible(a, b),
(Type::Vector(a, n), Type::Vector(b, m)) => n == m && types_compatible(a, b),
(a, b) => a == b,
}
}