use syntax::ast::{BinaryOperator, Expression, Literal, Span, UnaryOperator};
use syntax::program::Definition;
use syntax::types::{Type, substitute};
use BinaryOperator::*;
use UnaryOperator::*;
use super::super::Checker;
impl Checker<'_, '_> {
pub(super) fn infer_unary(
&mut self,
operator: UnaryOperator,
operand: Box<Expression>,
expected_ty: &Type,
span: Span,
) -> Expression {
let operand_expected_ty = if operator == Negative && expected_ty.resolve().is_numeric() {
expected_ty.clone()
} else {
self.new_type_var()
};
if operator == Negative {
self.scopes.increment_negation_depth();
}
let new_expression =
self.with_value_context(|s| s.infer_expression(*operand, &operand_expected_ty));
if operator == Negative {
self.scopes.decrement_negation_depth();
}
self.check_not_temp_producing(&new_expression);
let operand_span = new_expression.get_span();
let expression_ty = match operator {
Negative => {
let resolved = operand_expected_ty.resolve();
if resolved.is_numeric() || resolved.underlying_numeric_type().is_some() {
let is_literal = is_numeric_literal(&new_expression);
if resolved.is_unsigned_int() && !is_literal {
let type_name = resolved.get_name().unwrap_or_default();
self.sink
.push(diagnostics::infer::cannot_negate_unsigned(type_name, span));
}
self.check_negative_literal_overflow(&new_expression, &resolved, span);
operand_expected_ty.clone()
} else {
self.sink
.push(diagnostics::infer::not_numeric(&resolved, operand_span));
operand_expected_ty.clone()
}
}
Not => {
let bool_ty = self.type_bool();
self.unify(&bool_ty, &operand_expected_ty, &span);
bool_ty
}
Deref => {
let inner_ty = self.new_type_var();
let ref_ty = self.type_reference(inner_ty.clone());
self.unify(&ref_ty, &operand_expected_ty, &span);
inner_ty
}
};
self.unify(expected_ty, &expression_ty, &span);
Expression::Unary {
operator,
expression: new_expression.into(),
ty: expression_ty,
span,
}
}
pub(super) fn infer_binary(
&mut self,
operator: BinaryOperator,
left_operand: Box<Expression>,
right_operand: Box<Expression>,
expected_ty: &Type,
span: Span,
) -> Expression {
if matches!(*left_operand, Expression::Binary { .. }) {
let mut stack = vec![(operator, right_operand, span)];
let mut current = *left_operand;
while let Expression::Binary {
operator: op,
left,
right,
span: s,
..
} = current
{
stack.push((op, right, s));
current = *left;
}
let mut left_ty = self.new_type_var();
let mut left_inferred = self.infer_expression(current, &left_ty);
while let Some((op, right, s)) = stack.pop() {
let result_ty = if stack.is_empty() {
expected_ty.clone()
} else {
self.new_type_var()
};
let (inferred, ty) =
self.infer_binary_with_left(op, left_inferred, left_ty, right, &result_ty, s);
left_inferred = inferred;
left_ty = ty;
}
return left_inferred;
}
self.infer_binary_impl(operator, left_operand, right_operand, expected_ty, span)
}
fn infer_binary_with_left(
&mut self,
operator: BinaryOperator,
left_inferred: Expression,
left_ty: Type,
right_operand: Box<Expression>,
expected_ty: &Type,
span: Span,
) -> (Expression, Type) {
if matches!(operator, Division | Remainder) {
let is_zero = match right_operand.unwrap_parens() {
Expression::Literal {
literal: Literal::Integer { value: 0, .. },
..
} => true,
Expression::Literal {
literal: Literal::Float { value, .. },
..
} => *value == 0.0,
_ => false,
};
if is_zero {
self.sink.push(diagnostics::infer::division_by_zero(span));
}
}
let left_operand_ty = left_ty;
let right_operand_ty = self.new_type_var();
let right_literal_kind = numeric_literal_kind(&right_operand);
let is_right_literal = !matches!(right_literal_kind, NumericLiteralKind::None);
let new_right_operand = self.with_value_context(|s| {
if is_right_literal {
let left_resolved = left_operand_ty.resolve();
if literal_can_adapt_to(&right_literal_kind, &left_resolved) {
let _ = s.try_unify(&right_operand_ty, &left_resolved, &span);
}
}
s.infer_expression(*right_operand, &right_operand_ty)
});
self.check_not_temp_producing(&left_inferred);
self.check_not_temp_producing(&new_right_operand);
if matches!(operator, And | Or)
&& let Some(span) = Checker::find_propagate(&new_right_operand)
{
self.sink
.push(diagnostics::infer::propagate_in_condition(span));
}
let left_span = left_inferred.get_span();
let right_span = new_right_operand.get_span();
let expression_ty = self.resolve_binary_type(
&operator,
&left_operand_ty,
&right_operand_ty,
&left_span,
&right_span,
span,
);
self.unify(expected_ty, &expression_ty, &span);
let result = Expression::Binary {
operator,
left: Box::new(left_inferred),
right: Box::new(new_right_operand),
ty: expression_ty.clone(),
span,
};
(result, expression_ty)
}
fn infer_binary_impl(
&mut self,
operator: BinaryOperator,
left_operand: Box<Expression>,
right_operand: Box<Expression>,
expected_ty: &Type,
span: Span,
) -> Expression {
if matches!(operator, Division | Remainder) {
let is_zero = match right_operand.unwrap_parens() {
Expression::Literal {
literal: Literal::Integer { value: 0, .. },
..
} => true,
Expression::Literal {
literal: Literal::Float { value, .. },
..
} => *value == 0.0,
_ => false,
};
if is_zero {
self.sink.push(diagnostics::infer::division_by_zero(span));
}
}
let left_operand_ty = self.new_type_var();
let right_operand_ty = self.new_type_var();
let left_literal_kind = numeric_literal_kind(&left_operand);
let right_literal_kind = numeric_literal_kind(&right_operand);
let is_left_literal = !matches!(left_literal_kind, NumericLiteralKind::None);
let is_right_literal = !matches!(right_literal_kind, NumericLiteralKind::None);
let (new_left_operand, new_right_operand) = self.with_value_context(|s| {
if is_left_literal && !is_right_literal {
let right = s.infer_expression(*right_operand, &right_operand_ty);
let right_resolved = right_operand_ty.resolve();
if literal_can_adapt_to(&left_literal_kind, &right_resolved) {
let _ = s.try_unify(&left_operand_ty, &right_resolved, &span);
}
let left = s.infer_expression(*left_operand, &left_operand_ty);
(left, right)
} else {
let left = s.infer_expression(*left_operand, &left_operand_ty);
if is_right_literal {
let left_resolved = left_operand_ty.resolve();
if literal_can_adapt_to(&right_literal_kind, &left_resolved) {
let _ = s.try_unify(&right_operand_ty, &left_resolved, &span);
}
}
let right = s.infer_expression(*right_operand, &right_operand_ty);
(left, right)
}
});
self.check_not_temp_producing(&new_left_operand);
self.check_not_temp_producing(&new_right_operand);
if matches!(operator, And | Or)
&& let Some(span) = Checker::find_propagate(&new_right_operand)
{
self.sink
.push(diagnostics::infer::propagate_in_condition(span));
}
let left_span = new_left_operand.get_span();
let right_span = new_right_operand.get_span();
let expression_ty = self.resolve_binary_type(
&operator,
&left_operand_ty,
&right_operand_ty,
&left_span,
&right_span,
span,
);
self.unify(expected_ty, &expression_ty, &span);
Expression::Binary {
operator,
left: new_left_operand.into(),
right: new_right_operand.into(),
ty: expression_ty,
span,
}
}
fn resolve_binary_type(
&mut self,
operator: &BinaryOperator,
left_operand_ty: &Type,
right_operand_ty: &Type,
left_span: &Span,
right_span: &Span,
span: Span,
) -> Type {
match operator {
Equal | NotEqual => {
let resolved_left_operand = left_operand_ty.resolve();
let resolved_right_operand = right_operand_ty.resolve();
let same_aliased_numeric = resolved_left_operand == resolved_right_operand
&& resolved_left_operand.is_aliased_numeric_type();
let different_but_compatible = resolved_left_operand != resolved_right_operand
&& resolved_left_operand.is_numeric_compatible_with(&resolved_right_operand);
if !same_aliased_numeric && !different_but_compatible {
self.unify_binary_operands(operator, left_operand_ty, right_operand_ty, &span);
}
self.ensure_comparable(left_operand_ty, left_span);
self.type_bool()
}
And | Or => {
let bool_ty = self.type_bool();
self.unify(left_operand_ty, &bool_ty, &span);
self.unify(right_operand_ty, &bool_ty, &span);
bool_ty
}
LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
let resolved_left_operand = left_operand_ty.resolve();
let resolved_right_operand = right_operand_ty.resolve();
let same_aliased_numeric = resolved_left_operand == resolved_right_operand
&& resolved_left_operand.is_aliased_numeric_type();
let different_but_compatible = resolved_left_operand != resolved_right_operand
&& resolved_left_operand.is_numeric_compatible_with(&resolved_right_operand);
if same_aliased_numeric || different_but_compatible {
self.type_bool()
} else {
self.ensure_orderable(left_operand_ty, left_span);
self.ensure_orderable(right_operand_ty, right_span);
self.unify_binary_operands(operator, left_operand_ty, right_operand_ty, &span);
self.type_bool()
}
}
Addition => {
let resolved_left_operand = left_operand_ty.resolve();
let resolved_right_operand = right_operand_ty.resolve();
if let Some(result_ty) = self.try_operation_with_numeric_alias(
operator,
&resolved_left_operand,
&resolved_right_operand,
&span,
) {
result_ty
} else {
let numeric_ok = if !resolved_left_operand.is_string()
&& !resolved_right_operand.is_string()
{
self.ensure_numeric_for_binary(operator, left_operand_ty, left_span)
& self.ensure_numeric_for_binary(operator, right_operand_ty, right_span)
} else {
true
};
if resolved_left_operand.is_complex() || resolved_right_operand.is_complex() {
self.type_complex128()
} else {
if numeric_ok {
self.unify_binary_operands(
operator,
left_operand_ty,
right_operand_ty,
&span,
);
}
left_operand_ty.clone()
}
}
}
Subtraction | Multiplication | Division | Remainder => {
let left_resolved = left_operand_ty.resolve();
let right_resolved = right_operand_ty.resolve();
if matches!(operator, Remainder)
&& (left_resolved.is_float() || right_resolved.is_float())
{
self.sink
.push(diagnostics::infer::float_modulo_not_supported(span));
}
if let Some(result_ty) = self.try_operation_with_numeric_alias(
operator,
&left_resolved,
&right_resolved,
&span,
) {
result_ty
} else if left_resolved.is_complex() || right_resolved.is_complex() {
self.type_complex128()
} else {
let left_ok =
self.ensure_numeric_for_binary(operator, left_operand_ty, left_span);
let right_ok =
self.ensure_numeric_for_binary(operator, right_operand_ty, right_span);
if left_ok && right_ok {
self.unify_binary_operands(
operator,
left_operand_ty,
right_operand_ty,
&span,
);
}
left_operand_ty.clone()
}
}
Pipeline => {
panic!("Pipeline operator should have been desugared before type inference")
}
}
}
fn ensure_numeric_for_binary(
&mut self,
operator: &BinaryOperator,
ty: &Type,
span: &Span,
) -> bool {
let resolved_ty = ty.resolve();
if matches!(resolved_ty, Type::Variable(_) | Type::Error) {
return true;
}
if matches!(resolved_ty, Type::Parameter(_)) {
self.sink
.push(diagnostics::infer::not_orderable(&resolved_ty, *span));
return false;
}
if !resolved_ty.is_numeric() {
self.sink.push(diagnostics::infer::not_numeric_for_binary(
operator,
&resolved_ty,
*span,
));
return false;
}
true
}
fn ensure_orderable(&mut self, ty: &Type, span: &Span) {
let resolved_ty = ty.resolve();
if resolved_ty.is_error() {
return;
}
if !resolved_ty.is_ordered() && !resolved_ty.is_string() && !resolved_ty.is_boolean() {
self.sink
.push(diagnostics::infer::not_orderable(&resolved_ty, *span));
}
}
fn ensure_comparable(&mut self, ty: &Type, span: &Span) {
let resolved = ty.resolve();
if resolved.is_error() {
return;
}
if let Some(reason) = self.check_not_comparable(&resolved) {
self.sink
.push(diagnostics::infer::not_comparable(&resolved, reason, *span));
}
}
fn check_not_comparable(&self, ty: &Type) -> Option<&'static str> {
if matches!(ty, Type::Function { .. }) {
return Some("functions");
}
if ty.has_name("Slice") {
return Some("slices");
}
if ty.has_name("Map") {
return Some("maps");
}
if ty.has_name("Ref") || ty.has_name("Channel") {
return None;
}
if matches!(ty, Type::Variable(_)) {
return None;
}
if matches!(ty, Type::Parameter(_)) {
return Some("type parameters (Go requires the `comparable` constraint)");
}
if let Some(name) = ty.get_qualified_id()
&& let Some(definition) = self.store.get_definition(name)
{
let type_args = ty.get_type_params().unwrap_or_default();
let generics = match &definition {
Definition::Struct { generics, .. } | Definition::Enum { generics, .. } => {
generics.as_slice()
}
_ => &[],
};
let sub_map = generics
.iter()
.map(|g| g.name.clone())
.zip(type_args.iter().cloned())
.collect();
match definition {
Definition::Struct { fields, .. } => {
for f in fields {
let field_ty = substitute(&f.ty.resolve(), &sub_map);
if self.check_not_comparable(&field_ty).is_some() {
return Some("a struct containing non-comparable fields");
}
}
}
Definition::Enum { variants, .. } => {
for v in variants {
for f in v.fields.iter() {
let field_ty = substitute(&f.ty.resolve(), &sub_map);
if self.check_not_comparable(&field_ty).is_some() {
return Some("an enum containing non-comparable fields");
}
}
}
}
_ => {}
}
}
if let Type::Tuple(elems) = ty {
for e in elems {
if self.check_not_comparable(&e.resolve()).is_some() {
return Some("a tuple containing non-comparable elements");
}
}
}
None
}
fn unify_binary_operands(
&mut self,
operator: &BinaryOperator,
left_operand_ty: &Type,
right_operand_ty: &Type,
span: &Span,
) {
if self
.try_unify(left_operand_ty, right_operand_ty, span)
.is_err()
{
self.sink
.push(diagnostics::infer::binary_operator_type_mismatch(
operator,
left_operand_ty,
right_operand_ty,
*span,
));
}
}
fn try_operation_with_numeric_alias(
&mut self,
operator: &BinaryOperator,
left_ty: &Type,
right_ty: &Type,
span: &Span,
) -> Option<Type> {
let left_underlying = left_ty.underlying_numeric_type();
let right_underlying = right_ty.underlying_numeric_type();
let (left_underlying, right_underlying) = match (left_underlying, right_underlying) {
(Some(l), Some(r)) => (l, r),
_ => return None,
};
let left_family = left_underlying.numeric_family()?;
let right_family = right_underlying.numeric_family()?;
if left_family != right_family {
return None;
}
let left_is_aliased = left_ty.is_aliased_numeric_type();
let right_is_aliased = right_ty.is_aliased_numeric_type();
match (left_is_aliased, right_is_aliased, operator) {
(true, true, _) if left_ty == right_ty => {
if matches!(operator, Division) {
Some(left_underlying)
} else {
Some(left_ty.clone())
}
}
(true, false, _) => Some(left_ty.clone()),
(false, true, Division | Remainder) => {
self.sink.push(diagnostics::infer::invalid_division_order(
operator, left_ty, right_ty, *span,
));
None
}
(false, true, _) => Some(right_ty.clone()),
(false, false, _) => None,
(true, true, _) => {
self.sink
.push(diagnostics::infer::incompatible_named_numeric_types(
&left_underlying,
*span,
));
None
}
}
}
pub(super) fn infer_range(
&mut self,
start: Option<Box<Expression>>,
end: Option<Box<Expression>>,
inclusive: bool,
span: Span,
expected_ty: &Type,
) -> Expression {
let element_ty = self.new_type_var();
let (new_start, new_end) = self.with_value_context(|s| {
let start =
start.map(|expression| Box::new(s.infer_expression(*expression, &element_ty)));
let end = end.map(|expression| Box::new(s.infer_expression(*expression, &element_ty)));
(start, end)
});
if let Some(s) = &new_start {
self.check_not_temp_producing(s);
}
if let Some(e) = &new_end {
self.check_not_temp_producing(e);
}
let range_ty = match (&new_start, &new_end, inclusive) {
(Some(_), Some(_), false) => self.type_range(element_ty.clone()),
(Some(_), Some(_), true) => self.type_range_inclusive(element_ty.clone()),
(Some(_), None, _) => self.type_range_from(element_ty.clone()),
(None, Some(_), false) => self.type_range_to(element_ty.clone()),
(None, Some(_), true) => self.type_range_to_inclusive(element_ty.clone()),
(None, None, _) => {
self.sink
.push(diagnostics::infer::range_full_not_valid_expression(span));
let error_ty = self.new_type_var();
self.type_range(error_ty)
}
};
self.unify(expected_ty, &range_ty, &span);
Expression::Range {
start: new_start,
end: new_end,
inclusive,
ty: range_ty,
span,
}
}
pub(super) fn infer_cast(
&mut self,
expression: Box<Expression>,
target_type: syntax::ast::Annotation,
span: Span,
expected_ty: &Type,
) -> Expression {
let target_ty = self.convert_to_type(&target_type, &span);
let source_ty_var = self.new_type_var();
let new_expression =
self.with_value_context(|s| s.infer_expression(*expression, &source_ty_var));
let source_ty = source_ty_var.resolve();
self.check_not_temp_producing(&new_expression);
if is_cast_expression(&new_expression) {
self.sink.push(diagnostics::infer::chained_cast(span));
}
if !self.check_redundant_cast(&source_ty, &target_ty, span) {
self.check_redundant_literal_cast(&new_expression, &target_ty, expected_ty, span);
}
self.check_cast_literal_overflow(&new_expression, &target_ty, span);
self.check_valid_cast(&source_ty, &target_ty, span);
if is_float_literal(&new_expression) && is_integer_type(&target_ty) {
self.sink
.push(diagnostics::infer::float_literal_int_cast(span));
}
self.unify(expected_ty, &target_ty, &span);
Expression::Cast {
expression: new_expression.into(),
target_type,
ty: target_ty,
span,
}
}
}
fn is_float_literal(expression: &Expression) -> bool {
match expression.unwrap_parens() {
Expression::Literal {
literal: Literal::Float { .. },
..
} => true,
Expression::Unary {
operator: Negative,
expression,
..
} => is_float_literal(expression),
_ => false,
}
}
fn is_integer_type(ty: &Type) -> bool {
matches!(
ty.resolve().get_name(),
Some(
"int"
| "int8"
| "int16"
| "int32"
| "int64"
| "uint"
| "uint8"
| "uint16"
| "uint32"
| "uint64"
| "byte"
| "rune"
)
)
}
fn is_cast_expression(expression: &Expression) -> bool {
match expression {
Expression::Cast { .. } => true,
Expression::Paren { expression, .. } => is_cast_expression(expression),
_ => false,
}
}
fn is_numeric_literal(expression: &Expression) -> bool {
match expression {
Expression::Literal {
literal: Literal::Integer { .. } | Literal::Float { .. },
..
} => true,
Expression::Paren { expression, .. } => is_numeric_literal(expression),
_ => false,
}
}
enum NumericLiteralKind {
Integer,
Float,
None,
}
fn numeric_literal_kind(expression: &Expression) -> NumericLiteralKind {
match expression {
Expression::Literal {
literal: Literal::Integer { .. },
..
} => NumericLiteralKind::Integer,
Expression::Literal {
literal: Literal::Float { .. },
..
} => NumericLiteralKind::Float,
Expression::Paren { expression, .. } => numeric_literal_kind(expression),
_ => NumericLiteralKind::None,
}
}
fn literal_can_adapt_to(kind: &NumericLiteralKind, target: &Type) -> bool {
match kind {
NumericLiteralKind::Integer => target.is_numeric(),
NumericLiteralKind::Float => target.is_float(),
NumericLiteralKind::None => false,
}
}