use std::collections::BTreeMap;
use crate::ast::*;
use crate::builtin_signatures;
use harn_lexer::Span;
#[derive(Debug, Clone)]
pub struct TypeDiagnostic {
pub message: String,
pub severity: DiagnosticSeverity,
pub span: Option<Span>,
pub help: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagnosticSeverity {
Error,
Warning,
}
type InferredType = Option<TypeExpr>;
#[derive(Debug, Clone)]
struct TypeScope {
vars: BTreeMap<String, InferredType>,
functions: BTreeMap<String, FnSignature>,
type_aliases: BTreeMap<String, TypeExpr>,
enums: BTreeMap<String, Vec<String>>,
interfaces: BTreeMap<String, Vec<InterfaceMethod>>,
structs: BTreeMap<String, Vec<(String, InferredType)>>,
impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
generic_type_params: std::collections::BTreeSet<String>,
where_constraints: BTreeMap<String, String>,
parent: Option<Box<TypeScope>>,
}
#[derive(Debug, Clone)]
struct ImplMethodSig {
name: String,
param_count: usize,
param_types: Vec<Option<TypeExpr>>,
return_type: Option<TypeExpr>,
}
#[derive(Debug, Clone)]
struct FnSignature {
params: Vec<(String, InferredType)>,
return_type: InferredType,
type_param_names: Vec<String>,
required_params: usize,
where_clauses: Vec<(String, String)>,
}
impl TypeScope {
fn new() -> Self {
Self {
vars: BTreeMap::new(),
functions: BTreeMap::new(),
type_aliases: BTreeMap::new(),
enums: BTreeMap::new(),
interfaces: BTreeMap::new(),
structs: BTreeMap::new(),
impl_methods: BTreeMap::new(),
generic_type_params: std::collections::BTreeSet::new(),
where_constraints: BTreeMap::new(),
parent: None,
}
}
fn child(&self) -> Self {
Self {
vars: BTreeMap::new(),
functions: BTreeMap::new(),
type_aliases: BTreeMap::new(),
enums: BTreeMap::new(),
interfaces: BTreeMap::new(),
structs: BTreeMap::new(),
impl_methods: BTreeMap::new(),
generic_type_params: std::collections::BTreeSet::new(),
where_constraints: BTreeMap::new(),
parent: Some(Box::new(self.clone())),
}
}
fn get_var(&self, name: &str) -> Option<&InferredType> {
self.vars
.get(name)
.or_else(|| self.parent.as_ref()?.get_var(name))
}
fn get_fn(&self, name: &str) -> Option<&FnSignature> {
self.functions
.get(name)
.or_else(|| self.parent.as_ref()?.get_fn(name))
}
fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
self.type_aliases
.get(name)
.or_else(|| self.parent.as_ref()?.resolve_type(name))
}
fn is_generic_type_param(&self, name: &str) -> bool {
self.generic_type_params.contains(name)
|| self
.parent
.as_ref()
.is_some_and(|p| p.is_generic_type_param(name))
}
fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
self.where_constraints
.get(type_param)
.map(|s| s.as_str())
.or_else(|| {
self.parent
.as_ref()
.and_then(|p| p.get_where_constraint(type_param))
})
}
fn get_enum(&self, name: &str) -> Option<&Vec<String>> {
self.enums
.get(name)
.or_else(|| self.parent.as_ref()?.get_enum(name))
}
fn get_interface(&self, name: &str) -> Option<&Vec<InterfaceMethod>> {
self.interfaces
.get(name)
.or_else(|| self.parent.as_ref()?.get_interface(name))
}
fn get_struct(&self, name: &str) -> Option<&Vec<(String, InferredType)>> {
self.structs
.get(name)
.or_else(|| self.parent.as_ref()?.get_struct(name))
}
fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
self.impl_methods
.get(name)
.or_else(|| self.parent.as_ref()?.get_impl_methods(name))
}
fn define_var(&mut self, name: &str, ty: InferredType) {
self.vars.insert(name.to_string(), ty);
}
fn define_fn(&mut self, name: &str, sig: FnSignature) {
self.functions.insert(name.to_string(), sig);
}
}
fn builtin_return_type(name: &str) -> InferredType {
builtin_signatures::builtin_return_type(name)
}
fn is_builtin(name: &str) -> bool {
builtin_signatures::is_builtin(name)
}
pub struct TypeChecker {
diagnostics: Vec<TypeDiagnostic>,
scope: TypeScope,
}
impl TypeChecker {
pub fn new() -> Self {
Self {
diagnostics: Vec::new(),
scope: TypeScope::new(),
}
}
pub fn check(mut self, program: &[SNode]) -> Vec<TypeDiagnostic> {
Self::register_declarations_into(&mut self.scope, program);
for snode in program {
if let Node::Pipeline { body, .. } = &snode.node {
Self::register_declarations_into(&mut self.scope, body);
}
}
for snode in program {
match &snode.node {
Node::Pipeline { params, body, .. } => {
let mut child = self.scope.child();
for p in params {
child.define_var(p, None);
}
self.check_block(body, &mut child);
}
Node::FnDecl {
name,
type_params,
params,
return_type,
where_clauses,
body,
..
} => {
let required_params =
params.iter().filter(|p| p.default_value.is_none()).count();
let sig = FnSignature {
params: params
.iter()
.map(|p| (p.name.clone(), p.type_expr.clone()))
.collect(),
return_type: return_type.clone(),
type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
required_params,
where_clauses: where_clauses
.iter()
.map(|wc| (wc.type_name.clone(), wc.bound.clone()))
.collect(),
};
self.scope.define_fn(name, sig);
self.check_fn_body(type_params, params, return_type, body, where_clauses);
}
_ => {
let mut scope = self.scope.clone();
self.check_node(snode, &mut scope);
for (name, ty) in scope.vars {
self.scope.vars.entry(name).or_insert(ty);
}
}
}
}
self.diagnostics
}
fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
for snode in nodes {
match &snode.node {
Node::TypeDecl { name, type_expr } => {
scope.type_aliases.insert(name.clone(), type_expr.clone());
}
Node::EnumDecl { name, variants, .. } => {
let variant_names: Vec<String> =
variants.iter().map(|v| v.name.clone()).collect();
scope.enums.insert(name.clone(), variant_names);
}
Node::InterfaceDecl { name, methods, .. } => {
scope.interfaces.insert(name.clone(), methods.clone());
}
Node::StructDecl { name, fields, .. } => {
let field_types: Vec<(String, InferredType)> = fields
.iter()
.map(|f| (f.name.clone(), f.type_expr.clone()))
.collect();
scope.structs.insert(name.clone(), field_types);
}
Node::ImplBlock {
type_name, methods, ..
} => {
let sigs: Vec<ImplMethodSig> = methods
.iter()
.filter_map(|m| {
if let Node::FnDecl {
name,
params,
return_type,
..
} = &m.node
{
let non_self: Vec<_> =
params.iter().filter(|p| p.name != "self").collect();
let param_count = non_self.len();
let param_types: Vec<Option<TypeExpr>> =
non_self.iter().map(|p| p.type_expr.clone()).collect();
Some(ImplMethodSig {
name: name.clone(),
param_count,
param_types,
return_type: return_type.clone(),
})
} else {
None
}
})
.collect();
scope.impl_methods.insert(type_name.clone(), sigs);
}
_ => {}
}
}
}
fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
for stmt in stmts {
self.check_node(stmt, scope);
}
}
fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope) {
match pattern {
BindingPattern::Identifier(name) => {
scope.define_var(name, None);
}
BindingPattern::Dict(fields) => {
for field in fields {
let name = field.alias.as_deref().unwrap_or(&field.key);
scope.define_var(name, None);
}
}
BindingPattern::List(elements) => {
for elem in elements {
scope.define_var(&elem.name, None);
}
}
}
}
fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
let span = snode.span;
match &snode.node {
Node::LetBinding {
pattern,
type_ann,
value,
} => {
let inferred = self.infer_type(value, scope);
if let BindingPattern::Identifier(name) = pattern {
if let Some(expected) = type_ann {
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
let mut msg = format!(
"Type mismatch: '{}' declared as {}, but assigned {}",
name,
format_type(expected),
format_type(actual)
);
if let Some(detail) = shape_mismatch_detail(expected, actual) {
msg.push_str(&format!(" ({})", detail));
}
self.error_at(msg, span);
}
}
}
let ty = type_ann.clone().or(inferred);
scope.define_var(name, ty);
} else {
Self::define_pattern_vars(pattern, scope);
}
}
Node::VarBinding {
pattern,
type_ann,
value,
} => {
let inferred = self.infer_type(value, scope);
if let BindingPattern::Identifier(name) = pattern {
if let Some(expected) = type_ann {
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
let mut msg = format!(
"Type mismatch: '{}' declared as {}, but assigned {}",
name,
format_type(expected),
format_type(actual)
);
if let Some(detail) = shape_mismatch_detail(expected, actual) {
msg.push_str(&format!(" ({})", detail));
}
self.error_at(msg, span);
}
}
}
let ty = type_ann.clone().or(inferred);
scope.define_var(name, ty);
} else {
Self::define_pattern_vars(pattern, scope);
}
}
Node::FnDecl {
name,
type_params,
params,
return_type,
where_clauses,
body,
..
} => {
let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
let sig = FnSignature {
params: params
.iter()
.map(|p| (p.name.clone(), p.type_expr.clone()))
.collect(),
return_type: return_type.clone(),
type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
required_params,
where_clauses: where_clauses
.iter()
.map(|wc| (wc.type_name.clone(), wc.bound.clone()))
.collect(),
};
scope.define_fn(name, sig.clone());
scope.define_var(name, None);
self.check_fn_body(type_params, params, return_type, body, where_clauses);
}
Node::FunctionCall { name, args } => {
self.check_call(name, args, scope, span);
}
Node::IfElse {
condition,
then_body,
else_body,
} => {
self.check_node(condition, scope);
let mut then_scope = scope.child();
if let Some((var_name, narrowed)) = Self::extract_nil_narrowing(condition, scope) {
then_scope.define_var(&var_name, narrowed);
}
self.check_block(then_body, &mut then_scope);
if let Some(else_body) = else_body {
let mut else_scope = scope.child();
self.check_block(else_body, &mut else_scope);
}
}
Node::ForIn {
pattern,
iterable,
body,
} => {
self.check_node(iterable, scope);
let mut loop_scope = scope.child();
if let BindingPattern::Identifier(variable) = pattern {
let elem_type = match self.infer_type(iterable, scope) {
Some(TypeExpr::List(inner)) => Some(*inner),
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
_ => None,
};
loop_scope.define_var(variable, elem_type);
} else {
Self::define_pattern_vars(pattern, &mut loop_scope);
}
self.check_block(body, &mut loop_scope);
}
Node::WhileLoop { condition, body } => {
self.check_node(condition, scope);
let mut loop_scope = scope.child();
self.check_block(body, &mut loop_scope);
}
Node::RequireStmt { condition, message } => {
self.check_node(condition, scope);
if let Some(message) = message {
self.check_node(message, scope);
}
}
Node::TryCatch {
body,
error_var,
catch_body,
finally_body,
..
} => {
let mut try_scope = scope.child();
self.check_block(body, &mut try_scope);
let mut catch_scope = scope.child();
if let Some(var) = error_var {
catch_scope.define_var(var, None);
}
self.check_block(catch_body, &mut catch_scope);
if let Some(fb) = finally_body {
let mut finally_scope = scope.child();
self.check_block(fb, &mut finally_scope);
}
}
Node::TryExpr { body } => {
let mut try_scope = scope.child();
self.check_block(body, &mut try_scope);
}
Node::ReturnStmt {
value: Some(val), ..
} => {
self.check_node(val, scope);
}
Node::Assignment {
target, value, op, ..
} => {
self.check_node(value, scope);
if let Node::Identifier(name) = &target.node {
if let Some(Some(var_type)) = scope.get_var(name) {
let value_type = self.infer_type(value, scope);
let assigned = if let Some(op) = op {
let var_inferred = scope.get_var(name).cloned().flatten();
infer_binary_op_type(op, &var_inferred, &value_type)
} else {
value_type
};
if let Some(actual) = &assigned {
if !self.types_compatible(var_type, actual, scope) {
self.error_at(
format!(
"Type mismatch: cannot assign {} to '{}' (declared as {})",
format_type(actual),
name,
format_type(var_type)
),
span,
);
}
}
}
}
}
Node::TypeDecl { name, type_expr } => {
scope.type_aliases.insert(name.clone(), type_expr.clone());
}
Node::EnumDecl { name, variants, .. } => {
let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
scope.enums.insert(name.clone(), variant_names);
}
Node::StructDecl { name, fields, .. } => {
let field_types: Vec<(String, InferredType)> = fields
.iter()
.map(|f| (f.name.clone(), f.type_expr.clone()))
.collect();
scope.structs.insert(name.clone(), field_types);
}
Node::InterfaceDecl { name, methods, .. } => {
scope.interfaces.insert(name.clone(), methods.clone());
}
Node::ImplBlock {
type_name, methods, ..
} => {
let sigs: Vec<ImplMethodSig> = methods
.iter()
.filter_map(|m| {
if let Node::FnDecl {
name,
params,
return_type,
..
} = &m.node
{
let non_self: Vec<_> =
params.iter().filter(|p| p.name != "self").collect();
let param_count = non_self.len();
let param_types: Vec<Option<TypeExpr>> =
non_self.iter().map(|p| p.type_expr.clone()).collect();
Some(ImplMethodSig {
name: name.clone(),
param_count,
param_types,
return_type: return_type.clone(),
})
} else {
None
}
})
.collect();
scope.impl_methods.insert(type_name.clone(), sigs);
for method_sn in methods {
self.check_node(method_sn, scope);
}
}
Node::TryOperator { operand } => {
self.check_node(operand, scope);
}
Node::MatchExpr { value, arms } => {
self.check_node(value, scope);
let value_type = self.infer_type(value, scope);
for arm in arms {
self.check_node(&arm.pattern, scope);
if let Some(ref vt) = value_type {
let value_type_name = format_type(vt);
let mismatch = match &arm.pattern.node {
Node::StringLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
}
Node::IntLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
&& !self.types_compatible(
vt,
&TypeExpr::Named("float".into()),
scope,
)
}
Node::FloatLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
&& !self.types_compatible(
vt,
&TypeExpr::Named("int".into()),
scope,
)
}
Node::BoolLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
}
_ => false,
};
if mismatch {
let pattern_type = match &arm.pattern.node {
Node::StringLiteral(_) => "string",
Node::IntLiteral(_) => "int",
Node::FloatLiteral(_) => "float",
Node::BoolLiteral(_) => "bool",
_ => unreachable!(),
};
self.warning_at(
format!(
"Match pattern type mismatch: matching {} against {} literal",
value_type_name, pattern_type
),
arm.pattern.span,
);
}
}
let mut arm_scope = scope.child();
self.check_block(&arm.body, &mut arm_scope);
}
self.check_match_exhaustiveness(value, arms, scope, span);
}
Node::BinaryOp { op, left, right } => {
self.check_node(left, scope);
self.check_node(right, scope);
let lt = self.infer_type(left, scope);
let rt = self.infer_type(right, scope);
if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
match op.as_str() {
"-" | "*" | "/" | "%" => {
let numeric = ["int", "float"];
if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
self.warning_at(
format!(
"Operator '{op}' may not be valid for types {} and {}",
l, r
),
span,
);
}
}
"+" => {
let valid = ["int", "float", "string", "list", "dict"];
if !valid.contains(&l.as_str()) && !valid.contains(&r.as_str()) {
self.warning_at(
format!(
"Operator '+' may not be valid for types {} and {}",
l, r
),
span,
);
}
}
_ => {}
}
}
}
Node::UnaryOp { operand, .. } => {
self.check_node(operand, scope);
}
Node::MethodCall {
object,
method,
args,
..
}
| Node::OptionalMethodCall {
object,
method,
args,
..
} => {
self.check_node(object, scope);
for arg in args {
self.check_node(arg, scope);
}
if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
if scope.is_generic_type_param(&type_name) {
if let Some(iface_name) = scope.get_where_constraint(&type_name) {
if let Some(iface_methods) = scope.get_interface(iface_name) {
let has_method = iface_methods.iter().any(|m| m.name == *method);
if !has_method {
self.warning_at(
format!(
"Method '{}' not found in interface '{}' (constraint on '{}')",
method, iface_name, type_name
),
span,
);
}
}
}
}
}
}
Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
self.check_node(object, scope);
}
Node::SubscriptAccess { object, index } => {
self.check_node(object, scope);
self.check_node(index, scope);
}
Node::SliceAccess { object, start, end } => {
self.check_node(object, scope);
if let Some(s) = start {
self.check_node(s, scope);
}
if let Some(e) = end {
self.check_node(e, scope);
}
}
Node::Ternary {
condition,
true_expr,
false_expr,
} => {
self.check_node(condition, scope);
self.check_node(true_expr, scope);
self.check_node(false_expr, scope);
}
Node::ThrowStmt { value } => {
self.check_node(value, scope);
}
Node::GuardStmt {
condition,
else_body,
} => {
self.check_node(condition, scope);
let mut else_scope = scope.child();
self.check_block(else_body, &mut else_scope);
}
Node::SpawnExpr { body } => {
let mut spawn_scope = scope.child();
self.check_block(body, &mut spawn_scope);
}
Node::Parallel {
count,
variable,
body,
} => {
self.check_node(count, scope);
let mut par_scope = scope.child();
if let Some(var) = variable {
par_scope.define_var(var, Some(TypeExpr::Named("int".into())));
}
self.check_block(body, &mut par_scope);
}
Node::ParallelMap {
list,
variable,
body,
}
| Node::ParallelSettle {
list,
variable,
body,
} => {
self.check_node(list, scope);
let mut par_scope = scope.child();
let elem_type = match self.infer_type(list, scope) {
Some(TypeExpr::List(inner)) => Some(*inner),
_ => None,
};
par_scope.define_var(variable, elem_type);
self.check_block(body, &mut par_scope);
}
Node::SelectExpr {
cases,
timeout,
default_body,
} => {
for case in cases {
self.check_node(&case.channel, scope);
let mut case_scope = scope.child();
case_scope.define_var(&case.variable, None);
self.check_block(&case.body, &mut case_scope);
}
if let Some((dur, body)) = timeout {
self.check_node(dur, scope);
let mut timeout_scope = scope.child();
self.check_block(body, &mut timeout_scope);
}
if let Some(body) = default_body {
let mut default_scope = scope.child();
self.check_block(body, &mut default_scope);
}
}
Node::DeadlineBlock { duration, body } => {
self.check_node(duration, scope);
let mut block_scope = scope.child();
self.check_block(body, &mut block_scope);
}
Node::MutexBlock { body } => {
let mut block_scope = scope.child();
self.check_block(body, &mut block_scope);
}
Node::Retry { count, body } => {
self.check_node(count, scope);
let mut retry_scope = scope.child();
self.check_block(body, &mut retry_scope);
}
Node::Closure { params, body, .. } => {
let mut closure_scope = scope.child();
for p in params {
closure_scope.define_var(&p.name, p.type_expr.clone());
}
self.check_block(body, &mut closure_scope);
}
Node::ListLiteral(elements) => {
for elem in elements {
self.check_node(elem, scope);
}
}
Node::DictLiteral(entries) | Node::AskExpr { fields: entries } => {
for entry in entries {
self.check_node(&entry.key, scope);
self.check_node(&entry.value, scope);
}
}
Node::RangeExpr { start, end, .. } => {
self.check_node(start, scope);
self.check_node(end, scope);
}
Node::Spread(inner) => {
self.check_node(inner, scope);
}
Node::Block(stmts) => {
let mut block_scope = scope.child();
self.check_block(stmts, &mut block_scope);
}
Node::YieldExpr { value } => {
if let Some(v) = value {
self.check_node(v, scope);
}
}
Node::StructConstruct {
struct_name,
fields,
} => {
for entry in fields {
self.check_node(&entry.key, scope);
self.check_node(&entry.value, scope);
}
if let Some(declared_fields) = scope.get_struct(struct_name).cloned() {
for entry in fields {
if let Node::StringLiteral(key) | Node::Identifier(key) = &entry.key.node {
if !declared_fields.iter().any(|(name, _)| name == key) {
self.warning_at(
format!("Unknown field '{}' in struct '{}'", key, struct_name),
entry.key.span,
);
}
}
}
let provided: Vec<String> = fields
.iter()
.filter_map(|e| match &e.key.node {
Node::StringLiteral(k) | Node::Identifier(k) => Some(k.clone()),
_ => None,
})
.collect();
for (name, _) in &declared_fields {
if !provided.contains(name) {
self.warning_at(
format!(
"Missing field '{}' in struct '{}' construction",
name, struct_name
),
span,
);
}
}
}
}
Node::EnumConstruct {
enum_name,
variant,
args,
} => {
for arg in args {
self.check_node(arg, scope);
}
if let Some(variants) = scope.get_enum(enum_name) {
if !variants.contains(variant) {
self.warning_at(
format!("Unknown variant '{}' in enum '{}'", variant, enum_name),
span,
);
}
}
}
Node::InterpolatedString(_) => {}
Node::StringLiteral(_)
| Node::IntLiteral(_)
| Node::FloatLiteral(_)
| Node::BoolLiteral(_)
| Node::NilLiteral
| Node::Identifier(_)
| Node::DurationLiteral(_)
| Node::BreakStmt
| Node::ContinueStmt
| Node::ReturnStmt { value: None }
| Node::ImportDecl { .. }
| Node::SelectiveImport { .. } => {}
Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
let mut decl_scope = scope.child();
self.check_block(body, &mut decl_scope);
}
}
}
fn check_fn_body(
&mut self,
type_params: &[TypeParam],
params: &[TypedParam],
return_type: &Option<TypeExpr>,
body: &[SNode],
where_clauses: &[WhereClause],
) {
let mut fn_scope = self.scope.child();
for tp in type_params {
fn_scope.generic_type_params.insert(tp.name.clone());
}
for wc in where_clauses {
fn_scope
.where_constraints
.insert(wc.type_name.clone(), wc.bound.clone());
}
for param in params {
fn_scope.define_var(¶m.name, param.type_expr.clone());
if let Some(default) = ¶m.default_value {
self.check_node(default, &mut fn_scope);
}
}
self.check_block(body, &mut fn_scope);
if let Some(ret_type) = return_type {
for stmt in body {
self.check_return_type(stmt, ret_type, &fn_scope);
}
}
}
fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &TypeScope) {
let span = snode.span;
match &snode.node {
Node::ReturnStmt { value: Some(val) } => {
let inferred = self.infer_type(val, scope);
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
self.error_at(
format!(
"Return type mismatch: expected {}, got {}",
format_type(expected),
format_type(actual)
),
span,
);
}
}
}
Node::IfElse {
then_body,
else_body,
..
} => {
for stmt in then_body {
self.check_return_type(stmt, expected, scope);
}
if let Some(else_body) = else_body {
for stmt in else_body {
self.check_return_type(stmt, expected, scope);
}
}
}
_ => {}
}
}
fn satisfies_interface(
&self,
type_name: &str,
interface_name: &str,
scope: &TypeScope,
) -> bool {
self.interface_mismatch_reason(type_name, interface_name, scope)
.is_none()
}
fn interface_mismatch_reason(
&self,
type_name: &str,
interface_name: &str,
scope: &TypeScope,
) -> Option<String> {
let interface_methods = match scope.get_interface(interface_name) {
Some(methods) => methods,
None => return Some(format!("interface '{}' not found", interface_name)),
};
let impl_methods = match scope.get_impl_methods(type_name) {
Some(methods) => methods,
None => {
if interface_methods.is_empty() {
return None;
}
let names: Vec<_> = interface_methods.iter().map(|m| m.name.as_str()).collect();
return Some(format!("missing method(s): {}", names.join(", ")));
}
};
for iface_method in interface_methods {
let iface_params: Vec<_> = iface_method
.params
.iter()
.filter(|p| p.name != "self")
.collect();
let iface_param_count = iface_params.len();
let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
let impl_method = match matching_impl {
Some(m) => m,
None => {
return Some(format!("missing method '{}'", iface_method.name));
}
};
if impl_method.param_count != iface_param_count {
return Some(format!(
"method '{}' has {} parameter(s), expected {}",
iface_method.name, impl_method.param_count, iface_param_count
));
}
for (i, iface_param) in iface_params.iter().enumerate() {
if let (Some(expected), Some(actual)) = (
&iface_param.type_expr,
impl_method.param_types.get(i).and_then(|t| t.as_ref()),
) {
if !self.types_compatible(expected, actual, scope) {
return Some(format!(
"method '{}' parameter {} has type '{}', expected '{}'",
iface_method.name,
i + 1,
format_type(actual),
format_type(expected),
));
}
}
}
if let (Some(expected_ret), Some(actual_ret)) =
(&iface_method.return_type, &impl_method.return_type)
{
if !self.types_compatible(expected_ret, actual_ret, scope) {
return Some(format!(
"method '{}' returns '{}', expected '{}'",
iface_method.name,
format_type(actual_ret),
format_type(expected_ret),
));
}
}
}
None
}
fn extract_type_bindings(
param_type: &TypeExpr,
arg_type: &TypeExpr,
type_params: &std::collections::BTreeSet<String>,
bindings: &mut BTreeMap<String, String>,
) {
match (param_type, arg_type) {
(TypeExpr::Named(param_name), TypeExpr::Named(concrete))
if type_params.contains(param_name) =>
{
bindings
.entry(param_name.clone())
.or_insert(concrete.clone());
}
(TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
Self::extract_type_bindings(p_inner, a_inner, type_params, bindings);
}
(TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
Self::extract_type_bindings(pk, ak, type_params, bindings);
Self::extract_type_bindings(pv, av, type_params, bindings);
}
_ => {}
}
}
fn extract_nil_narrowing(
condition: &SNode,
scope: &TypeScope,
) -> Option<(String, InferredType)> {
if let Node::BinaryOp { op, left, right } = &condition.node {
if op == "!=" {
let (var_node, nil_node) = if matches!(right.node, Node::NilLiteral) {
(left, right)
} else if matches!(left.node, Node::NilLiteral) {
(right, left)
} else {
return None;
};
let _ = nil_node;
if let Node::Identifier(name) = &var_node.node {
if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
let narrowed: Vec<TypeExpr> = members
.iter()
.filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
.cloned()
.collect();
return if narrowed.len() == 1 {
Some((name.clone(), Some(narrowed.into_iter().next().unwrap())))
} else if narrowed.is_empty() {
None
} else {
Some((name.clone(), Some(TypeExpr::Union(narrowed))))
};
}
}
}
}
None
}
fn check_match_exhaustiveness(
&mut self,
value: &SNode,
arms: &[MatchArm],
scope: &TypeScope,
span: Span,
) {
let enum_name = match &value.node {
Node::PropertyAccess { object, property } if property == "variant" => {
match self.infer_type(object, scope) {
Some(TypeExpr::Named(name)) => {
if scope.get_enum(&name).is_some() {
Some(name)
} else {
None
}
}
_ => None,
}
}
_ => {
match self.infer_type(value, scope) {
Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
_ => None,
}
}
};
let Some(enum_name) = enum_name else {
return;
};
let Some(variants) = scope.get_enum(&enum_name) else {
return;
};
let mut covered: Vec<String> = Vec::new();
let mut has_wildcard = false;
for arm in arms {
match &arm.pattern.node {
Node::StringLiteral(s) => covered.push(s.clone()),
Node::Identifier(name) if name == "_" || !variants.contains(name) => {
has_wildcard = true;
}
Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
Node::PropertyAccess { property, .. } => covered.push(property.clone()),
_ => {
has_wildcard = true;
}
}
}
if has_wildcard {
return;
}
let missing: Vec<&String> = variants.iter().filter(|v| !covered.contains(v)).collect();
if !missing.is_empty() {
let missing_str = missing
.iter()
.map(|s| format!("\"{}\"", s))
.collect::<Vec<_>>()
.join(", ");
self.warning_at(
format!(
"Non-exhaustive match on enum {}: missing variants {}",
enum_name, missing_str
),
span,
);
}
}
fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
if let Some(sig) = scope.get_fn(name).cloned() {
if !has_spread
&& !is_builtin(name)
&& (args.len() < sig.required_params || args.len() > sig.params.len())
{
let expected = if sig.required_params == sig.params.len() {
format!("{}", sig.params.len())
} else {
format!("{}-{}", sig.required_params, sig.params.len())
};
self.warning_at(
format!(
"Function '{}' expects {} arguments, got {}",
name,
expected,
args.len()
),
span,
);
}
let call_scope = if sig.type_param_names.is_empty() {
scope.clone()
} else {
let mut s = scope.child();
for tp_name in &sig.type_param_names {
s.generic_type_params.insert(tp_name.clone());
}
s
};
for (i, (arg, (param_name, param_type))) in
args.iter().zip(sig.params.iter()).enumerate()
{
if let Some(expected) = param_type {
let actual = self.infer_type(arg, scope);
if let Some(actual) = &actual {
if !self.types_compatible(expected, actual, &call_scope) {
self.error_at(
format!(
"Argument {} ('{}'): expected {}, got {}",
i + 1,
param_name,
format_type(expected),
format_type(actual)
),
arg.span,
);
}
}
}
}
if !sig.where_clauses.is_empty() {
let mut type_bindings: BTreeMap<String, String> = BTreeMap::new();
let type_param_set: std::collections::BTreeSet<String> =
sig.type_param_names.iter().cloned().collect();
for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
if let Some(param_ty) = param_type {
if let Some(arg_ty) = self.infer_type(arg, scope) {
Self::extract_type_bindings(
param_ty,
&arg_ty,
&type_param_set,
&mut type_bindings,
);
}
}
}
for (type_param, bound) in &sig.where_clauses {
if let Some(concrete_type) = type_bindings.get(type_param) {
if let Some(reason) =
self.interface_mismatch_reason(concrete_type, bound, scope)
{
self.warning_at(
format!(
"Type '{}' does not satisfy interface '{}': {} \
(required by constraint `where {}: {}`)",
concrete_type, bound, reason, type_param, bound
),
span,
);
}
}
}
}
}
for arg in args {
self.check_node(arg, scope);
}
}
fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
match &snode.node {
Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
Node::StringLiteral(_) | Node::InterpolatedString(_) => {
Some(TypeExpr::Named("string".into()))
}
Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
Node::ListLiteral(_) => Some(TypeExpr::Named("list".into())),
Node::DictLiteral(entries) => {
let mut fields = Vec::new();
let mut all_string_keys = true;
for entry in entries {
if let Node::StringLiteral(key) = &entry.key.node {
let val_type = self
.infer_type(&entry.value, scope)
.unwrap_or(TypeExpr::Named("nil".into()));
fields.push(ShapeField {
name: key.clone(),
type_expr: val_type,
optional: false,
});
} else {
all_string_keys = false;
break;
}
}
if all_string_keys && !fields.is_empty() {
Some(TypeExpr::Shape(fields))
} else {
Some(TypeExpr::Named("dict".into()))
}
}
Node::Closure { params, body, .. } => {
let all_typed = params.iter().all(|p| p.type_expr.is_some());
if all_typed && !params.is_empty() {
let param_types: Vec<TypeExpr> =
params.iter().filter_map(|p| p.type_expr.clone()).collect();
let ret = body.last().and_then(|last| self.infer_type(last, scope));
if let Some(ret_type) = ret {
return Some(TypeExpr::FnType {
params: param_types,
return_type: Box::new(ret_type),
});
}
}
Some(TypeExpr::Named("closure".into()))
}
Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
Node::FunctionCall { name, .. } => {
if scope.get_struct(name).is_some() {
return Some(TypeExpr::Named(name.clone()));
}
if let Some(sig) = scope.get_fn(name) {
return sig.return_type.clone();
}
builtin_return_type(name)
}
Node::BinaryOp { op, left, right } => {
let lt = self.infer_type(left, scope);
let rt = self.infer_type(right, scope);
infer_binary_op_type(op, <, &rt)
}
Node::UnaryOp { op, operand } => {
let t = self.infer_type(operand, scope);
match op.as_str() {
"!" => Some(TypeExpr::Named("bool".into())),
"-" => t, _ => None,
}
}
Node::Ternary {
true_expr,
false_expr,
..
} => {
let tt = self.infer_type(true_expr, scope);
let ft = self.infer_type(false_expr, scope);
match (&tt, &ft) {
(Some(a), Some(b)) if a == b => tt,
(Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
(Some(_), None) => tt,
(None, Some(_)) => ft,
(None, None) => None,
}
}
Node::EnumConstruct { enum_name, .. } => Some(TypeExpr::Named(enum_name.clone())),
Node::PropertyAccess { object, property } => {
if let Node::Identifier(name) = &object.node {
if scope.get_enum(name).is_some() {
return Some(TypeExpr::Named(name.clone()));
}
}
if property == "variant" {
let obj_type = self.infer_type(object, scope);
if let Some(TypeExpr::Named(name)) = &obj_type {
if scope.get_enum(name).is_some() {
return Some(TypeExpr::Named("string".into()));
}
}
}
let obj_type = self.infer_type(object, scope);
if let Some(TypeExpr::Shape(fields)) = &obj_type {
if let Some(field) = fields.iter().find(|f| f.name == *property) {
return Some(field.type_expr.clone());
}
}
None
}
Node::SubscriptAccess { object, index } => {
let obj_type = self.infer_type(object, scope);
match &obj_type {
Some(TypeExpr::List(inner)) => Some(*inner.clone()),
Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
Some(TypeExpr::Shape(fields)) => {
if let Node::StringLiteral(key) = &index.node {
fields
.iter()
.find(|f| &f.name == key)
.map(|f| f.type_expr.clone())
} else {
None
}
}
Some(TypeExpr::Named(n)) if n == "list" => None,
Some(TypeExpr::Named(n)) if n == "dict" => None,
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
_ => None,
}
}
Node::SliceAccess { object, .. } => {
let obj_type = self.infer_type(object, scope);
match &obj_type {
Some(TypeExpr::List(_)) => obj_type,
Some(TypeExpr::Named(n)) if n == "list" => obj_type,
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
_ => None,
}
}
Node::MethodCall { object, method, .. }
| Node::OptionalMethodCall { object, method, .. } => {
let obj_type = self.infer_type(object, scope);
let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
|| matches!(&obj_type, Some(TypeExpr::DictType(..)))
|| matches!(&obj_type, Some(TypeExpr::Shape(_)));
match method.as_str() {
"contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
Some(TypeExpr::Named("bool".into()))
}
"count" | "index_of" => Some(TypeExpr::Named("int".into())),
"trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
| "pad_left" | "pad_right" | "repeat" | "join" => {
Some(TypeExpr::Named("string".into()))
}
"split" | "chars" => Some(TypeExpr::Named("list".into())),
"filter" => {
if is_dict {
Some(TypeExpr::Named("dict".into()))
} else {
Some(TypeExpr::Named("list".into()))
}
}
"map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
"reduce" | "find" | "first" | "last" => None,
"keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
"merge" | "map_values" | "rekey" | "map_keys" => {
if let Some(TypeExpr::DictType(_, v)) = &obj_type {
Some(TypeExpr::DictType(
Box::new(TypeExpr::Named("string".into())),
v.clone(),
))
} else {
Some(TypeExpr::Named("dict".into()))
}
}
"to_string" => Some(TypeExpr::Named("string".into())),
"to_int" => Some(TypeExpr::Named("int".into())),
"to_float" => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
Node::TryOperator { operand } => {
match self.infer_type(operand, scope) {
Some(TypeExpr::Named(name)) if name == "Result" => None, _ => None,
}
}
_ => None,
}
}
fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
if let TypeExpr::Named(name) = expected {
if scope.is_generic_type_param(name) {
return true;
}
}
if let TypeExpr::Named(name) = actual {
if scope.is_generic_type_param(name) {
return true;
}
}
let expected = self.resolve_alias(expected, scope);
let actual = self.resolve_alias(actual, scope);
if let TypeExpr::Named(iface_name) = &expected {
if scope.get_interface(iface_name).is_some() {
if let TypeExpr::Named(type_name) = &actual {
return self.satisfies_interface(type_name, iface_name, scope);
}
return false;
}
}
match (&expected, &actual) {
(TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
(TypeExpr::Union(members), actual_type) => members
.iter()
.any(|m| self.types_compatible(m, actual_type, scope)),
(expected_type, TypeExpr::Union(members)) => members
.iter()
.all(|m| self.types_compatible(expected_type, m, scope)),
(TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
(TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
(TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
if expected_field.optional {
return true;
}
af.iter().any(|actual_field| {
actual_field.name == expected_field.name
&& self.types_compatible(
&expected_field.type_expr,
&actual_field.type_expr,
scope,
)
})
}),
(TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
keys_ok
&& af
.iter()
.all(|f| self.types_compatible(ev, &f.type_expr, scope))
}
(TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
(TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
self.types_compatible(expected_inner, actual_inner, scope)
}
(TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
(TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
(TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
self.types_compatible(ek, ak, scope) && self.types_compatible(ev, av, scope)
}
(TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
(TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
(
TypeExpr::FnType {
params: ep,
return_type: er,
},
TypeExpr::FnType {
params: ap,
return_type: ar,
},
) => {
ep.len() == ap.len()
&& ep
.iter()
.zip(ap.iter())
.all(|(e, a)| self.types_compatible(e, a, scope))
&& self.types_compatible(er, ar, scope)
}
(TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
(TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
_ => false,
}
}
fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
if let TypeExpr::Named(name) = ty {
if let Some(resolved) = scope.resolve_type(name) {
return resolved.clone();
}
}
ty.clone()
}
fn error_at(&mut self, message: String, span: Span) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Error,
span: Some(span),
help: None,
});
}
#[allow(dead_code)]
fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Error,
span: Some(span),
help: Some(help),
});
}
fn warning_at(&mut self, message: String, span: Span) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Warning,
span: Some(span),
help: None,
});
}
#[allow(dead_code)]
fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Warning,
span: Some(span),
help: Some(help),
});
}
}
impl Default for TypeChecker {
fn default() -> Self {
Self::new()
}
}
fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
match op {
"==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
Some(TypeExpr::Named("bool".into()))
}
"+" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
("string", _) => Some(TypeExpr::Named("string".into())),
("list", "list") => Some(TypeExpr::Named("list".into())),
("dict", "dict") => Some(TypeExpr::Named("dict".into())),
_ => Some(TypeExpr::Named("string".into())),
}
}
_ => None,
},
"-" | "*" | "/" | "%" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"??" => match (left, right) {
(Some(TypeExpr::Union(members)), _) => {
let non_nil: Vec<_> = members
.iter()
.filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
.cloned()
.collect();
if non_nil.len() == 1 {
Some(non_nil[0].clone())
} else if non_nil.is_empty() {
right.clone()
} else {
Some(TypeExpr::Union(non_nil))
}
}
_ => right.clone(),
},
"|>" => None,
_ => None,
}
}
pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
let mut details = Vec::new();
for field in ef {
if field.optional {
continue;
}
match af.iter().find(|f| f.name == field.name) {
None => details.push(format!(
"missing field '{}' ({})",
field.name,
format_type(&field.type_expr)
)),
Some(actual_field) => {
let e_str = format_type(&field.type_expr);
let a_str = format_type(&actual_field.type_expr);
if e_str != a_str {
details.push(format!(
"field '{}' has type {}, expected {}",
field.name, a_str, e_str
));
}
}
}
}
if details.is_empty() {
None
} else {
Some(details.join("; "))
}
} else {
None
}
}
pub fn format_type(ty: &TypeExpr) -> String {
match ty {
TypeExpr::Named(n) => n.clone(),
TypeExpr::Union(types) => types
.iter()
.map(format_type)
.collect::<Vec<_>>()
.join(" | "),
TypeExpr::Shape(fields) => {
let inner: Vec<String> = fields
.iter()
.map(|f| {
let opt = if f.optional { "?" } else { "" };
format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
})
.collect();
format!("{{{}}}", inner.join(", "))
}
TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
TypeExpr::FnType {
params,
return_type,
} => {
let params_str = params
.iter()
.map(format_type)
.collect::<Vec<_>>()
.join(", ");
format!("fn({}) -> {}", params_str, format_type(return_type))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Parser;
use harn_lexer::Lexer;
fn check_source(source: &str) -> Vec<TypeDiagnostic> {
let mut lexer = Lexer::new(source);
let tokens = lexer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
let program = parser.parse().unwrap();
TypeChecker::new().check(&program)
}
fn errors(source: &str) -> Vec<String> {
check_source(source)
.into_iter()
.filter(|d| d.severity == DiagnosticSeverity::Error)
.map(|d| d.message)
.collect()
}
#[test]
fn test_no_errors_for_untyped_code() {
let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
assert!(errs.is_empty());
}
#[test]
fn test_correct_typed_let() {
let errs = errors("pipeline t(task) { let x: int = 42 }");
assert!(errs.is_empty());
}
#[test]
fn test_type_mismatch_let() {
let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Type mismatch"));
assert!(errs[0].contains("int"));
assert!(errs[0].contains("string"));
}
#[test]
fn test_correct_typed_fn() {
let errs = errors(
"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
);
assert!(errs.is_empty());
}
#[test]
fn test_fn_arg_type_mismatch() {
let errs = errors(
r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
add("hello", 2) }"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Argument 1"));
assert!(errs[0].contains("expected int"));
}
#[test]
fn test_return_type_mismatch() {
let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Return type mismatch"));
}
#[test]
fn test_union_type_compatible() {
let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
assert!(errs.is_empty());
}
#[test]
fn test_union_type_mismatch() {
let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Type mismatch"));
}
#[test]
fn test_type_inference_propagation() {
let errs = errors(
r#"pipeline t(task) {
fn add(a: int, b: int) -> int { return a + b }
let result: string = add(1, 2)
}"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Type mismatch"));
assert!(errs[0].contains("string"));
assert!(errs[0].contains("int"));
}
#[test]
fn test_builtin_return_type_inference() {
let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("string"));
assert!(errs[0].contains("int"));
}
#[test]
fn test_workflow_and_transcript_builtins_are_known() {
let errs = errors(
r#"pipeline t(task) {
let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
let report: dict = workflow_policy_report(flow, {tools: tool_registry(), capabilities: {workspace: ["read_text"]}})
let run: dict = workflow_execute("task", flow, [], {})
let tree: dict = load_run_tree("run.json")
let fixture: dict = run_record_fixture(run?.run)
let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
let diff: dict = run_record_diff(run?.run, run?.run)
let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
let suite_report: dict = eval_suite_run(manifest)
let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
let selection: dict = artifact_editor_selection("src/main.rs", "main")
let verify: dict = artifact_verification_result("verify", "ok")
let test_result: dict = artifact_test_result("tests", "pass")
let cmd: dict = artifact_command_result("cargo test", {status: 0})
let patch: dict = artifact_diff("src/main.rs", "old", "new")
let git: dict = artifact_git_diff("diff --git a b")
let review: dict = artifact_diff_review(patch, "review me")
let decision: dict = artifact_review_decision(review, "accepted")
let proposal: dict = artifact_patch_proposal(review, "*** Begin Patch")
let bundle: dict = artifact_verification_bundle("checks", [{name: "fmt", ok: true}])
let apply: dict = artifact_apply_intent(review, "apply")
let transcript = transcript_reset({metadata: {source: "test"}})
let visible: string = transcript_render_visible(transcript_archive(transcript))
let events: list = transcript_events(transcript)
let context: string = artifact_context([], {max_artifacts: 1})
println(report)
println(run)
println(tree)
println(fixture)
println(suite)
println(diff)
println(manifest)
println(suite_report)
println(wf)
println(snap)
println(selection)
println(verify)
println(test_result)
println(cmd)
println(patch)
println(git)
println(review)
println(decision)
println(proposal)
println(bundle)
println(apply)
println(visible)
println(events)
println(context)
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_binary_op_type_inference() {
let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
assert_eq!(errs.len(), 1);
}
#[test]
fn test_comparison_returns_bool() {
let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
assert!(errs.is_empty());
}
#[test]
fn test_int_float_promotion() {
let errs = errors("pipeline t(task) { let x: float = 42 }");
assert!(errs.is_empty());
}
#[test]
fn test_untyped_code_no_errors() {
let errs = errors(
r#"pipeline t(task) {
fn process(data) {
let result = data + " processed"
return result
}
log(process("hello"))
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_type_alias() {
let errs = errors(
r#"pipeline t(task) {
type Name = string
let x: Name = "hello"
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_type_alias_mismatch() {
let errs = errors(
r#"pipeline t(task) {
type Name = string
let x: Name = 42
}"#,
);
assert_eq!(errs.len(), 1);
}
#[test]
fn test_assignment_type_check() {
let errs = errors(
r#"pipeline t(task) {
var x: int = 0
x = "hello"
}"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("cannot assign string"));
}
#[test]
fn test_covariance_int_to_float_in_fn() {
let errs = errors(
"pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
);
assert!(errs.is_empty());
}
#[test]
fn test_covariance_return_type() {
let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
assert!(errs.is_empty());
}
#[test]
fn test_no_contravariance_float_to_int() {
let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
assert_eq!(errs.len(), 1);
}
fn warnings(source: &str) -> Vec<String> {
check_source(source)
.into_iter()
.filter(|d| d.severity == DiagnosticSeverity::Warning)
.map(|d| d.message)
.collect()
}
#[test]
fn test_exhaustive_match_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c = Color.Red
match c.variant {
"Red" -> { log("r") }
"Green" -> { log("g") }
"Blue" -> { log("b") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert!(exhaustive_warns.is_empty());
}
#[test]
fn test_non_exhaustive_match_warning() {
let warns = warnings(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c = Color.Red
match c.variant {
"Red" -> { log("r") }
"Green" -> { log("g") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert_eq!(exhaustive_warns.len(), 1);
assert!(exhaustive_warns[0].contains("Blue"));
}
#[test]
fn test_non_exhaustive_multiple_missing() {
let warns = warnings(
r#"pipeline t(task) {
enum Status { Active, Inactive, Pending }
let s = Status.Active
match s.variant {
"Active" -> { log("a") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert_eq!(exhaustive_warns.len(), 1);
assert!(exhaustive_warns[0].contains("Inactive"));
assert!(exhaustive_warns[0].contains("Pending"));
}
#[test]
fn test_enum_construct_type_inference() {
let errs = errors(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c: Color = Color.Red
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_nil_coalescing_strips_nil() {
let errs = errors(
r#"pipeline t(task) {
let x: string | nil = nil
let y: string = x ?? "default"
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_shape_mismatch_detail_missing_field() {
let errs = errors(
r#"pipeline t(task) {
let x: {name: string, age: int} = {name: "hello"}
}"#,
);
assert_eq!(errs.len(), 1);
assert!(
errs[0].contains("missing field 'age'"),
"expected detail about missing field, got: {}",
errs[0]
);
}
#[test]
fn test_shape_mismatch_detail_wrong_type() {
let errs = errors(
r#"pipeline t(task) {
let x: {name: string, age: int} = {name: 42, age: 10}
}"#,
);
assert_eq!(errs.len(), 1);
assert!(
errs[0].contains("field 'name' has type int, expected string"),
"expected detail about wrong type, got: {}",
errs[0]
);
}
#[test]
fn test_match_pattern_string_against_int() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
"hello" -> { log("bad") }
42 -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching int against string literal"));
}
#[test]
fn test_match_pattern_int_against_string() {
let warns = warnings(
r#"pipeline t(task) {
let x: string = "hello"
match x {
42 -> { log("bad") }
"hello" -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching string against int literal"));
}
#[test]
fn test_match_pattern_bool_against_int() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
true -> { log("bad") }
42 -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching int against bool literal"));
}
#[test]
fn test_match_pattern_float_against_string() {
let warns = warnings(
r#"pipeline t(task) {
let x: string = "hello"
match x {
3.14 -> { log("bad") }
"hello" -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching string against float literal"));
}
#[test]
fn test_match_pattern_int_against_float_ok() {
let warns = warnings(
r#"pipeline t(task) {
let x: float = 3.14
match x {
42 -> { log("ok") }
_ -> { log("default") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_float_against_int_ok() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
3.14 -> { log("close") }
_ -> { log("default") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_correct_types_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
1 -> { log("one") }
2 -> { log("two") }
_ -> { log("other") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_wildcard_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
_ -> { log("catch all") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_untyped_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x = some_unknown_fn()
match x {
"hello" -> { log("string") }
42 -> { log("int") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
fn iface_warns(source: &str) -> Vec<String> {
warnings(source)
.into_iter()
.filter(|w| w.contains("does not satisfy interface"))
.collect()
}
#[test]
fn test_interface_constraint_return_type_mismatch() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn size(self) -> string { return "nope" }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'size' returns 'string', expected 'int'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_param_type_mismatch() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Processor {
fn process(self, x: int) -> string
}
struct MyProc { name: string }
impl MyProc {
fn process(self, x: string) -> string { return x }
}
fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
run_proc(MyProc({name: "a"}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_missing_method() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn area(self) -> int { return self.width }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("missing method 'size'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_param_count_mismatch() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Doubler {
fn double(self, x: int) -> int
}
struct Bad { v: int }
impl Bad {
fn double(self) -> int { return self.v * 2 }
}
fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
run_double(Bad({v: 5}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_satisfied() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int, height: int }
impl Box {
fn size(self) -> int { return self.width * self.height }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3, height: 4}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_constraint_untyped_impl_compatible() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn size(self) { return self.width }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_constraint_int_float_covariance() {
let warns = iface_warns(
r#"pipeline t(task) {
interface Measurable {
fn value(self) -> float
}
struct Gauge { v: int }
impl Gauge {
fn value(self) -> int { return self.v }
}
fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
read_val(Gauge({v: 42}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
}