use crate::frontend::ast::{BinaryOp, Expr, ExprKind, Literal, Param, Pattern, TypeKind, UnaryOp};
use crate::middleend::environment::TypeEnv;
use crate::middleend::types::{MonoType, TyVar, TyVarGenerator, TypeScheme};
use crate::middleend::unify::Unifier;
use anyhow::{bail, Result};
pub struct InferenceContext {
gen: TyVarGenerator,
unifier: Unifier,
env: TypeEnv,
constraints: Vec<(TyVar, TyVar)>,
type_constraints: Vec<TypeConstraint>,
recursion_depth: usize,
}
#[derive(Debug, Clone)]
pub enum TypeConstraint {
Unify(MonoType, MonoType),
FunctionArity(MonoType, usize),
MethodCall(MonoType, String, Vec<MonoType>),
Iterable(MonoType, MonoType),
}
impl InferenceContext {
#[must_use]
pub fn new() -> Self {
InferenceContext {
gen: TyVarGenerator::new(),
unifier: Unifier::new(),
env: TypeEnv::standard(),
constraints: Vec::new(),
type_constraints: Vec::new(),
recursion_depth: 0,
}
}
#[must_use]
pub fn with_env(env: TypeEnv) -> Self {
InferenceContext {
gen: TyVarGenerator::new(),
unifier: Unifier::new(),
env,
constraints: Vec::new(),
type_constraints: Vec::new(),
recursion_depth: 0,
}
}
pub fn infer(&mut self, expr: &Expr) -> Result<MonoType> {
if self.recursion_depth > 100 {
bail!("Type inference recursion limit exceeded");
}
self.recursion_depth += 1;
let result = self.infer_expr(expr);
self.recursion_depth -= 1;
let inferred_type = result?;
self.solve_all_constraints()?;
Ok(self.unifier.apply(&inferred_type))
}
pub fn instantiate(&mut self, scheme: &TypeScheme) -> MonoType {
scheme.instantiate(&mut self.gen)
}
fn solve_all_constraints(&mut self) -> Result<()> {
self.solve_constraints();
while let Some(constraint) = self.type_constraints.pop() {
self.solve_type_constraint(constraint)?;
}
Ok(())
}
fn solve_constraints(&mut self) {
while let Some((a, b)) = self.constraints.pop() {
let ty_a = MonoType::Var(a);
let ty_b = MonoType::Var(b);
let _ = self.unifier.unify(&ty_a, &ty_b);
}
}
fn solve_type_constraint(&mut self, constraint: TypeConstraint) -> Result<()> {
match constraint {
TypeConstraint::Unify(t1, t2) => {
self.unifier.unify(&t1, &t2)?;
}
TypeConstraint::FunctionArity(func_ty, expected_arity) => {
let mut current_ty = &func_ty;
let mut arity = 0;
while let MonoType::Function(_, ret) = current_ty {
arity += 1;
current_ty = ret;
}
if arity != expected_arity {
bail!("Function arity mismatch: expected {expected_arity}, found {arity}");
}
}
TypeConstraint::MethodCall(receiver_ty, method_name, arg_types) => {
self.check_method_call_constraint(&receiver_ty, &method_name, &arg_types)?;
}
TypeConstraint::Iterable(collection_ty, element_ty) => {
match collection_ty {
MonoType::List(inner) => {
self.unifier.unify(&inner, &element_ty)?;
}
MonoType::String => {
self.unifier.unify(&element_ty, &MonoType::Char)?;
}
_ => bail!("Type {collection_ty} is not iterable"),
}
}
}
Ok(())
}
fn check_method_call_constraint(
&mut self,
receiver_ty: &MonoType,
method_name: &str,
_arg_types: &[MonoType],
) -> Result<()> {
match (method_name, receiver_ty) {
("map" | "filter" | "reduce", MonoType::List(_)) => Ok(()),
("len" | "length", MonoType::List(_) | MonoType::String) => Ok(()),
("push", MonoType::List(_)) => Ok(()),
("filter" | "groupby" | "agg" | "select" | "col", MonoType::DataFrame(_)) => Ok(()),
("filter" | "groupby" | "agg" | "select" | "col", MonoType::Named(name))
if name == "DataFrame" =>
{
Ok(())
}
("mean" | "std" | "sum" | "count", MonoType::Series(_) | MonoType::DataFrame(_)) => {
Ok(())
}
("mean" | "std" | "sum" | "count", MonoType::Named(name))
if name == "Series" || name == "DataFrame" =>
{
Ok(())
}
("insert" | "get" | "contains_key", MonoType::Named(name))
if name.starts_with("HashMap") =>
{
Ok(())
}
("chars" | "trim" | "to_upper" | "to_lower", MonoType::String) => Ok(()),
_ => {
Ok(())
}
}
}
fn infer_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Literal(lit) => Ok(Self::infer_literal(lit)),
ExprKind::Identifier(name) => self.infer_identifier(name),
ExprKind::QualifiedName { module: _, name } => self.infer_identifier(name),
ExprKind::If {
condition: _,
then_branch: _,
else_branch: _,
} => self.infer_control_flow_expr(expr),
ExprKind::For { .. } | ExprKind::While { .. } | ExprKind::Loop { .. } => {
self.infer_control_flow_expr(expr)
}
ExprKind::Match { expr, arms } => self.infer_match(expr, arms),
ExprKind::IfLet { .. } | ExprKind::WhileLet { .. } => {
self.infer_control_flow_expr(expr)
}
ExprKind::Function { .. } | ExprKind::Lambda { .. } => self.infer_function_expr(expr),
ExprKind::List(..) | ExprKind::Tuple(..) | ExprKind::ListComprehension { .. } => {
self.infer_collection_expr(expr)
}
ExprKind::Binary { .. }
| ExprKind::Unary { .. }
| ExprKind::Call { .. }
| ExprKind::MethodCall { .. } => self.infer_operation_expr(expr),
_ => self.infer_other_expr(expr),
}
}
fn infer_literal(lit: &Literal) -> MonoType {
match lit {
Literal::Integer(_, _) => MonoType::Int,
Literal::Float(_) => MonoType::Float,
Literal::String(_) => MonoType::String,
Literal::Bool(_) => MonoType::Bool,
Literal::Char(_) => MonoType::Char,
Literal::Byte(_) => MonoType::Int, Literal::Unit => MonoType::Unit,
Literal::Null => MonoType::Unit, Literal::Atom(_) => MonoType::String, }
}
fn infer_identifier(&mut self, name: &str) -> Result<MonoType> {
match self.env.lookup(name) {
Some(scheme) => Ok(self.env.instantiate(scheme, &mut self.gen)),
None => bail!("Undefined variable: {name}"),
}
}
fn infer_binary(&mut self, left: &Expr, op: BinaryOp, right: &Expr) -> Result<MonoType> {
let left_ty = self.infer_expr(left)?;
let right_ty = self.infer_expr(right)?;
match op {
BinaryOp::Add => {
if matches!((&left_ty, &right_ty), (MonoType::String, MonoType::String)) {
Ok(MonoType::String)
} else {
self.unifier.unify(&left_ty, &right_ty)?;
self.unifier.unify(&left_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
}
BinaryOp::Subtract | BinaryOp::Multiply | BinaryOp::Divide | BinaryOp::Modulo => {
self.unifier.unify(&left_ty, &right_ty)?;
self.unifier.unify(&left_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
BinaryOp::Power => {
self.unifier.unify(&left_ty, &MonoType::Int)?;
self.unifier.unify(&right_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
BinaryOp::Equal
| BinaryOp::NotEqual
| BinaryOp::Less
| BinaryOp::LessEqual
| BinaryOp::Greater
| BinaryOp::GreaterEqual
| BinaryOp::Gt => {
self.unifier.unify(&left_ty, &right_ty)?;
Ok(MonoType::Bool)
}
BinaryOp::And | BinaryOp::Or => {
self.unifier.unify(&left_ty, &MonoType::Bool)?;
self.unifier.unify(&right_ty, &MonoType::Bool)?;
Ok(MonoType::Bool)
}
BinaryOp::NullCoalesce => {
Ok(right_ty) }
BinaryOp::BitwiseAnd
| BinaryOp::BitwiseOr
| BinaryOp::BitwiseXor
| BinaryOp::LeftShift
| BinaryOp::RightShift => {
self.unifier.unify(&left_ty, &MonoType::Int)?;
self.unifier.unify(&right_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
BinaryOp::Send => {
Ok(MonoType::Unit)
}
BinaryOp::In => {
Ok(MonoType::Bool)
}
}
}
fn infer_unary(&mut self, op: UnaryOp, operand: &Expr) -> Result<MonoType> {
let operand_ty = self.infer_expr(operand)?;
match op {
UnaryOp::Not => {
self.unifier.unify(&operand_ty, &MonoType::Bool)?;
Ok(MonoType::Bool)
}
UnaryOp::Negate => {
self.unifier.unify(&operand_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
UnaryOp::BitwiseNot => {
self.unifier.unify(&operand_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
UnaryOp::Reference | UnaryOp::MutableReference => {
Ok(MonoType::Reference(Box::new(operand_ty)))
}
UnaryOp::Deref => {
match operand_ty {
MonoType::Reference(ref inner) => Ok((**inner).clone()),
_ => Err(anyhow::anyhow!("Cannot dereference non-reference type")),
}
}
}
}
fn infer_throw(&mut self, expr: &Expr) -> Result<MonoType> {
let _expr_ty = self.infer_expr(expr)?;
Ok(MonoType::Var(self.gen.fresh()))
}
fn infer_await(&mut self, expr: &Expr) -> Result<MonoType> {
let expr_ty = self.infer_expr(expr)?;
if let MonoType::Named(name) = &expr_ty {
if name.starts_with("Future") {
return Ok(MonoType::Var(self.gen.fresh()));
}
}
Ok(MonoType::Var(self.gen.fresh()))
}
fn infer_if(
&mut self,
condition: &Expr,
then_branch: &Expr,
else_branch: Option<&Expr>,
) -> Result<MonoType> {
let cond_ty = self.infer_expr(condition)?;
self.unifier.unify(&cond_ty, &MonoType::Bool)?;
let then_ty = self.infer_expr(then_branch)?;
if let Some(else_expr) = else_branch {
let else_ty = self.infer_expr(else_expr)?;
self.unifier.unify(&then_ty, &else_ty)?;
Ok(self.unifier.apply(&then_ty))
} else {
self.unifier.unify(&then_ty, &MonoType::Unit)?;
Ok(MonoType::Unit)
}
}
fn infer_let(
&mut self,
name: &str,
value: &Expr,
body: &Expr,
_is_mutable: bool,
) -> Result<MonoType> {
let value_ty = self.infer_expr(value)?;
let scheme = self.env.generalize(value_ty);
let old_env = self.env.clone();
self.env = self.env.extend(name, scheme);
let body_ty = self.infer_expr(body)?;
self.env = old_env;
Ok(body_ty)
}
fn infer_function(
&mut self,
name: &str,
params: &[Param],
body: &Expr,
_return_type: Option<&crate::frontend::ast::Type>,
_is_async: bool,
) -> Result<MonoType> {
let mut param_types = Vec::new();
let old_env = self.env.clone();
for param in params {
let param_ty =
if param.ty.kind == crate::frontend::ast::TypeKind::Named("Any".to_string()) {
MonoType::Var(self.gen.fresh())
} else {
Self::ast_type_to_mono_static(¶m.ty)?
};
param_types.push(param_ty.clone());
self.env = self.env.extend(param.name(), TypeScheme::mono(param_ty));
}
let result_var = MonoType::Var(self.gen.fresh());
let func_type = param_types
.iter()
.rev()
.fold(result_var.clone(), |acc, param_ty| {
MonoType::Function(Box::new(param_ty.clone()), Box::new(acc))
});
self.env = self.env.extend(name, TypeScheme::mono(func_type.clone()));
let body_ty = self.infer_expr(body)?;
self.unifier.unify(&result_var, &body_ty)?;
self.env = old_env;
let final_type = self.unifier.apply(&func_type);
Ok(final_type)
}
fn infer_lambda(&mut self, params: &[Param], body: &Expr) -> Result<MonoType> {
let old_env = self.env.clone();
let mut param_types = Vec::new();
for param in params {
let param_ty = match ¶m.ty.kind {
TypeKind::Named(name) if name == "Any" || name == "_" => {
MonoType::Var(self.gen.fresh())
}
_ => {
Self::ast_type_to_mono_static(¶m.ty)?
}
};
param_types.push(param_ty.clone());
self.env = self.env.extend(param.name(), TypeScheme::mono(param_ty));
}
let body_ty = self.infer_expr(body)?;
self.env = old_env;
let lambda_type = param_types.iter().rev().fold(body_ty, |acc, param_ty| {
MonoType::Function(Box::new(param_ty.clone()), Box::new(acc))
});
Ok(self.unifier.apply(&lambda_type))
}
fn infer_call(&mut self, func: &Expr, args: &[Expr]) -> Result<MonoType> {
let func_ty = self.infer_expr(func)?;
let result_ty = MonoType::Var(self.gen.fresh());
let mut expected_func_ty = result_ty.clone();
for arg in args.iter().rev() {
let arg_ty = self.infer_expr(arg)?;
expected_func_ty = MonoType::Function(Box::new(arg_ty), Box::new(expected_func_ty));
}
self.unifier.unify(&func_ty, &expected_func_ty)?;
Ok(self.unifier.apply(&result_ty))
}
fn infer_macro(&mut self, name: &str, args: &[Expr]) -> Result<MonoType> {
for arg in args {
self.infer_expr(arg)?;
}
match name {
"println" => Ok(MonoType::Unit), "vec" => {
if args.is_empty() {
Ok(MonoType::List(Box::new(MonoType::Var(self.gen.fresh()))))
} else {
let elem_ty = self.infer_expr(&args[0])?;
Ok(MonoType::List(Box::new(elem_ty)))
}
}
"df" => {
self.infer_dataframe_macro(args)
}
_ => bail!("Unknown macro: {name}"),
}
}
fn infer_dataframe_macro(&mut self, args: &[Expr]) -> Result<MonoType> {
let mut columns = Vec::new();
for arg in args {
if let ExprKind::Assign { target, value } = &arg.kind {
let column_name = match &target.kind {
ExprKind::Identifier(name) => name.clone(),
_ => continue, };
let column_type = self.infer_expr(value)?;
let element_type = match column_type {
MonoType::List(elem_type) => *elem_type,
other_type => other_type, };
columns.push((column_name, element_type));
}
}
Ok(MonoType::DataFrame(columns))
}
fn infer_dataframe_from_assignments(&mut self, assignments: &[Expr]) -> Result<MonoType> {
let mut columns = Vec::new();
for assignment in assignments {
if let ExprKind::Assign { target, value } = &assignment.kind {
let column_name = match &target.kind {
ExprKind::Identifier(name) => name.clone(),
_ => continue, };
let column_type = self.infer_expr(value)?;
let element_type = match column_type {
MonoType::List(elem_type) => *elem_type,
other_type => other_type, };
columns.push((column_name, element_type));
}
}
Ok(MonoType::DataFrame(columns))
}
pub fn infer_method_call(
&mut self,
receiver: &Expr,
method: &str,
args: &[Expr],
) -> Result<MonoType> {
let receiver_ty = self.infer_expr(receiver)?;
self.add_method_constraint(&receiver_ty, method, args)?;
match &receiver_ty {
MonoType::List(_) => self.infer_list_method(&receiver_ty, method, args),
MonoType::String => self.infer_string_method(&receiver_ty, method, args),
MonoType::DataFrame(_) | MonoType::Series(_) => {
self.infer_dataframe_method(&receiver_ty, method, args)
}
MonoType::Named(name) if name == "DataFrame" || name == "Series" => {
self.infer_dataframe_method(&receiver_ty, method, args)
}
_ => self.infer_generic_method(&receiver_ty, method, args),
}
}
fn add_method_constraint(
&mut self,
receiver_ty: &MonoType,
method: &str,
args: &[Expr],
) -> Result<()> {
let arg_types: Result<Vec<_>> = args.iter().map(|arg| self.infer_expr(arg)).collect();
let arg_types = arg_types?;
self.type_constraints.push(TypeConstraint::MethodCall(
receiver_ty.clone(),
method.to_string(),
arg_types,
));
Ok(())
}
fn infer_list_method(
&mut self,
receiver_ty: &MonoType,
method: &str,
args: &[Expr],
) -> Result<MonoType> {
if let MonoType::List(elem_ty) = receiver_ty {
match method {
"len" | "length" => {
self.validate_no_args(method, args)?;
Ok(MonoType::Int)
}
"push" => {
self.validate_single_arg(method, args)?;
let arg_ty = self.infer_expr(&args[0])?;
self.unifier.unify(&arg_ty, elem_ty)?;
Ok(MonoType::Unit)
}
"pop" => {
self.validate_no_args(method, args)?;
Ok(MonoType::Optional(elem_ty.clone()))
}
"sorted" | "reversed" | "unique" => {
self.validate_no_args(method, args)?;
Ok(MonoType::List(elem_ty.clone()))
}
"sum" => {
self.validate_no_args(method, args)?;
Ok(*elem_ty.clone())
}
"min" | "max" => {
self.validate_no_args(method, args)?;
Ok(MonoType::Optional(elem_ty.clone()))
}
_ => self.infer_generic_method(receiver_ty, method, args),
}
} else {
self.infer_generic_method(receiver_ty, method, args)
}
}
fn infer_string_method(
&mut self,
receiver_ty: &MonoType,
method: &str,
args: &[Expr],
) -> Result<MonoType> {
match method {
"len" | "length" => {
self.validate_no_args(method, args)?;
Ok(MonoType::Int)
}
"chars" => {
self.validate_no_args(method, args)?;
Ok(MonoType::List(Box::new(MonoType::String)))
}
_ => self.infer_generic_method(receiver_ty, method, args),
}
}
fn infer_dataframe_method(
&mut self,
receiver_ty: &MonoType,
method: &str,
args: &[Expr],
) -> Result<MonoType> {
match method {
"filter" | "groupby" | "agg" | "select" => match receiver_ty {
MonoType::DataFrame(columns) => Ok(MonoType::DataFrame(columns.clone())),
MonoType::Named(name) if name == "DataFrame" => {
Ok(MonoType::Named("DataFrame".to_string()))
}
_ => Ok(MonoType::Named("DataFrame".to_string())),
},
"mean" | "std" | "sum" | "count" => Ok(MonoType::Float),
"col" => self.infer_column_selection(receiver_ty, args),
_ => self.infer_generic_method(receiver_ty, method, args),
}
}
fn infer_column_selection(
&mut self,
receiver_ty: &MonoType,
args: &[Expr],
) -> Result<MonoType> {
if let MonoType::DataFrame(columns) = receiver_ty {
if let Some(arg) = args.first() {
if let ExprKind::Literal(Literal::String(col_name)) = &arg.kind {
if let Some((_, col_type)) = columns.iter().find(|(name, _)| name == col_name) {
return Ok(MonoType::Series(Box::new(col_type.clone())));
}
}
}
Ok(MonoType::Series(Box::new(MonoType::Var(self.gen.fresh()))))
} else {
Ok(MonoType::Series(Box::new(MonoType::Var(self.gen.fresh()))))
}
}
fn infer_generic_method(
&mut self,
receiver_ty: &MonoType,
method: &str,
args: &[Expr],
) -> Result<MonoType> {
if let Some(scheme) = self.env.lookup(method) {
let method_ty = self.env.instantiate(scheme, &mut self.gen);
let result_ty = MonoType::Var(self.gen.fresh());
let expected_func_ty =
self.build_method_function_type(receiver_ty, args, result_ty.clone())?;
self.unifier.unify(&method_ty, &expected_func_ty)?;
Ok(self.unifier.apply(&result_ty))
} else {
Ok(MonoType::Var(self.gen.fresh()))
}
}
fn build_method_function_type(
&mut self,
receiver_ty: &MonoType,
args: &[Expr],
result_ty: MonoType,
) -> Result<MonoType> {
let mut expected_func_ty = result_ty;
for arg in args.iter().rev() {
let arg_ty = self.infer_expr(arg)?;
expected_func_ty = MonoType::Function(Box::new(arg_ty), Box::new(expected_func_ty));
}
expected_func_ty =
MonoType::Function(Box::new(receiver_ty.clone()), Box::new(expected_func_ty));
Ok(expected_func_ty)
}
fn validate_no_args(&self, method: &str, args: &[Expr]) -> Result<()> {
if !args.is_empty() {
bail!("Method {method} takes no arguments");
}
Ok(())
}
fn validate_single_arg(&self, method: &str, args: &[Expr]) -> Result<()> {
if args.len() != 1 {
bail!("Method {method} takes exactly one argument");
}
Ok(())
}
fn infer_block(&mut self, exprs: &[Expr]) -> Result<MonoType> {
if exprs.is_empty() {
return Ok(MonoType::Unit);
}
if exprs.len() == 2 {
if let (ExprKind::Identifier(name), ExprKind::List(assignments)) =
(&exprs[0].kind, &exprs[1].kind)
{
if name == "df" {
return self.infer_dataframe_from_assignments(assignments);
}
}
}
let mut last_ty = MonoType::Unit;
for expr in exprs {
last_ty = self.infer_expr(expr)?;
}
Ok(last_ty)
}
fn infer_list(&mut self, elements: &[Expr]) -> Result<MonoType> {
if elements.is_empty() {
let elem_ty = MonoType::Var(self.gen.fresh());
return Ok(MonoType::List(Box::new(elem_ty)));
}
let first_ty = self.infer_expr(&elements[0])?;
for elem in &elements[1..] {
let elem_ty = self.infer_expr(elem)?;
self.unifier.unify(&first_ty, &elem_ty)?;
}
Ok(MonoType::List(Box::new(self.unifier.apply(&first_ty))))
}
fn infer_list_comprehension(
&mut self,
element: &Expr,
variable: &str,
iterable: &Expr,
condition: Option<&Expr>,
) -> Result<MonoType> {
let iterable_ty = self.infer_expr(iterable)?;
let elem_ty = MonoType::Var(self.gen.fresh());
self.unifier
.unify(&iterable_ty, &MonoType::List(Box::new(elem_ty.clone())))?;
let old_env = self.env.clone();
self.env = self
.env
.extend(variable, TypeScheme::mono(self.unifier.apply(&elem_ty)));
if let Some(cond) = condition {
let cond_ty = self.infer_expr(cond)?;
self.unifier.unify(&cond_ty, &MonoType::Bool)?;
}
let result_elem_ty = self.infer_expr(element)?;
self.env = old_env;
Ok(MonoType::List(Box::new(
self.unifier.apply(&result_elem_ty),
)))
}
fn infer_match(
&mut self,
expr: &Expr,
arms: &[crate::frontend::ast::MatchArm],
) -> Result<MonoType> {
let expr_ty = self.infer_expr(expr)?;
if arms.is_empty() {
bail!("Match expression must have at least one arm");
}
let result_ty = MonoType::Var(self.gen.fresh());
for arm in arms {
let old_env = self.env.clone();
self.infer_pattern(&arm.pattern, &expr_ty)?;
let body_ty = self.infer_expr(&arm.body)?;
self.unifier.unify(&result_ty, &body_ty)?;
self.env = old_env;
}
Ok(self.unifier.apply(&result_ty))
}
fn infer_pattern(&mut self, pattern: &Pattern, expected_ty: &MonoType) -> Result<()> {
match pattern {
Pattern::Wildcard => Ok(()),
Pattern::Literal(lit) => {
let lit_ty = Self::infer_literal(lit);
self.unifier.unify(expected_ty, &lit_ty)
}
Pattern::Identifier(name) => {
self.env = self.env.extend(name, TypeScheme::mono(expected_ty.clone()));
Ok(())
}
Pattern::QualifiedName(_path) => {
Ok(())
}
Pattern::List(patterns) => {
let elem_ty = MonoType::Var(self.gen.fresh());
self.unifier
.unify(expected_ty, &MonoType::List(Box::new(elem_ty.clone())))?;
for pat in patterns {
self.infer_pattern(pat, &elem_ty)?;
}
Ok(())
}
Pattern::Ok(inner) => {
if let MonoType::Result(ok_ty, _) = expected_ty {
self.infer_pattern(inner, ok_ty)
} else {
let error_ty = MonoType::Var(self.gen.fresh());
let inner_ty = MonoType::Var(self.gen.fresh());
let result_ty =
MonoType::Result(Box::new(inner_ty.clone()), Box::new(error_ty));
self.unifier.unify(expected_ty, &result_ty)?;
self.infer_pattern(inner, &inner_ty)
}
}
Pattern::Err(inner) => {
if let MonoType::Result(_, err_ty) = expected_ty {
self.infer_pattern(inner, err_ty)
} else {
let ok_ty = MonoType::Var(self.gen.fresh());
let inner_ty = MonoType::Var(self.gen.fresh());
let result_ty = MonoType::Result(Box::new(ok_ty), Box::new(inner_ty.clone()));
self.unifier.unify(expected_ty, &result_ty)?;
self.infer_pattern(inner, &inner_ty)
}
}
Pattern::Some(inner) => {
if let MonoType::Optional(inner_ty) = expected_ty {
self.infer_pattern(inner, inner_ty)
} else {
let inner_ty = MonoType::Var(self.gen.fresh());
let option_ty = MonoType::Optional(Box::new(inner_ty.clone()));
self.unifier.unify(expected_ty, &option_ty)?;
self.infer_pattern(inner, &inner_ty)
}
}
Pattern::None => {
let type_var = MonoType::Var(self.gen.fresh());
let option_ty = MonoType::Optional(Box::new(type_var));
self.unifier.unify(expected_ty, &option_ty)
}
Pattern::Tuple(patterns) => {
let mut elem_types = Vec::new();
for pat in patterns {
let elem_ty = MonoType::Var(self.gen.fresh());
self.infer_pattern(pat, &elem_ty)?;
elem_types.push(elem_ty);
}
let tuple_ty = MonoType::Tuple(elem_types);
self.unifier.unify(expected_ty, &tuple_ty)
}
Pattern::Struct {
name,
fields,
has_rest: _,
} => {
let struct_ty = MonoType::Named(name.clone());
self.unifier.unify(expected_ty, &struct_ty)?;
for field in fields {
if let Some(pattern) = &field.pattern {
let field_ty = MonoType::Var(self.gen.fresh());
self.infer_pattern(pattern, &field_ty)?;
}
}
Ok(())
}
Pattern::Range { start, end, .. } => {
let start_ty = MonoType::Var(self.gen.fresh());
let end_ty = MonoType::Var(self.gen.fresh());
self.infer_pattern(start, &start_ty)?;
self.infer_pattern(end, &end_ty)?;
self.unifier.unify(&start_ty, &end_ty)?;
self.unifier.unify(expected_ty, &start_ty)
}
Pattern::Or(patterns) => {
for pat in patterns {
self.infer_pattern(pat, expected_ty)?;
}
Ok(())
}
Pattern::Rest => {
Ok(())
}
Pattern::RestNamed(name) => {
self.env = self.env.extend(name, TypeScheme::mono(expected_ty.clone()));
Ok(())
}
Pattern::AtBinding { name, pattern } => {
self.env = self.env.extend(name, TypeScheme::mono(expected_ty.clone()));
self.infer_pattern(pattern, expected_ty)
}
Pattern::WithDefault { pattern, .. } => {
self.infer_pattern(pattern, expected_ty)
}
Pattern::TupleVariant { path: _, patterns } => {
for pat in patterns {
let elem_ty = MonoType::Var(self.gen.fresh());
self.infer_pattern(pat, &elem_ty)?;
}
Ok(())
}
Pattern::Mut(inner) => {
self.infer_pattern(inner, expected_ty)
}
}
}
fn infer_for(&mut self, var: &str, iter: &Expr, body: &Expr) -> Result<MonoType> {
let iter_ty = self.infer_expr(iter)?;
let elem_ty = MonoType::Var(self.gen.fresh());
self.unifier
.unify(&iter_ty, &MonoType::List(Box::new(elem_ty.clone())))?;
let old_env = self.env.clone();
self.env = self.env.extend(var, TypeScheme::mono(elem_ty));
let _body_ty = self.infer_expr(body)?;
self.env = old_env;
Ok(MonoType::Unit)
}
fn infer_while(&mut self, condition: &Expr, body: &Expr) -> Result<MonoType> {
let cond_ty = self.infer_expr(condition)?;
self.unifier.unify(&cond_ty, &MonoType::Bool)?;
let body_ty = self.infer_expr(body)?;
self.unifier.unify(&body_ty, &MonoType::Unit)?;
Ok(MonoType::Unit)
}
fn infer_loop(&mut self, body: &Expr) -> Result<MonoType> {
let body_ty = self.infer_expr(body)?;
self.unifier.unify(&body_ty, &MonoType::Unit)?;
Ok(MonoType::Unit)
}
fn infer_range(&mut self, start: &Expr, end: &Expr) -> Result<MonoType> {
let start_ty = self.infer_expr(start)?;
let end_ty = self.infer_expr(end)?;
self.unifier.unify(&start_ty, &MonoType::Int)?;
self.unifier.unify(&end_ty, &MonoType::Int)?;
Ok(MonoType::List(Box::new(MonoType::Int)))
}
fn infer_pipeline(
&mut self,
expr: &Expr,
stages: &[crate::frontend::ast::PipelineStage],
) -> Result<MonoType> {
let mut current_ty = self.infer_expr(expr)?;
for stage in stages {
let stage_ty = self.infer_expr(&stage.op)?;
let result_ty = MonoType::Var(self.gen.fresh());
let expected_func =
MonoType::Function(Box::new(current_ty.clone()), Box::new(result_ty.clone()));
self.unifier.unify(&stage_ty, &expected_func)?;
current_ty = self.unifier.apply(&result_ty);
}
Ok(current_ty)
}
fn infer_assign(&mut self, target: &Expr, value: &Expr) -> Result<MonoType> {
let value_ty = self.infer_expr(value)?;
let target_ty = self.infer_expr(target)?;
self.unifier.unify(&target_ty, &value_ty)?;
Ok(MonoType::Unit)
}
fn infer_compound_assign(
&mut self,
target: &Expr,
op: BinaryOp,
value: &Expr,
) -> Result<MonoType> {
let target_ty = self.infer_expr(target)?;
let value_ty = self.infer_expr(value)?;
let result_ty = self.infer_binary_op_type(op, &target_ty, &value_ty)?;
self.unifier.unify(&target_ty, &result_ty)?;
Ok(MonoType::Unit)
}
fn infer_binary_op_type(
&mut self,
op: BinaryOp,
left_ty: &MonoType,
right_ty: &MonoType,
) -> Result<MonoType> {
match op {
BinaryOp::Add
| BinaryOp::Subtract
| BinaryOp::Multiply
| BinaryOp::Divide
| BinaryOp::Modulo => {
if let Ok(()) = self.unifier.unify(left_ty, &MonoType::Int) {
if let Ok(()) = self.unifier.unify(right_ty, &MonoType::Int) {
return Ok(MonoType::Int);
}
}
self.unifier.unify(left_ty, &MonoType::Float)?;
self.unifier.unify(right_ty, &MonoType::Float)?;
Ok(MonoType::Float)
}
BinaryOp::Power => {
self.unifier.unify(left_ty, right_ty)?;
if let Ok(()) = self.unifier.unify(left_ty, &MonoType::Int) {
Ok(MonoType::Int)
} else {
self.unifier.unify(left_ty, &MonoType::Float)?;
Ok(MonoType::Float)
}
}
BinaryOp::Equal
| BinaryOp::NotEqual
| BinaryOp::Less
| BinaryOp::LessEqual
| BinaryOp::Greater
| BinaryOp::GreaterEqual
| BinaryOp::Gt => {
self.unifier.unify(left_ty, right_ty)?;
Ok(MonoType::Bool)
}
BinaryOp::And | BinaryOp::Or => {
self.unifier.unify(left_ty, &MonoType::Bool)?;
self.unifier.unify(right_ty, &MonoType::Bool)?;
Ok(MonoType::Bool)
}
BinaryOp::NullCoalesce => {
Ok(right_ty.clone())
}
BinaryOp::BitwiseAnd
| BinaryOp::BitwiseOr
| BinaryOp::BitwiseXor
| BinaryOp::LeftShift
| BinaryOp::RightShift => {
self.unifier.unify(left_ty, &MonoType::Int)?;
self.unifier.unify(right_ty, &MonoType::Int)?;
Ok(MonoType::Int)
}
BinaryOp::Send => {
Ok(MonoType::Unit)
}
BinaryOp::In => {
Ok(MonoType::Bool)
}
}
}
fn infer_increment_decrement(&mut self, target: &Expr) -> Result<MonoType> {
let target_ty = self.infer_expr(target)?;
if let Ok(()) = self.unifier.unify(&target_ty, &MonoType::Int) {
Ok(MonoType::Int)
} else {
self.unifier.unify(&target_ty, &MonoType::Float)?;
Ok(MonoType::Float)
}
}
fn ast_type_to_mono_static(ty: &crate::frontend::ast::Type) -> Result<MonoType> {
use crate::frontend::ast::TypeKind;
Ok(match &ty.kind {
TypeKind::Named(name) => match name.as_str() {
"i32" | "i64" => MonoType::Int,
"f32" | "f64" => MonoType::Float,
"bool" => MonoType::Bool,
"String" | "str" => MonoType::String,
"Any" => MonoType::Var(TyVarGenerator::new().fresh()),
_ => MonoType::Named(name.clone()),
},
TypeKind::Generic { base, params } => {
match base.as_str() {
"Vec" | "List" => {
if let Some(first_param) = params.first() {
MonoType::List(Box::new(Self::ast_type_to_mono_static(first_param)?))
} else {
MonoType::List(Box::new(MonoType::Var(TyVarGenerator::new().fresh())))
}
}
"Option" => {
if let Some(first_param) = params.first() {
MonoType::Optional(Box::new(Self::ast_type_to_mono_static(
first_param,
)?))
} else {
MonoType::Optional(Box::new(MonoType::Var(
TyVarGenerator::new().fresh(),
)))
}
}
_ => MonoType::Named(base.clone()),
}
}
TypeKind::Optional(inner) => {
MonoType::Optional(Box::new(Self::ast_type_to_mono_static(inner)?))
}
TypeKind::List(inner) => {
MonoType::List(Box::new(Self::ast_type_to_mono_static(inner)?))
}
TypeKind::Array { elem_type, size: _ } => {
MonoType::List(Box::new(Self::ast_type_to_mono_static(elem_type)?))
}
TypeKind::Function { params, ret } => {
let ret_ty = Self::ast_type_to_mono_static(ret)?;
let result: Result<MonoType> =
params.iter().rev().try_fold(ret_ty, |acc, param| {
Ok(MonoType::Function(
Box::new(Self::ast_type_to_mono_static(param)?),
Box::new(acc),
))
});
result?
}
TypeKind::DataFrame { columns } => {
let mut col_types = Vec::new();
for (name, ty) in columns {
col_types.push((name.clone(), Self::ast_type_to_mono_static(ty)?));
}
MonoType::DataFrame(col_types)
}
TypeKind::Series { dtype } => {
MonoType::Series(Box::new(Self::ast_type_to_mono_static(dtype)?))
}
TypeKind::Tuple(types) => {
let mono_types: Result<Vec<_>> =
types.iter().map(Self::ast_type_to_mono_static).collect();
MonoType::Tuple(mono_types?)
}
TypeKind::Reference { inner, .. } => {
Self::ast_type_to_mono_static(inner)?
}
TypeKind::Refined { base, .. } => Self::ast_type_to_mono_static(base)?,
})
}
#[must_use]
pub fn solve(&self, var: &crate::middleend::types::TyVar) -> MonoType {
self.unifier.solve(var)
}
#[must_use]
pub fn apply(&self, ty: &MonoType) -> MonoType {
self.unifier.apply(ty)
}
fn infer_control_flow_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::If {
condition,
then_branch,
else_branch,
} => self.infer_if(condition, then_branch, else_branch.as_deref()),
ExprKind::For {
var, iter, body, ..
} => self.infer_for(var, iter, body),
ExprKind::While {
condition, body, ..
} => self.infer_while(condition, body),
ExprKind::Loop { body, .. } => self.infer_loop(body),
ExprKind::IfLet {
pattern: _,
expr,
then_branch,
else_branch,
} => {
let _expr_ty = self.infer_expr(expr)?;
let then_ty = self.infer_expr(then_branch)?;
let else_ty = if let Some(else_expr) = else_branch {
self.infer_expr(else_expr)?
} else {
MonoType::Unit
};
self.unifier.unify(&then_ty, &else_ty)?;
Ok(then_ty)
}
ExprKind::WhileLet {
pattern: _,
expr,
body,
..
} => {
let _expr_ty = self.infer_expr(expr)?;
let _body_ty = self.infer_expr(body)?;
Ok(MonoType::Unit)
}
_ => bail!("Unexpected expression type in control flow handler"),
}
}
fn infer_function_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Function {
name,
params,
body,
return_type,
is_async,
..
} => self.infer_function(name, params, body, return_type.as_ref(), *is_async),
ExprKind::Lambda { params, body } => self.infer_lambda(params, body),
_ => bail!("Unexpected expression type in function handler"),
}
}
fn infer_collection_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::List(elements) => self.infer_list(elements),
ExprKind::Tuple(elements) => {
let element_types: Result<Vec<_>> =
elements.iter().map(|e| self.infer_expr(e)).collect();
Ok(MonoType::Tuple(element_types?))
}
ExprKind::ListComprehension { element, clauses } => {
if let Some(first_clause) = clauses.first() {
self.infer_list_comprehension(
element,
&first_clause.variable,
&first_clause.iterable,
first_clause.condition.as_deref(),
)
} else {
bail!("List comprehension must have at least one clause")
}
}
_ => bail!("Unexpected expression type in collection handler"),
}
}
fn infer_operation_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Binary { left, op, right } => self.infer_binary(left, *op, right),
ExprKind::Unary { op, operand } => self.infer_unary(*op, operand),
ExprKind::Call { func, args } => self.infer_call(func, args),
ExprKind::MethodCall {
receiver,
method,
args,
} => self.infer_method_call(receiver, method, args),
_ => bail!("Unexpected expression type in operation handler"),
}
}
pub fn infer_other_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::StringInterpolation { parts } => self.infer_string_interpolation(parts),
ExprKind::Throw { expr } => self.infer_throw(expr),
ExprKind::Ok { value } => self.infer_result_ok(value),
ExprKind::Err { error } => self.infer_result_err(error),
ExprKind::Break { .. } | ExprKind::Continue { .. } | ExprKind::Return { .. } => {
self.infer_other_control_flow_expr(expr)
}
ExprKind::Struct { .. }
| ExprKind::Enum { .. }
| ExprKind::Trait { .. }
| ExprKind::Impl { .. }
| ExprKind::Extension { .. }
| ExprKind::Actor { .. }
| ExprKind::Import { .. }
| ExprKind::Export { .. } => self.infer_other_definition_expr(expr),
ExprKind::StructLiteral { .. }
| ExprKind::ObjectLiteral { .. }
| ExprKind::FieldAccess { .. }
| ExprKind::IndexAccess { .. }
| ExprKind::Slice { .. } => self.infer_other_literal_access_expr(expr),
ExprKind::Some { .. } | ExprKind::None => self.infer_other_option_expr(expr),
ExprKind::Await { .. } | ExprKind::AsyncBlock { .. } | ExprKind::Try { .. } => {
self.infer_other_async_expr(expr)
}
ExprKind::Send { .. }
| ExprKind::ActorSend { .. }
| ExprKind::Ask { .. }
| ExprKind::ActorQuery { .. } => self.infer_other_actor_expr(expr),
ExprKind::Assign { .. }
| ExprKind::CompoundAssign { .. }
| ExprKind::PreIncrement { .. }
| ExprKind::PostIncrement { .. }
| ExprKind::PreDecrement { .. }
| ExprKind::PostDecrement { .. } => self.infer_other_assignment_expr(expr),
_ => self.infer_remaining_expr(expr),
}
}
fn infer_other_control_flow_expr(&mut self, _expr: &Expr) -> Result<MonoType> {
Ok(MonoType::Unit) }
fn infer_other_definition_expr(&mut self, _expr: &Expr) -> Result<MonoType> {
Ok(MonoType::Unit) }
fn infer_other_literal_access_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::StructLiteral { name, .. } => Ok(MonoType::Named(name.clone())),
ExprKind::ObjectLiteral { fields } => self.infer_object_literal(fields),
ExprKind::FieldAccess { object, .. } => self.infer_field_access(object),
ExprKind::IndexAccess { object, index } => self.infer_index_access(object, index),
ExprKind::Slice { object, .. } => self.infer_slice(object),
_ => bail!("Unexpected literal/access expression"),
}
}
fn infer_other_option_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Some { value } => {
let inner_type = self.infer_expr(value)?;
Ok(MonoType::Optional(Box::new(inner_type)))
}
ExprKind::None => {
let type_var = MonoType::Var(self.gen.fresh());
Ok(MonoType::Optional(Box::new(type_var)))
}
_ => bail!("Unexpected option expression"),
}
}
fn infer_other_async_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Await { expr } => self.infer_await(expr),
ExprKind::AsyncBlock { body } => self.infer_async_block(body),
ExprKind::Try { expr } => {
let expr_type = self.infer(expr)?;
Ok(expr_type)
}
_ => bail!("Unexpected async expression"),
}
}
fn infer_other_actor_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Send { actor, message } | ExprKind::ActorSend { actor, message } => {
self.infer_send(actor, message)
}
ExprKind::Ask {
actor,
message,
timeout,
} => self.infer_ask(actor, message, timeout.as_deref()),
ExprKind::ActorQuery { actor, message } => self.infer_ask(actor, message, None),
_ => bail!("Unexpected actor expression"),
}
}
fn infer_other_assignment_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Assign { target, value } => self.infer_assign(target, value),
ExprKind::CompoundAssign { target, op, value } => {
self.infer_compound_assign(target, *op, value)
}
ExprKind::PreIncrement { target }
| ExprKind::PostIncrement { target }
| ExprKind::PreDecrement { target }
| ExprKind::PostDecrement { target } => self.infer_increment_decrement(target),
_ => bail!("Unexpected assignment expression"),
}
}
fn infer_remaining_expr(&mut self, expr: &Expr) -> Result<MonoType> {
match &expr.kind {
ExprKind::Let {
name,
value,
body,
is_mutable,
..
} => self.infer_let(name, value, body, *is_mutable),
ExprKind::Block(exprs) => self.infer_block(exprs),
ExprKind::Range { start, end, .. } => self.infer_range(start, end),
ExprKind::Pipeline { expr, stages } => self.infer_pipeline(expr, stages),
ExprKind::Module { body, .. } => self.infer_expr(body),
ExprKind::DataFrame { columns } => self.infer_dataframe(columns),
ExprKind::Command { .. } => Ok(MonoType::String),
ExprKind::Macro { name, args } => self.infer_macro(name, args),
ExprKind::DataFrameOperation { source, operation } => {
self.infer_dataframe_operation(source, operation)
}
_ => bail!("Unknown expression type in inference"),
}
}
fn infer_string_interpolation(
&mut self,
parts: &[crate::frontend::ast::StringPart],
) -> Result<MonoType> {
for part in parts {
if let crate::frontend::ast::StringPart::Expr(expr) = part {
let _ = self.infer_expr(expr)?;
}
}
Ok(MonoType::Named("String".to_string()))
}
fn infer_result_ok(&mut self, value: &Expr) -> Result<MonoType> {
let value_type = self.infer_expr(value)?;
let error_type = MonoType::Var(self.gen.fresh());
Ok(MonoType::Result(Box::new(value_type), Box::new(error_type)))
}
fn infer_result_err(&mut self, error: &Expr) -> Result<MonoType> {
let error_type = self.infer_expr(error)?;
let value_type = MonoType::Var(self.gen.fresh());
Ok(MonoType::Result(Box::new(value_type), Box::new(error_type)))
}
fn infer_object_literal(
&mut self,
fields: &[crate::frontend::ast::ObjectField],
) -> Result<MonoType> {
for field in fields {
match field {
crate::frontend::ast::ObjectField::KeyValue { value, .. } => {
let _ = self.infer_expr(value)?;
}
crate::frontend::ast::ObjectField::Spread { expr } => {
let _ = self.infer_expr(expr)?;
}
}
}
Ok(MonoType::Named("Object".to_string()))
}
fn infer_field_access(&mut self, object: &Expr) -> Result<MonoType> {
let _object_ty = self.infer_expr(object)?;
Ok(MonoType::Var(self.gen.fresh()))
}
fn infer_index_access(&mut self, object: &Expr, index: &Expr) -> Result<MonoType> {
let object_ty = self.infer_expr(object)?;
let index_ty = self.infer_expr(index)?;
if let MonoType::List(inner_ty) = &index_ty {
if matches!(**inner_ty, MonoType::Int) {
return Ok(object_ty);
}
}
match object_ty {
MonoType::List(element_ty) => {
self.unifier.unify(&index_ty, &MonoType::Int)?;
Ok(*element_ty)
}
MonoType::String => {
self.unifier.unify(&index_ty, &MonoType::Int)?;
Ok(MonoType::String)
}
_ => Ok(MonoType::Var(self.gen.fresh())),
}
}
fn infer_slice(&mut self, object: &Expr) -> Result<MonoType> {
let object_ty = self.infer_expr(object)?;
Ok(object_ty)
}
fn infer_send(&mut self, actor: &Expr, message: &Expr) -> Result<MonoType> {
let _actor_ty = self.infer_expr(actor)?;
let _message_ty = self.infer_expr(message)?;
Ok(MonoType::Unit)
}
fn infer_ask(
&mut self,
actor: &Expr,
message: &Expr,
timeout: Option<&Expr>,
) -> Result<MonoType> {
let _actor_ty = self.infer_expr(actor)?;
let _message_ty = self.infer_expr(message)?;
if let Some(t) = timeout {
let timeout_ty = self.infer_expr(t)?;
self.unifier.unify(&timeout_ty, &MonoType::Int)?;
}
Ok(MonoType::Var(self.gen.fresh()))
}
fn infer_dataframe(
&mut self,
columns: &[crate::frontend::ast::DataFrameColumn],
) -> Result<MonoType> {
let mut column_types = Vec::new();
for col in columns {
let col_type = if col.values.is_empty() {
MonoType::Var(self.gen.fresh())
} else {
let first_ty = self.infer_expr(&col.values[0])?;
for value in &col.values[1..] {
let value_ty = self.infer_expr(value)?;
self.unifier.unify(&first_ty, &value_ty)?;
}
first_ty
};
column_types.push((col.name.clone(), col_type));
}
Ok(MonoType::DataFrame(column_types))
}
fn infer_dataframe_operation(
&mut self,
source: &Expr,
operation: &crate::frontend::ast::DataFrameOp,
) -> Result<MonoType> {
use crate::frontend::ast::DataFrameOp;
let source_ty = self.infer_expr(source)?;
match &source_ty {
MonoType::DataFrame(columns) => {
match operation {
DataFrameOp::Filter(_) => {
Ok(source_ty.clone())
}
DataFrameOp::Select(selected_cols) => {
let mut new_columns = Vec::new();
for col_name in selected_cols {
if let Some((_, ty)) = columns.iter().find(|(name, _)| name == col_name)
{
new_columns.push((col_name.clone(), ty.clone()));
}
}
Ok(MonoType::DataFrame(new_columns))
}
DataFrameOp::GroupBy(_) => {
Ok(source_ty.clone())
}
DataFrameOp::Aggregate(_) => {
Ok(source_ty.clone())
}
DataFrameOp::Join { .. } => {
Ok(source_ty.clone())
}
DataFrameOp::Sort { .. } => {
Ok(source_ty.clone())
}
DataFrameOp::Limit(_) | DataFrameOp::Head(_) | DataFrameOp::Tail(_) => {
Ok(source_ty.clone())
}
}
}
MonoType::Named(name) if name == "DataFrame" => {
Ok(MonoType::Named("DataFrame".to_string()))
}
_ => bail!("DataFrame operation on non-DataFrame type: {source_ty}"),
}
}
fn infer_async_block(&mut self, body: &Expr) -> Result<MonoType> {
let body_ty = self.infer_expr(body)?;
Ok(MonoType::Named(format!("Future<{body_ty}>")))
}
}
impl Default for InferenceContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::panic)]
mod tests {
use super::*;
use crate::frontend::parser::Parser;
fn infer_str(input: &str) -> Result<MonoType> {
let mut parser = Parser::new(input);
let expr = parser.parse()?;
let mut ctx = InferenceContext::new();
ctx.infer(&expr)
}
#[test]
fn test_infer_literals() {
assert_eq!(
infer_str("42").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("3.15").expect("type inference should succeed in test"),
MonoType::Float
);
assert_eq!(
infer_str("true").expect("type inference should succeed in test"),
MonoType::Bool
);
assert_eq!(
infer_str("\"hello\"").expect("type inference should succeed in test"),
MonoType::String
);
}
#[test]
fn test_infer_arithmetic() {
assert_eq!(
infer_str("1 + 2").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("3 * 4").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("5 - 2").expect("type inference should succeed in test"),
MonoType::Int
);
}
#[test]
fn test_infer_comparison() {
assert_eq!(
infer_str("1 < 2").expect("type inference should succeed in test"),
MonoType::Bool
);
assert_eq!(
infer_str("3 == 3").expect("type inference should succeed in test"),
MonoType::Bool
);
assert_eq!(
infer_str("true != false").expect("type inference should succeed in test"),
MonoType::Bool
);
}
#[test]
fn test_infer_if() {
assert_eq!(
infer_str("if true { 1 } else { 2 }").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("if false { \"yes\" } else { \"no\" }")
.expect("type inference should succeed in test"),
MonoType::String
);
}
#[test]
fn test_infer_let() {
assert_eq!(
infer_str("let x = 42 in x + 1").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("let f = 3.15 in let g = 2.71 in f")
.expect("type inference should succeed in test"),
MonoType::Float
);
}
#[test]
fn test_infer_list() {
assert_eq!(
infer_str("[1, 2, 3]").expect("type inference should succeed in test"),
MonoType::List(Box::new(MonoType::Int))
);
assert_eq!(
infer_str("[true, false]").expect("type inference should succeed in test"),
MonoType::List(Box::new(MonoType::Bool))
);
}
#[test]
#[ignore = "DataFrame syntax not yet implemented"]
fn test_infer_dataframe() {
let df_str = r#"df![age = [25, 30, 35], name = ["Alice", "Bob", "Charlie"]]"#;
let result = infer_str(df_str).unwrap_or(MonoType::DataFrame(vec![]));
match result {
MonoType::DataFrame(columns) => {
assert_eq!(columns.len(), 2);
assert_eq!(columns[0].0, "age");
assert!(matches!(columns[0].1, MonoType::Int));
assert_eq!(columns[1].0, "name");
assert!(matches!(columns[1].1, MonoType::String));
}
_ => panic!("Expected DataFrame type, got {result:?}"),
}
}
#[test]
#[ignore = "DataFrame syntax not yet implemented"]
fn test_infer_dataframe_operations() {
let df_str = r"df![age = [25, 30, 35]]";
let result = infer_str(df_str).unwrap_or(MonoType::DataFrame(vec![]));
match result {
MonoType::DataFrame(columns) => {
assert_eq!(columns.len(), 1);
assert_eq!(columns[0].0, "age");
}
_ => panic!("Expected DataFrame type, got {result:?}"),
}
}
#[test]
fn test_infer_series() {
let col_str = r#"let df = DataFrame::new(); df.col("age")"#;
let result = infer_str(col_str).unwrap_or(MonoType::DataFrame(vec![]));
assert!(matches!(result, MonoType::Series(_)) || matches!(result, MonoType::DataFrame(_)));
let mean_str = r#"let df = DataFrame::new(); df.col("age").mean()"#;
let result = infer_str(mean_str).unwrap_or(MonoType::Float);
assert_eq!(result, MonoType::Float);
}
#[test]
fn test_infer_function() {
let result = infer_str("fun add(x: i32, y: i32) -> i32 { x + y }")
.expect("type inference should succeed in test");
match result {
MonoType::Function(first_arg, remaining) => {
assert!(matches!(first_arg.as_ref(), MonoType::Int));
match remaining.as_ref() {
MonoType::Function(second_arg, return_type) => {
assert!(matches!(second_arg.as_ref(), MonoType::Int));
assert!(matches!(return_type.as_ref(), MonoType::Int));
}
_ => panic!("Expected function type"),
}
}
_ => panic!("Expected function type"),
}
}
#[test]
fn test_type_errors() {
assert!(infer_str("1 + true").is_err());
assert!(infer_str("if 42 { 1 } else { 2 }").is_err());
assert!(infer_str("[1, true, 3]").is_err());
}
#[test]
fn test_infer_lambda() {
let result = infer_str("|x| x + 1").expect("type inference should succeed in test");
match result {
MonoType::Function(arg, ret) => {
assert!(matches!(arg.as_ref(), MonoType::Int));
assert!(matches!(ret.as_ref(), MonoType::Int));
}
_ => panic!("Expected function type for lambda"),
}
let result = infer_str("|x, y| x * y").expect("type inference should succeed in test");
match result {
MonoType::Function(first_arg, remaining) => {
assert!(matches!(first_arg.as_ref(), MonoType::Int));
match remaining.as_ref() {
MonoType::Function(second_arg, return_type) => {
assert!(matches!(second_arg.as_ref(), MonoType::Int));
assert!(matches!(return_type.as_ref(), MonoType::Int));
}
_ => panic!("Expected function type"),
}
}
_ => panic!("Expected function type for lambda"),
}
let result = infer_str("|| 42").expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
let result =
infer_str("let f = |x| x + 1 in f(5)").expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
}
#[test]
fn test_self_hosting_patterns() {
let result = infer_str("x => x * 2").expect("type inference should succeed in test");
match result {
MonoType::Function(arg, ret) => {
assert!(matches!(arg.as_ref(), MonoType::Int));
assert!(matches!(ret.as_ref(), MonoType::Int));
}
_ => panic!("Expected function type for fat arrow lambda"),
}
let result =
infer_str("let map = |f, xs| xs in let double = |x| x * 2 in map(double, [1, 2, 3])")
.expect("type inference should succeed in test");
assert!(matches!(result, MonoType::List(_)));
let result = infer_str(
"fun factorial(n: i32) -> i32 { if n <= 1 { 1 } else { n * factorial(n - 1) } }",
)
.expect("type inference should succeed in test");
match result {
MonoType::Function(arg, ret) => {
assert!(matches!(arg.as_ref(), MonoType::Int));
assert!(matches!(ret.as_ref(), MonoType::Int));
}
_ => panic!("Expected function type for recursive function"),
}
}
#[test]
fn test_compiler_data_structures() {
let result = infer_str("struct Token { kind: String, value: String }")
.expect("type inference should succeed in test");
assert_eq!(result, MonoType::Unit);
let result = infer_str("enum Expr { Literal, Binary, Function }")
.expect("type inference should succeed in test");
assert_eq!(result, MonoType::Unit);
let result = infer_str("[1, 2, 3]").expect("type inference should succeed in test");
assert!(matches!(result, MonoType::List(_)));
let result = infer_str("[1, 2, 3].len()").expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
}
#[test]
fn test_constraint_solving() {
let result = infer_str("[1, 2, 3].len()").expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
let result = infer_str("let id = |x| x in let n = id(42) in let s = id(\"hello\") in n")
.expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
let result =
infer_str("let f = |x| x + 1 in f").expect("type inference should succeed in test");
assert!(matches!(result, MonoType::Function(_, _)));
let result = infer_str("let compose = |f, g, x| f(g(x)) in compose")
.expect("type inference should succeed in test");
assert!(matches!(result, MonoType::Function(_, _)));
}
#[test]
#[ignore = "Unary operation type inference needs implementation"]
fn test_unary_operations() {
assert_eq!(
infer_str("-5").expect("type inference should succeed"),
MonoType::Int
);
assert_eq!(
infer_str("-3.15").expect("type inference should succeed"),
MonoType::Float
);
assert_eq!(
infer_str("!true").expect("type inference should succeed"),
MonoType::Bool
);
assert_eq!(
infer_str("!false").expect("type inference should succeed"),
MonoType::Bool
);
}
#[test]
fn test_logical_operations() {
assert_eq!(
infer_str("true && false").expect("type inference should succeed in test"),
MonoType::Bool
);
assert_eq!(
infer_str("true || false").expect("type inference should succeed in test"),
MonoType::Bool
);
assert_eq!(
infer_str("(1 < 2) && (3 > 2)").expect("type inference should succeed in test"),
MonoType::Bool
);
}
#[test]
fn test_block_expressions() {
assert_eq!(
infer_str("{ 42 }").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("{ 1; 2; 3 }").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("{ let x = 5; x + 1 }").expect("type inference should succeed in test"),
MonoType::Int
);
}
#[test]
fn test_tuple_types() {
let result = infer_str("(1, true)").expect("type inference should succeed in test");
match result {
MonoType::Tuple(types) => {
assert_eq!(types.len(), 2);
assert!(matches!(types[0], MonoType::Int));
assert!(matches!(types[1], MonoType::Bool));
}
_ => panic!("Expected tuple type"),
}
let result =
infer_str("(1, \"hello\", true)").expect("type inference should succeed in test");
match result {
MonoType::Tuple(types) => {
assert_eq!(types.len(), 3);
assert!(matches!(types[0], MonoType::Int));
assert!(matches!(types[1], MonoType::String));
assert!(matches!(types[2], MonoType::Bool));
}
_ => panic!("Expected tuple type"),
}
}
#[test]
fn test_match_expressions() {
let result = infer_str("match 5 { 0 => \"zero\", _ => \"other\" }")
.expect("type inference should succeed in test");
assert_eq!(result, MonoType::String);
let result = infer_str("match true { true => 1, false => 2 }")
.expect("type inference should succeed in test");
assert_eq!(result, MonoType::Int);
}
#[test]
#[ignore = "While loop type inference needs implementation"]
fn test_while_loop() {
assert_eq!(
infer_str("while false { 1 }").expect("type inference should succeed"),
MonoType::Unit
);
}
#[test]
fn test_for_loop() {
assert_eq!(
infer_str("for x in [1, 2, 3] { x }").expect("type inference should succeed in test"),
MonoType::Unit
);
}
#[test]
fn test_string_operations() {
assert_eq!(
infer_str("\"hello\" + \" world\"").expect("type inference should succeed in test"),
MonoType::String
);
}
#[test]
fn test_recursion_limit() {
let mut ctx = InferenceContext::new();
ctx.recursion_depth = 99;
let expr = Expr::new(
ExprKind::Literal(Literal::Integer(42, None)),
Default::default(),
);
let result = ctx.infer(&expr);
assert!(result.is_ok());
}
#[test]
fn test_type_environment() {
let mut env = TypeEnv::standard();
env.bind("custom_var", TypeScheme::mono(MonoType::Float));
let mut ctx = InferenceContext::with_env(env);
let expr = Expr::new(
ExprKind::Literal(Literal::Integer(42, None)),
Default::default(),
);
let result = ctx.infer(&expr);
assert_eq!(
result.expect("type inference should succeed in test"),
MonoType::Int
);
}
#[test]
fn test_constraint_types() {
let unify = TypeConstraint::Unify(MonoType::Int, MonoType::Int);
match unify {
TypeConstraint::Unify(a, b) => {
assert_eq!(a, MonoType::Int);
assert_eq!(b, MonoType::Int);
}
_ => panic!("Expected Unify constraint"),
}
let arity = TypeConstraint::FunctionArity(MonoType::Int, 2);
match arity {
TypeConstraint::FunctionArity(ty, n) => {
assert_eq!(ty, MonoType::Int);
assert_eq!(n, 2);
}
_ => panic!("Expected FunctionArity constraint"),
}
let method = TypeConstraint::MethodCall(MonoType::String, "len".to_string(), vec![]);
match method {
TypeConstraint::MethodCall(ty, name, args) => {
assert_eq!(ty, MonoType::String);
assert_eq!(name, "len");
assert!(args.is_empty());
}
_ => panic!("Expected MethodCall constraint"),
}
let iter = TypeConstraint::Iterable(MonoType::List(Box::new(MonoType::Int)), MonoType::Int);
match iter {
TypeConstraint::Iterable(container, elem) => {
assert!(matches!(container, MonoType::List(_)));
assert_eq!(elem, MonoType::Int);
}
_ => panic!("Expected Iterable constraint"),
}
}
#[test]
fn test_option_types() {
let result = infer_str("None");
assert!(result.is_ok() || result.is_err());
let result = infer_str("Some(42)");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_result_types() {
let result = infer_str("Ok(42)");
assert!(result.is_ok() || result.is_err());
let result = infer_str("Err(\"error\")");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_char_literal() {
assert_eq!(
infer_str("'a'").expect("type inference should succeed in test"),
MonoType::Char
);
assert_eq!(
infer_str("'\\n'").expect("type inference should succeed in test"),
MonoType::Char
);
}
#[test]
fn test_array_indexing() {
assert_eq!(
infer_str("[1, 2, 3][0]").expect("type inference should succeed in test"),
MonoType::Int
);
assert_eq!(
infer_str("[\"a\", \"b\"][1]").expect("type inference should succeed in test"),
MonoType::String
);
}
#[test]
fn test_field_access() {
let _ = infer_str("point.x");
}
#[test]
fn test_break_continue() {
let result = infer_str("loop { break }");
assert!(result.is_ok() || result.is_err());
let result = infer_str("loop { continue }");
assert!(result.is_ok() || result.is_err());
}
#[test]
#[ignore = "Function type inference needs implementation"]
fn test_return_statement() {
assert_eq!(
infer_str("fun test() { return 42 }").expect("type inference should succeed"),
MonoType::Function(Box::new(MonoType::Unit), Box::new(MonoType::Int))
);
}
#[test]
fn test_complex_nested_expression() {
let result = infer_str("if (1 + 2) > 2 { [1, 2, 3] } else { [4, 5] }")
.expect("type inference should succeed in test");
assert!(matches!(result, MonoType::List(_)));
}
#[test]
fn test_error_cases() {
let result = infer_str("undefined_var");
assert!(result.is_err());
let result = infer_str("if true { 1 } else { \"string\" }");
let _ = result;
let result = infer_str("[1, \"string\", true]");
let _ = result;
}
#[test]
fn test_nested_function_inference() {
let result = infer_str("fun outer(x) { fun inner(y) { x + y } inner }");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_polymorphic_function() {
let result = infer_str("let id = fun(x) { x } in id(42)");
if let Ok(ty) = result {
assert_eq!(ty, MonoType::Int);
}
let result2 = infer_str("let id = fun(x) { x } in id(true)");
if let Ok(ty) = result2 {
assert_eq!(ty, MonoType::Bool);
}
}
#[test]
fn test_tuple_inference() {
let result = infer_str("(1, \"hello\", true)");
if let Ok(ty) = result {
if let MonoType::Tuple(types) = ty {
assert_eq!(types.len(), 3);
assert_eq!(types[0], MonoType::Int);
assert_eq!(types[1], MonoType::String);
assert_eq!(types[2], MonoType::Bool);
}
}
}
#[test]
fn test_pattern_match_inference() {
let result = infer_str("match x { Some(v) => v, None => 0 }");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_recursive_type_inference() {
let result =
infer_str("let rec fact = fun(n) { if n == 0 { 1 } else { n * fact(n - 1) } } in fact");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_constraint_solving_comprehensive() {
let mut ctx = InferenceContext::new();
let tv1 = ctx.gen.fresh();
let tv2 = ctx.gen.fresh();
ctx.constraints.push((tv1, tv2));
let result = ctx.solve_all_constraints();
assert!(result.is_ok());
}
#[test]
fn test_method_call_inference() {
let result = infer_str("[1, 2, 3].map(fun(x) { x * 2 })");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_field_access_inference() {
let result = infer_str("point.x");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_array_indexing_inference() {
let result = infer_str("[1, 2, 3][0]");
if let Ok(ty) = result {
assert_eq!(ty, MonoType::Int);
}
}
#[test]
fn test_type_annotation_inference() {
let result = infer_str("let x: i32 = 42 in x");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_generic_instantiation() {
let mut ctx = InferenceContext::new();
let tv = ctx.gen.fresh();
let scheme = TypeScheme::generalize(&TypeEnv::new(), &MonoType::Var(tv));
let instantiated = ctx.instantiate(&scheme);
assert!(matches!(instantiated, MonoType::Var(_)));
}
#[test]
fn test_complex_unification() {
let mut ctx = InferenceContext::new();
let fn1 = MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool));
let fn2 = MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool));
let result = ctx.unifier.unify(&fn1, &fn2);
assert!(result.is_ok());
}
#[test]
fn test_type_environment_comprehensive() {
let mut env = TypeEnv::new();
let scheme = TypeScheme::mono(MonoType::Int);
env.bind("x", scheme.clone());
assert_eq!(env.lookup("x"), Some(&scheme));
assert_eq!(env.lookup("y"), None);
}
#[test]
fn test_error_recovery() {
let mut ctx = InferenceContext::new();
ctx.recursion_depth = 99;
let expr = Parser::new("42")
.parse()
.expect("type inference should succeed in test");
let result = ctx.infer(&expr);
assert!(result.is_ok());
}
#[test]
fn test_async_type_inference() {
let result = infer_str("async { await fetch() }");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_error_handling_inference() {
let result = infer_str("try { risky_op()? }");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_closure_inference() {
let result = infer_str("|x, y| x + y");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_range_inference() {
let result = infer_str("1..10");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_context_initialization() {
let ctx = InferenceContext::new();
assert_eq!(ctx.recursion_depth, 0);
assert!(ctx.constraints.is_empty());
assert!(ctx.type_constraints.is_empty());
let env = TypeEnv::standard();
let ctx2 = InferenceContext::with_env(env);
assert_eq!(ctx2.recursion_depth, 0);
}
#[test]
fn test_type_constraint_handling() {
let mut ctx = InferenceContext::new();
ctx.type_constraints
.push(TypeConstraint::Unify(MonoType::Int, MonoType::Int));
ctx.type_constraints.push(TypeConstraint::FunctionArity(
MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool)),
1,
));
let result = ctx.solve_all_constraints();
assert!(result.is_ok());
}
#[test]
fn test_infer_integer_literal_r162() {
assert_eq!(infer_str("0").unwrap(), MonoType::Int);
assert_eq!(infer_str("-1").unwrap(), MonoType::Int);
assert_eq!(infer_str("999999").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_float_literal_r162() {
assert_eq!(infer_str("0.0").unwrap(), MonoType::Float);
assert_eq!(infer_str("3.14159").unwrap(), MonoType::Float);
}
#[test]
fn test_infer_string_literal_r162() {
assert_eq!(infer_str("\"\"").unwrap(), MonoType::String);
assert_eq!(infer_str("\"test string\"").unwrap(), MonoType::String);
}
#[test]
fn test_infer_bool_literal_r162() {
assert_eq!(infer_str("true").unwrap(), MonoType::Bool);
assert_eq!(infer_str("false").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_add_integers_r162() {
assert_eq!(infer_str("5 + 3").unwrap(), MonoType::Int);
assert_eq!(infer_str("0 + 0").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_subtract_r162() {
assert_eq!(infer_str("10 - 3").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_multiply_r162() {
assert_eq!(infer_str("4 * 5").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_divide_r162() {
assert_eq!(infer_str("20 / 4").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_modulo_r162() {
assert_eq!(infer_str("17 % 5").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_float_arithmetic_r162() {
let result1 = infer_str("1.5 + 2.5");
let result2 = infer_str("3.0 * 2.0");
assert!(result1.is_ok() || result1.is_err());
assert!(result2.is_ok() || result2.is_err());
}
#[test]
fn test_infer_less_than_r162() {
assert_eq!(infer_str("3 < 5").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_greater_than_r162() {
assert_eq!(infer_str("10 > 7").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_less_equal_r162() {
assert_eq!(infer_str("5 <= 5").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_greater_equal_r162() {
assert_eq!(infer_str("8 >= 3").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_equality_r162() {
assert_eq!(infer_str("42 == 42").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_inequality_r162() {
assert_eq!(infer_str("1 != 2").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_logical_and_r162() {
assert_eq!(infer_str("true && false").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_logical_or_r162() {
assert_eq!(infer_str("true || false").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_unary_neg_r162() {
assert_eq!(infer_str("-42").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_unary_not_r162() {
assert_eq!(infer_str("!true").unwrap(), MonoType::Bool);
assert_eq!(infer_str("!false").unwrap(), MonoType::Bool);
}
#[test]
fn test_infer_empty_list_r162() {
let result = infer_str("[]");
assert!(result.is_ok());
}
#[test]
fn test_infer_integer_list_r162() {
assert_eq!(
infer_str("[1, 2, 3, 4]").unwrap(),
MonoType::List(Box::new(MonoType::Int))
);
}
#[test]
fn test_infer_string_list_r162() {
assert_eq!(
infer_str("[\"a\", \"b\", \"c\"]").unwrap(),
MonoType::List(Box::new(MonoType::String))
);
}
#[test]
fn test_infer_bool_list_r162() {
assert_eq!(
infer_str("[true, false, true]").unwrap(),
MonoType::List(Box::new(MonoType::Bool))
);
}
#[test]
fn test_infer_if_else_int_r162() {
assert_eq!(
infer_str("if true { 10 } else { 20 }").unwrap(),
MonoType::Int
);
}
#[test]
fn test_infer_if_else_string_r162() {
assert_eq!(
infer_str("if false { \"yes\" } else { \"no\" }").unwrap(),
MonoType::String
);
}
#[test]
fn test_infer_if_else_bool_r162() {
assert_eq!(
infer_str("if true { true } else { false }").unwrap(),
MonoType::Bool
);
}
#[test]
fn test_infer_nested_if_r162() {
let result = infer_str("if true { if false { 1 } else { 2 } } else { 3 }");
assert_eq!(result.unwrap(), MonoType::Int);
}
#[test]
fn test_infer_let_integer_r162() {
assert_eq!(infer_str("let x = 10 in x").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_let_string_r162() {
assert_eq!(
infer_str("let s = \"hello\" in s").unwrap(),
MonoType::String
);
}
#[test]
fn test_infer_let_expression_r162() {
assert_eq!(infer_str("let x = 5 + 3 in x * 2").unwrap(), MonoType::Int);
}
#[test]
fn test_infer_nested_let_r162() {
assert_eq!(
infer_str("let x = 1 in let y = 2 in x + y").unwrap(),
MonoType::Int
);
}
#[test]
fn test_type_constraint_unify_r162() {
let constraint = TypeConstraint::Unify(MonoType::Int, MonoType::Int);
assert!(format!("{:?}", constraint).contains("Unify"));
}
#[test]
fn test_type_constraint_function_arity_r162() {
let constraint = TypeConstraint::FunctionArity(
MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool)),
1,
);
assert!(format!("{:?}", constraint).contains("FunctionArity"));
}
#[test]
fn test_type_constraint_method_call_r162() {
let constraint = TypeConstraint::MethodCall(MonoType::String, "len".to_string(), vec![]);
assert!(format!("{:?}", constraint).contains("MethodCall"));
}
#[test]
fn test_type_constraint_iterable_r162() {
let constraint =
TypeConstraint::Iterable(MonoType::List(Box::new(MonoType::Int)), MonoType::Int);
assert!(format!("{:?}", constraint).contains("Iterable"));
}
#[test]
fn test_infer_lambda_single_param() {
let result = infer_str("|x| x + 1");
assert!(result.is_ok(), "Lambda should infer type");
}
#[test]
fn test_infer_lambda_multiple_params() {
let result = infer_str("|x, y| x + y");
assert!(result.is_ok(), "Multi-param lambda should infer type");
}
#[test]
fn test_infer_lambda_no_params() {
let result = infer_str("|| 42");
assert!(result.is_ok(), "No-param lambda should infer type");
}
#[test]
fn test_infer_tuple() {
let result = infer_str("(1, \"hello\", true)");
assert!(result.is_ok(), "Tuple should infer type");
}
#[test]
fn test_infer_array_empty() {
let result = infer_str("[]");
assert!(result.is_ok(), "Empty array should infer type");
}
#[test]
fn test_infer_array_with_elements() {
let result = infer_str("[1, 2, 3]");
assert!(result.is_ok(), "Array with elements should infer type");
}
#[test]
fn test_infer_map_empty() {
let result = infer_str("{}");
assert!(result.is_ok(), "Empty map should infer type");
}
#[test]
fn test_infer_map_with_entries() {
let result = infer_str("{\"a\": 1, \"b\": 2}");
assert!(result.is_ok(), "Map with entries should infer type");
}
#[test]
fn test_infer_if_expression() {
let result = infer_str("if true { 1 } else { 0 }");
assert!(result.is_ok(), "If expression should infer type");
}
#[test]
fn test_infer_if_without_else() {
let result = infer_str("if true { 1 }");
let _ = result;
}
#[test]
fn test_infer_block() {
let result = infer_str("{ let x = 1; x + 1 }");
assert!(result.is_ok(), "Block should infer type");
}
#[test]
fn test_infer_let_binding() {
let result = infer_str("let x = 42");
assert!(result.is_ok(), "Let binding should infer type");
}
#[test]
fn test_infer_function_call() {
let result = infer_str("print(\"hello\")");
assert!(result.is_ok(), "Function call should infer type");
}
#[test]
fn test_infer_method_call() {
let result = infer_str("[1, 2, 3].len()");
assert!(result.is_ok(), "Method call should infer type");
}
#[test]
fn test_infer_index_access() {
let result = infer_str("[1, 2, 3][0]");
assert!(result.is_ok(), "Index access should infer type");
}
#[test]
fn test_infer_field_access() {
let result = infer_str("{\"x\": 1}.x");
let _ = result;
}
#[test]
fn test_infer_unary_neg() {
let result = infer_str("-5");
assert!(result.is_ok(), "Unary neg should infer type");
}
#[test]
fn test_infer_unary_not() {
let result = infer_str("!true");
assert!(result.is_ok(), "Unary not should infer type");
}
#[test]
fn test_infer_binary_and() {
let result = infer_str("true && false");
assert!(result.is_ok(), "Binary and should infer type");
}
#[test]
fn test_infer_binary_or() {
let result = infer_str("true || false");
assert!(result.is_ok(), "Binary or should infer type");
}
#[test]
fn test_infer_string_concat() {
let result = infer_str("\"hello\" + \" world\"");
assert!(result.is_ok(), "String concat should infer type");
}
#[test]
fn test_infer_range() {
let result = infer_str("1..10");
let _ = result;
}
#[test]
fn test_infer_some() {
let result = infer_str("Some(42)");
assert!(result.is_ok(), "Some should infer type");
}
#[test]
fn test_infer_none() {
let result = infer_str("None");
assert!(result.is_ok(), "None should infer type");
}
#[test]
fn test_infer_ok() {
let result = infer_str("Ok(42)");
let _ = result;
}
#[test]
fn test_infer_err() {
let result = infer_str("Err(\"error\")");
let _ = result;
}
#[test]
fn test_infer_while_loop() {
let result = infer_str("while true { 1 }");
let _ = result;
}
#[test]
fn test_infer_for_loop() {
let result = infer_str("for x in [1, 2, 3] { x }");
assert!(result.is_ok(), "For loop should infer type");
}
#[test]
fn test_infer_break() {
let result = infer_str("while true { break }");
assert!(result.is_ok(), "Break should infer type");
}
#[test]
fn test_infer_continue() {
let result = infer_str("while true { continue }");
assert!(result.is_ok(), "Continue should infer type");
}
#[test]
fn test_infer_return() {
let result = infer_str("fun f() { return 42 }");
assert!(result.is_ok(), "Return should infer type");
}
#[test]
fn test_infer_match() {
let result = infer_str("match 1 { 1 => \"one\", _ => \"other\" }");
assert!(result.is_ok(), "Match should infer type");
}
#[test]
fn test_infer_try_catch() {
let result = infer_str("try { 1 } catch e { 0 }");
let _ = result;
}
#[test]
fn test_monotype_display() {
assert_eq!(format!("{}", MonoType::Int), "i32");
assert_eq!(format!("{}", MonoType::Float), "f64");
assert_eq!(format!("{}", MonoType::Bool), "bool");
assert_eq!(format!("{}", MonoType::String), "String");
assert_eq!(format!("{}", MonoType::Unit), "()");
assert_eq!(format!("{}", MonoType::Char), "char");
}
#[test]
fn test_monotype_complex_display() {
let list_type = MonoType::List(Box::new(MonoType::Int));
assert!(
format!("{}", list_type).contains("i32"),
"List should contain i32"
);
let tuple_type = MonoType::Tuple(vec![MonoType::Int, MonoType::String]);
assert!(
format!("{}", tuple_type).contains("i32"),
"Tuple should contain i32"
);
assert!(
format!("{}", tuple_type).contains("String"),
"Tuple should contain String"
);
let opt_type = MonoType::Optional(Box::new(MonoType::Int));
assert!(
format!("{}", opt_type).contains("i32"),
"Optional should contain i32"
);
let result_type = MonoType::Result(Box::new(MonoType::Int), Box::new(MonoType::String));
assert!(
format!("{}", result_type).contains("i32"),
"Result should contain i32"
);
}
#[test]
fn test_tyvar_generator_fresh() {
use super::super::types::TyVarGenerator;
let mut gen = TyVarGenerator::new();
let tv1 = gen.fresh();
let tv2 = gen.fresh();
let tv3 = gen.fresh();
assert!(tv1.0 != tv2.0);
assert!(tv2.0 != tv3.0);
assert!(tv1.0 != tv3.0);
}
}
#[cfg(test)]
mod property_tests_infer {
use proptest::proptest;
proptest! {
#[test]
fn test_new_never_panics(input: String) {
let _input = if input.len() > 100 { &input[..100] } else { &input[..] };
let _ = std::panic::catch_unwind(|| {
});
}
}
}