use std::collections::BTreeMap;
use crate::ast::*;
use crate::builtin_signatures;
use harn_lexer::{FixEdit, Span};
#[derive(Debug, Clone)]
pub struct InlayHintInfo {
pub line: usize,
pub column: usize,
pub label: String,
}
#[derive(Debug, Clone)]
pub struct TypeDiagnostic {
pub message: String,
pub severity: DiagnosticSeverity,
pub span: Option<Span>,
pub help: Option<String>,
pub fix: Option<Vec<FixEdit>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagnosticSeverity {
Error,
Warning,
}
type InferredType = Option<TypeExpr>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Polarity {
Covariant,
Contravariant,
Invariant,
}
impl Polarity {
fn compose(self, child: Variance) -> Polarity {
match (self, child) {
(_, Variance::Invariant) | (Polarity::Invariant, _) => Polarity::Invariant,
(p, Variance::Covariant) => p,
(Polarity::Covariant, Variance::Contravariant) => Polarity::Contravariant,
(Polarity::Contravariant, Variance::Contravariant) => Polarity::Covariant,
}
}
}
#[derive(Debug, Clone)]
struct EnumDeclInfo {
type_params: Vec<TypeParam>,
variants: Vec<EnumVariant>,
}
#[derive(Debug, Clone)]
struct StructDeclInfo {
type_params: Vec<TypeParam>,
fields: Vec<StructField>,
}
#[derive(Debug, Clone)]
struct InterfaceDeclInfo {
type_params: Vec<TypeParam>,
associated_types: Vec<(String, Option<TypeExpr>)>,
methods: Vec<InterfaceMethod>,
}
#[derive(Debug, Clone)]
struct TypeScope {
vars: BTreeMap<String, InferredType>,
functions: BTreeMap<String, FnSignature>,
type_aliases: BTreeMap<String, TypeExpr>,
enums: BTreeMap<String, EnumDeclInfo>,
interfaces: BTreeMap<String, InterfaceDeclInfo>,
structs: BTreeMap<String, StructDeclInfo>,
impl_methods: BTreeMap<String, Vec<ImplMethodSig>>,
generic_type_params: std::collections::BTreeSet<String>,
where_constraints: BTreeMap<String, String>,
mutable_vars: std::collections::BTreeSet<String>,
narrowed_vars: BTreeMap<String, InferredType>,
schema_bindings: BTreeMap<String, InferredType>,
untyped_sources: BTreeMap<String, String>,
unknown_ruled_out: BTreeMap<String, Vec<String>>,
parent: Option<Box<TypeScope>>,
}
#[derive(Debug, Clone)]
struct ImplMethodSig {
name: String,
param_count: usize,
param_types: Vec<Option<TypeExpr>>,
return_type: Option<TypeExpr>,
}
#[derive(Debug, Clone)]
struct FnSignature {
params: Vec<(String, InferredType)>,
return_type: InferredType,
type_param_names: Vec<String>,
required_params: usize,
where_clauses: Vec<(String, String)>,
has_rest: bool,
}
impl TypeScope {
fn new() -> Self {
let mut scope = Self {
vars: BTreeMap::new(),
functions: BTreeMap::new(),
type_aliases: BTreeMap::new(),
enums: BTreeMap::new(),
interfaces: BTreeMap::new(),
structs: BTreeMap::new(),
impl_methods: BTreeMap::new(),
generic_type_params: std::collections::BTreeSet::new(),
where_constraints: BTreeMap::new(),
mutable_vars: std::collections::BTreeSet::new(),
narrowed_vars: BTreeMap::new(),
schema_bindings: BTreeMap::new(),
untyped_sources: BTreeMap::new(),
unknown_ruled_out: BTreeMap::new(),
parent: None,
};
scope.enums.insert(
"Result".into(),
EnumDeclInfo {
type_params: vec![
TypeParam {
name: "T".into(),
variance: Variance::Covariant,
},
TypeParam {
name: "E".into(),
variance: Variance::Covariant,
},
],
variants: vec![
EnumVariant {
name: "Ok".into(),
fields: vec![TypedParam {
name: "value".into(),
type_expr: Some(TypeExpr::Named("T".into())),
default_value: None,
rest: false,
}],
},
EnumVariant {
name: "Err".into(),
fields: vec![TypedParam {
name: "error".into(),
type_expr: Some(TypeExpr::Named("E".into())),
default_value: None,
rest: false,
}],
},
],
},
);
scope
}
fn child(&self) -> Self {
Self {
vars: BTreeMap::new(),
functions: BTreeMap::new(),
type_aliases: BTreeMap::new(),
enums: BTreeMap::new(),
interfaces: BTreeMap::new(),
structs: BTreeMap::new(),
impl_methods: BTreeMap::new(),
generic_type_params: std::collections::BTreeSet::new(),
where_constraints: BTreeMap::new(),
mutable_vars: std::collections::BTreeSet::new(),
narrowed_vars: BTreeMap::new(),
schema_bindings: BTreeMap::new(),
untyped_sources: BTreeMap::new(),
unknown_ruled_out: BTreeMap::new(),
parent: Some(Box::new(self.clone())),
}
}
fn get_var(&self, name: &str) -> Option<&InferredType> {
self.vars
.get(name)
.or_else(|| self.parent.as_ref()?.get_var(name))
}
fn add_unknown_ruled_out(&mut self, var_name: &str, type_name: &str) {
if !self.unknown_ruled_out.contains_key(var_name) {
let inherited = self.lookup_unknown_ruled_out(var_name);
self.unknown_ruled_out
.insert(var_name.to_string(), inherited);
}
let entry = self
.unknown_ruled_out
.get_mut(var_name)
.expect("just inserted");
if !entry.iter().any(|t| t == type_name) {
entry.push(type_name.to_string());
}
}
fn lookup_unknown_ruled_out(&self, var_name: &str) -> Vec<String> {
if let Some(list) = self.unknown_ruled_out.get(var_name) {
list.clone()
} else if let Some(parent) = &self.parent {
parent.lookup_unknown_ruled_out(var_name)
} else {
Vec::new()
}
}
fn collect_unknown_ruled_out(&self) -> BTreeMap<String, Vec<String>> {
let mut out: BTreeMap<String, Vec<String>> = BTreeMap::new();
self.collect_unknown_ruled_out_inner(&mut out);
out
}
fn collect_unknown_ruled_out_inner(&self, acc: &mut BTreeMap<String, Vec<String>>) {
if let Some(parent) = &self.parent {
parent.collect_unknown_ruled_out_inner(acc);
}
for (name, list) in &self.unknown_ruled_out {
acc.insert(name.clone(), list.clone());
}
}
fn clear_unknown_ruled_out(&mut self, var_name: &str) {
self.unknown_ruled_out
.insert(var_name.to_string(), Vec::new());
}
fn get_fn(&self, name: &str) -> Option<&FnSignature> {
self.functions
.get(name)
.or_else(|| self.parent.as_ref()?.get_fn(name))
}
fn get_schema_binding(&self, name: &str) -> Option<&InferredType> {
self.schema_bindings
.get(name)
.or_else(|| self.parent.as_ref()?.get_schema_binding(name))
}
fn resolve_type(&self, name: &str) -> Option<&TypeExpr> {
self.type_aliases
.get(name)
.or_else(|| self.parent.as_ref()?.resolve_type(name))
}
fn is_generic_type_param(&self, name: &str) -> bool {
self.generic_type_params.contains(name)
|| self
.parent
.as_ref()
.is_some_and(|p| p.is_generic_type_param(name))
}
fn get_where_constraint(&self, type_param: &str) -> Option<&str> {
self.where_constraints
.get(type_param)
.map(|s| s.as_str())
.or_else(|| {
self.parent
.as_ref()
.and_then(|p| p.get_where_constraint(type_param))
})
}
fn get_enum(&self, name: &str) -> Option<&EnumDeclInfo> {
self.enums
.get(name)
.or_else(|| self.parent.as_ref()?.get_enum(name))
}
fn get_interface(&self, name: &str) -> Option<&InterfaceDeclInfo> {
self.interfaces
.get(name)
.or_else(|| self.parent.as_ref()?.get_interface(name))
}
fn get_struct(&self, name: &str) -> Option<&StructDeclInfo> {
self.structs
.get(name)
.or_else(|| self.parent.as_ref()?.get_struct(name))
}
fn get_impl_methods(&self, name: &str) -> Option<&Vec<ImplMethodSig>> {
self.impl_methods
.get(name)
.or_else(|| self.parent.as_ref()?.get_impl_methods(name))
}
fn variance_of(&self, name: &str) -> Option<Vec<Variance>> {
if let Some(info) = self.get_enum(name) {
return Some(info.type_params.iter().map(|tp| tp.variance).collect());
}
if let Some(info) = self.get_struct(name) {
return Some(info.type_params.iter().map(|tp| tp.variance).collect());
}
if let Some(info) = self.get_interface(name) {
return Some(info.type_params.iter().map(|tp| tp.variance).collect());
}
None
}
fn define_var(&mut self, name: &str, ty: InferredType) {
self.vars.insert(name.to_string(), ty);
}
fn define_var_mutable(&mut self, name: &str, ty: InferredType) {
self.vars.insert(name.to_string(), ty);
self.mutable_vars.insert(name.to_string());
}
fn define_schema_binding(&mut self, name: &str, ty: InferredType) {
self.schema_bindings.insert(name.to_string(), ty);
}
fn is_untyped_source(&self, name: &str) -> Option<&str> {
if let Some(source) = self.untyped_sources.get(name) {
if source.is_empty() {
return None; }
return Some(source.as_str());
}
self.parent.as_ref()?.is_untyped_source(name)
}
fn mark_untyped_source(&mut self, name: &str, source: &str) {
self.untyped_sources
.insert(name.to_string(), source.to_string());
}
fn clear_untyped_source(&mut self, name: &str) {
self.untyped_sources.insert(name.to_string(), String::new());
}
fn is_mutable(&self, name: &str) -> bool {
self.mutable_vars.contains(name) || self.parent.as_ref().is_some_and(|p| p.is_mutable(name))
}
fn define_fn(&mut self, name: &str, sig: FnSignature) {
self.functions.insert(name.to_string(), sig);
}
}
#[derive(Debug, Clone, Default)]
struct Refinements {
truthy: Vec<(String, InferredType)>,
falsy: Vec<(String, InferredType)>,
truthy_ruled_out: Vec<(String, String)>,
falsy_ruled_out: Vec<(String, String)>,
}
impl Refinements {
fn empty() -> Self {
Self::default()
}
fn inverted(self) -> Self {
Self {
truthy: self.falsy,
falsy: self.truthy,
truthy_ruled_out: self.falsy_ruled_out,
falsy_ruled_out: self.truthy_ruled_out,
}
}
fn apply_truthy(&self, scope: &mut TypeScope) {
apply_refinements(scope, &self.truthy);
for (var, ty) in &self.truthy_ruled_out {
scope.add_unknown_ruled_out(var, ty);
}
}
fn apply_falsy(&self, scope: &mut TypeScope) {
apply_refinements(scope, &self.falsy);
for (var, ty) in &self.falsy_ruled_out {
scope.add_unknown_ruled_out(var, ty);
}
}
}
fn builtin_return_type(name: &str) -> InferredType {
builtin_signatures::builtin_return_type(name)
}
fn is_builtin(name: &str) -> bool {
builtin_signatures::is_builtin(name)
}
pub struct TypeChecker {
diagnostics: Vec<TypeDiagnostic>,
scope: TypeScope,
source: Option<String>,
hints: Vec<InlayHintInfo>,
strict_types: bool,
fn_depth: usize,
}
impl TypeChecker {
fn wildcard_type() -> TypeExpr {
TypeExpr::Named("_".into())
}
fn is_wildcard_type(ty: &TypeExpr) -> bool {
matches!(ty, TypeExpr::Named(name) if name == "_")
}
fn base_type_name(ty: &TypeExpr) -> Option<&str> {
match ty {
TypeExpr::Named(name) => Some(name.as_str()),
TypeExpr::Applied { name, .. } => Some(name.as_str()),
_ => None,
}
}
pub fn new() -> Self {
Self {
diagnostics: Vec::new(),
scope: TypeScope::new(),
source: None,
hints: Vec::new(),
strict_types: false,
fn_depth: 0,
}
}
pub fn with_strict_types(strict: bool) -> Self {
Self {
diagnostics: Vec::new(),
scope: TypeScope::new(),
source: None,
hints: Vec::new(),
strict_types: strict,
fn_depth: 0,
}
}
pub fn check_with_source(mut self, program: &[SNode], source: &str) -> Vec<TypeDiagnostic> {
self.source = Some(source.to_string());
self.check_inner(program).0
}
pub fn check_strict_with_source(
mut self,
program: &[SNode],
source: &str,
) -> Vec<TypeDiagnostic> {
self.source = Some(source.to_string());
self.check_inner(program).0
}
pub fn check(self, program: &[SNode]) -> Vec<TypeDiagnostic> {
self.check_inner(program).0
}
fn detect_boundary_source(value: &SNode, scope: &TypeScope) -> Option<String> {
match &value.node {
Node::FunctionCall { name, args } => {
if !builtin_signatures::is_untyped_boundary_source(name) {
return None;
}
if (name == "llm_call" || name == "llm_completion")
&& Self::extract_llm_schema_from_options(args, scope).is_some()
{
return None;
}
Some(name.clone())
}
Node::Identifier(name) => scope.is_untyped_source(name).map(|s| s.to_string()),
_ => None,
}
}
fn extract_llm_schema_from_options(args: &[SNode], scope: &TypeScope) -> Option<TypeExpr> {
let opts = args.get(2)?;
let entries = match &opts.node {
Node::DictLiteral(entries) => entries,
_ => return None,
};
for entry in entries {
let key = match &entry.key.node {
Node::StringLiteral(k) | Node::Identifier(k) => k.as_str(),
_ => continue,
};
if key == "schema" || key == "output_schema" {
return schema_type_expr_from_node(&entry.value, scope);
}
}
None
}
fn is_concrete_type(ty: &TypeExpr) -> bool {
matches!(
ty,
TypeExpr::Shape(_)
| TypeExpr::Applied { .. }
| TypeExpr::FnType { .. }
| TypeExpr::List(_)
| TypeExpr::Iter(_)
| TypeExpr::DictType(_, _)
) || matches!(ty, TypeExpr::Named(n) if n != "dict" && n != "any" && n != "_")
}
pub fn check_with_hints(
mut self,
program: &[SNode],
source: &str,
) -> (Vec<TypeDiagnostic>, Vec<InlayHintInfo>) {
self.source = Some(source.to_string());
self.check_inner(program)
}
fn check_inner(mut self, program: &[SNode]) -> (Vec<TypeDiagnostic>, Vec<InlayHintInfo>) {
Self::register_declarations_into(&mut self.scope, program);
for snode in program {
if let Node::Pipeline { body, .. } = &snode.node {
Self::register_declarations_into(&mut self.scope, body);
}
}
for snode in program {
match &snode.node {
Node::Pipeline { params, body, .. } => {
let mut child = self.scope.child();
for p in params {
child.define_var(p, None);
}
self.fn_depth += 1;
self.check_block(body, &mut child);
self.fn_depth -= 1;
}
Node::FnDecl {
name,
type_params,
params,
return_type,
where_clauses,
body,
..
} => {
let required_params =
params.iter().filter(|p| p.default_value.is_none()).count();
let sig = FnSignature {
params: params
.iter()
.map(|p| (p.name.clone(), p.type_expr.clone()))
.collect(),
return_type: return_type.clone(),
type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
required_params,
where_clauses: where_clauses
.iter()
.map(|wc| (wc.type_name.clone(), wc.bound.clone()))
.collect(),
has_rest: params.last().is_some_and(|p| p.rest),
};
self.scope.define_fn(name, sig);
self.check_fn_body(type_params, params, return_type, body, where_clauses);
}
_ => {
let mut scope = self.scope.clone();
self.check_node(snode, &mut scope);
for (name, ty) in scope.vars {
self.scope.vars.entry(name).or_insert(ty);
}
for name in scope.mutable_vars {
self.scope.mutable_vars.insert(name);
}
}
}
}
(self.diagnostics, self.hints)
}
fn register_declarations_into(scope: &mut TypeScope, nodes: &[SNode]) {
for snode in nodes {
match &snode.node {
Node::TypeDecl {
name,
type_params: _,
type_expr,
} => {
scope.type_aliases.insert(name.clone(), type_expr.clone());
}
Node::EnumDecl {
name,
type_params,
variants,
..
} => {
scope.enums.insert(
name.clone(),
EnumDeclInfo {
type_params: type_params.clone(),
variants: variants.clone(),
},
);
}
Node::InterfaceDecl {
name,
type_params,
associated_types,
methods,
} => {
scope.interfaces.insert(
name.clone(),
InterfaceDeclInfo {
type_params: type_params.clone(),
associated_types: associated_types.clone(),
methods: methods.clone(),
},
);
}
Node::StructDecl {
name,
type_params,
fields,
..
} => {
scope.structs.insert(
name.clone(),
StructDeclInfo {
type_params: type_params.clone(),
fields: fields.clone(),
},
);
}
Node::ImplBlock {
type_name, methods, ..
} => {
let sigs: Vec<ImplMethodSig> = methods
.iter()
.filter_map(|m| {
if let Node::FnDecl {
name,
params,
return_type,
..
} = &m.node
{
let non_self: Vec<_> =
params.iter().filter(|p| p.name != "self").collect();
let param_count = non_self.len();
let param_types: Vec<Option<TypeExpr>> =
non_self.iter().map(|p| p.type_expr.clone()).collect();
Some(ImplMethodSig {
name: name.clone(),
param_count,
param_types,
return_type: return_type.clone(),
})
} else {
None
}
})
.collect();
scope.impl_methods.insert(type_name.clone(), sigs);
}
_ => {}
}
}
}
fn check_block(&mut self, stmts: &[SNode], scope: &mut TypeScope) {
let mut definitely_exited = false;
for stmt in stmts {
if definitely_exited {
self.warning_at("unreachable code".to_string(), stmt.span);
break; }
self.check_node(stmt, scope);
if Self::stmt_definitely_exits(stmt) {
definitely_exited = true;
}
}
}
fn stmt_definitely_exits(stmt: &SNode) -> bool {
stmt_definitely_exits(stmt)
}
fn define_pattern_vars(pattern: &BindingPattern, scope: &mut TypeScope, mutable: bool) {
let define = |scope: &mut TypeScope, name: &str| {
if mutable {
scope.define_var_mutable(name, None);
} else {
scope.define_var(name, None);
}
};
match pattern {
BindingPattern::Identifier(name) => {
define(scope, name);
}
BindingPattern::Dict(fields) => {
for field in fields {
let name = field.alias.as_deref().unwrap_or(&field.key);
define(scope, name);
}
}
BindingPattern::List(elements) => {
for elem in elements {
define(scope, &elem.name);
}
}
BindingPattern::Pair(a, b) => {
define(scope, a);
define(scope, b);
}
}
}
fn check_pattern_defaults(&mut self, pattern: &BindingPattern, scope: &mut TypeScope) {
match pattern {
BindingPattern::Identifier(_) => {}
BindingPattern::Dict(fields) => {
for field in fields {
if let Some(default) = &field.default_value {
self.check_binops(default, scope);
}
}
}
BindingPattern::List(elements) => {
for elem in elements {
if let Some(default) = &elem.default_value {
self.check_binops(default, scope);
}
}
}
BindingPattern::Pair(_, _) => {}
}
}
fn check_node(&mut self, snode: &SNode, scope: &mut TypeScope) {
let span = snode.span;
match &snode.node {
Node::LetBinding {
pattern,
type_ann,
value,
} => {
self.check_binops(value, scope);
let inferred = self.infer_type(value, scope);
if let BindingPattern::Identifier(name) = pattern {
if let Some(expected) = type_ann {
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
let mut msg = format!(
"'{}' declared as {}, but assigned {}",
name,
format_type(expected),
format_type(actual)
);
if let Some(detail) = shape_mismatch_detail(expected, actual) {
msg.push_str(&format!(" ({})", detail));
}
self.error_at(msg, span);
}
}
}
if type_ann.is_none() {
if let Some(ref ty) = inferred {
if !is_obvious_type(value, ty) {
self.hints.push(InlayHintInfo {
line: span.line,
column: span.column + "let ".len() + name.len(),
label: format!(": {}", format_type(ty)),
});
}
}
}
let ty = type_ann.clone().or(inferred);
scope.define_var(name, ty);
scope.define_schema_binding(name, schema_type_expr_from_node(value, scope));
if self.strict_types {
if let Some(boundary) = Self::detect_boundary_source(value, scope) {
let has_concrete_ann =
type_ann.as_ref().is_some_and(Self::is_concrete_type);
if !has_concrete_ann {
scope.mark_untyped_source(name, &boundary);
}
}
}
} else {
self.check_pattern_defaults(pattern, scope);
Self::define_pattern_vars(pattern, scope, false);
}
}
Node::VarBinding {
pattern,
type_ann,
value,
} => {
self.check_binops(value, scope);
let inferred = self.infer_type(value, scope);
if let BindingPattern::Identifier(name) = pattern {
if let Some(expected) = type_ann {
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
let mut msg = format!(
"'{}' declared as {}, but assigned {}",
name,
format_type(expected),
format_type(actual)
);
if let Some(detail) = shape_mismatch_detail(expected, actual) {
msg.push_str(&format!(" ({})", detail));
}
self.error_at(msg, span);
}
}
}
if type_ann.is_none() {
if let Some(ref ty) = inferred {
if !is_obvious_type(value, ty) {
self.hints.push(InlayHintInfo {
line: span.line,
column: span.column + "var ".len() + name.len(),
label: format!(": {}", format_type(ty)),
});
}
}
}
let ty = type_ann.clone().or(inferred);
scope.define_var_mutable(name, ty);
scope.define_schema_binding(name, schema_type_expr_from_node(value, scope));
if self.strict_types {
if let Some(boundary) = Self::detect_boundary_source(value, scope) {
let has_concrete_ann =
type_ann.as_ref().is_some_and(Self::is_concrete_type);
if !has_concrete_ann {
scope.mark_untyped_source(name, &boundary);
}
}
}
} else {
self.check_pattern_defaults(pattern, scope);
Self::define_pattern_vars(pattern, scope, true);
}
}
Node::FnDecl {
name,
type_params,
params,
return_type,
where_clauses,
body,
..
} => {
let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
let sig = FnSignature {
params: params
.iter()
.map(|p| (p.name.clone(), p.type_expr.clone()))
.collect(),
return_type: return_type.clone(),
type_param_names: type_params.iter().map(|tp| tp.name.clone()).collect(),
required_params,
where_clauses: where_clauses
.iter()
.map(|wc| (wc.type_name.clone(), wc.bound.clone()))
.collect(),
has_rest: params.last().is_some_and(|p| p.rest),
};
scope.define_fn(name, sig.clone());
scope.define_var(name, None);
self.check_fn_decl_variance(type_params, params, return_type.as_ref(), name, span);
self.check_fn_body(type_params, params, return_type, body, where_clauses);
}
Node::ToolDecl {
name,
params,
return_type,
body,
..
} => {
let required_params = params.iter().filter(|p| p.default_value.is_none()).count();
let sig = FnSignature {
params: params
.iter()
.map(|p| (p.name.clone(), p.type_expr.clone()))
.collect(),
return_type: return_type.clone(),
type_param_names: Vec::new(),
required_params,
where_clauses: Vec::new(),
has_rest: params.last().is_some_and(|p| p.rest),
};
scope.define_fn(name, sig);
scope.define_var(name, None);
self.check_fn_body(&[], params, return_type, body, &[]);
}
Node::FunctionCall { name, args } => {
self.check_call(name, args, scope, span);
if self.strict_types && name == "schema_expect" && args.len() >= 2 {
if let Node::Identifier(var_name) = &args[0].node {
scope.clear_untyped_source(var_name);
if let Some(schema_type) = schema_type_expr_from_node(&args[1], scope) {
scope.define_var(var_name, Some(schema_type));
}
}
}
}
Node::IfElse {
condition,
then_body,
else_body,
} => {
self.check_node(condition, scope);
let refs = Self::extract_refinements(condition, scope);
let mut then_scope = scope.child();
refs.apply_truthy(&mut then_scope);
if self.strict_types {
if let Node::FunctionCall { name, args } = &condition.node {
if (name == "schema_is" || name == "is_type") && args.len() == 2 {
if let Node::Identifier(var_name) = &args[0].node {
then_scope.clear_untyped_source(var_name);
}
}
}
}
self.check_block(then_body, &mut then_scope);
if let Some(else_body) = else_body {
let mut else_scope = scope.child();
refs.apply_falsy(&mut else_scope);
self.check_block(else_body, &mut else_scope);
if Self::block_definitely_exits(then_body)
&& !Self::block_definitely_exits(else_body)
{
refs.apply_falsy(scope);
} else if Self::block_definitely_exits(else_body)
&& !Self::block_definitely_exits(then_body)
{
refs.apply_truthy(scope);
}
} else {
if Self::block_definitely_exits(then_body) {
refs.apply_falsy(scope);
}
}
}
Node::ForIn {
pattern,
iterable,
body,
} => {
self.check_node(iterable, scope);
let mut loop_scope = scope.child();
let iter_type = self.infer_type(iterable, scope);
if let BindingPattern::Identifier(variable) = pattern {
let elem_type = match iter_type {
Some(TypeExpr::List(inner)) => Some(*inner),
Some(TypeExpr::Iter(inner)) => Some(*inner),
Some(TypeExpr::Applied { ref name, ref args })
if name == "Iter" && args.len() == 1 =>
{
Some(args[0].clone())
}
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
Some(TypeExpr::Named(n)) if n == "range" => {
Some(TypeExpr::Named("int".into()))
}
_ => None,
};
loop_scope.define_var(variable, elem_type);
} else if let BindingPattern::Pair(a, b) = pattern {
let (ka, vb) = match &iter_type {
Some(TypeExpr::Iter(inner)) => {
if let TypeExpr::Applied { name, args } = inner.as_ref() {
if name == "Pair" && args.len() == 2 {
(Some(args[0].clone()), Some(args[1].clone()))
} else {
(None, None)
}
} else {
(None, None)
}
}
Some(TypeExpr::Applied { name, args })
if name == "Iter" && args.len() == 1 =>
{
if let TypeExpr::Applied { name: n2, args: a2 } = &args[0] {
if n2 == "Pair" && a2.len() == 2 {
(Some(a2[0].clone()), Some(a2[1].clone()))
} else {
(None, None)
}
} else {
(None, None)
}
}
_ => (None, None),
};
loop_scope.define_var(a, ka);
loop_scope.define_var(b, vb);
} else {
self.check_pattern_defaults(pattern, &mut loop_scope);
Self::define_pattern_vars(pattern, &mut loop_scope, false);
}
self.check_block(body, &mut loop_scope);
}
Node::WhileLoop { condition, body } => {
self.check_node(condition, scope);
let refs = Self::extract_refinements(condition, scope);
let mut loop_scope = scope.child();
refs.apply_truthy(&mut loop_scope);
self.check_block(body, &mut loop_scope);
}
Node::RequireStmt { condition, message } => {
self.check_node(condition, scope);
if let Some(message) = message {
self.check_node(message, scope);
}
}
Node::TryCatch {
body,
error_var,
error_type,
catch_body,
finally_body,
..
} => {
let mut try_scope = scope.child();
self.check_block(body, &mut try_scope);
let mut catch_scope = scope.child();
if let Some(var) = error_var {
catch_scope.define_var(var, error_type.clone());
}
self.check_block(catch_body, &mut catch_scope);
if let Some(fb) = finally_body {
let mut finally_scope = scope.child();
self.check_block(fb, &mut finally_scope);
}
}
Node::TryExpr { body } => {
let mut try_scope = scope.child();
self.check_block(body, &mut try_scope);
}
Node::TryStar { operand } => {
if self.fn_depth == 0 {
self.error_at(
"try* requires an enclosing function (fn, tool, or pipeline) so the rethrow has a target".to_string(),
span,
);
}
self.check_node(operand, scope);
}
Node::ReturnStmt {
value: Some(val), ..
} => {
self.check_node(val, scope);
}
Node::Assignment {
target, value, op, ..
} => {
self.check_node(value, scope);
if let Node::Identifier(name) = &target.node {
if scope.get_var(name).is_some() && !scope.is_mutable(name) {
self.warning_at(
format!(
"Cannot assign to '{}': variable is immutable (declared with 'let')",
name
),
span,
);
}
if let Some(Some(var_type)) = scope.get_var(name) {
let value_type = self.infer_type(value, scope);
let assigned = if let Some(op) = op {
let var_inferred = scope.get_var(name).cloned().flatten();
infer_binary_op_type(op, &var_inferred, &value_type)
} else {
value_type
};
if let Some(actual) = &assigned {
let check_type = scope
.narrowed_vars
.get(name)
.and_then(|t| t.as_ref())
.unwrap_or(var_type);
if !self.types_compatible(check_type, actual, scope) {
self.error_at(
format!(
"can't assign {} to '{}' (declared as {})",
format_type(actual),
name,
format_type(check_type)
),
span,
);
}
}
}
if let Some(original) = scope.narrowed_vars.remove(name) {
scope.define_var(name, original);
}
scope.define_schema_binding(name, None);
scope.clear_unknown_ruled_out(name);
}
}
Node::TypeDecl {
name,
type_params,
type_expr,
} => {
scope.type_aliases.insert(name.clone(), type_expr.clone());
self.check_type_alias_decl_variance(type_params, type_expr, name, span);
}
Node::EnumDecl {
name,
type_params,
variants,
..
} => {
scope.enums.insert(
name.clone(),
EnumDeclInfo {
type_params: type_params.clone(),
variants: variants.clone(),
},
);
self.check_enum_decl_variance(type_params, variants, name, span);
}
Node::StructDecl {
name,
type_params,
fields,
..
} => {
scope.structs.insert(
name.clone(),
StructDeclInfo {
type_params: type_params.clone(),
fields: fields.clone(),
},
);
self.check_struct_decl_variance(type_params, fields, name, span);
}
Node::InterfaceDecl {
name,
type_params,
associated_types,
methods,
} => {
scope.interfaces.insert(
name.clone(),
InterfaceDeclInfo {
type_params: type_params.clone(),
associated_types: associated_types.clone(),
methods: methods.clone(),
},
);
self.check_interface_decl_variance(type_params, methods, name, span);
}
Node::ImplBlock {
type_name, methods, ..
} => {
let sigs: Vec<ImplMethodSig> = methods
.iter()
.filter_map(|m| {
if let Node::FnDecl {
name,
params,
return_type,
..
} = &m.node
{
let non_self: Vec<_> =
params.iter().filter(|p| p.name != "self").collect();
let param_count = non_self.len();
let param_types: Vec<Option<TypeExpr>> =
non_self.iter().map(|p| p.type_expr.clone()).collect();
Some(ImplMethodSig {
name: name.clone(),
param_count,
param_types,
return_type: return_type.clone(),
})
} else {
None
}
})
.collect();
scope.impl_methods.insert(type_name.clone(), sigs);
for method_sn in methods {
self.check_node(method_sn, scope);
}
}
Node::TryOperator { operand } => {
self.check_node(operand, scope);
}
Node::MatchExpr { value, arms } => {
self.check_node(value, scope);
let value_type = self.infer_type(value, scope);
for arm in arms {
self.check_node(&arm.pattern, scope);
if let Some(ref vt) = value_type {
let value_type_name = format_type(vt);
let mismatch = match &arm.pattern.node {
Node::StringLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("string".into()), scope)
}
Node::IntLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("int".into()), scope)
&& !self.types_compatible(
vt,
&TypeExpr::Named("float".into()),
scope,
)
}
Node::FloatLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("float".into()), scope)
&& !self.types_compatible(
vt,
&TypeExpr::Named("int".into()),
scope,
)
}
Node::BoolLiteral(_) => {
!self.types_compatible(vt, &TypeExpr::Named("bool".into()), scope)
}
_ => false,
};
if mismatch {
let pattern_type = match &arm.pattern.node {
Node::StringLiteral(_) => "string",
Node::IntLiteral(_) => "int",
Node::FloatLiteral(_) => "float",
Node::BoolLiteral(_) => "bool",
_ => unreachable!(),
};
self.warning_at(
format!(
"Match pattern type mismatch: matching {} against {} literal",
value_type_name, pattern_type
),
arm.pattern.span,
);
}
}
let mut arm_scope = scope.child();
if let Node::Identifier(var_name) = &value.node {
if let Some(Some(TypeExpr::Union(members))) = scope.get_var(var_name) {
let narrowed = match &arm.pattern.node {
Node::NilLiteral => narrow_to_single(members, "nil"),
Node::StringLiteral(_) => narrow_to_single(members, "string"),
Node::IntLiteral(_) => narrow_to_single(members, "int"),
Node::FloatLiteral(_) => narrow_to_single(members, "float"),
Node::BoolLiteral(_) => narrow_to_single(members, "bool"),
_ => None,
};
if let Some(narrowed_type) = narrowed {
arm_scope.define_var(var_name, Some(narrowed_type));
}
}
}
if let Some(ref guard) = arm.guard {
self.check_node(guard, &mut arm_scope);
}
self.check_block(&arm.body, &mut arm_scope);
}
self.check_match_exhaustiveness(value, arms, scope, span);
}
Node::BinaryOp { op, left, right } => {
self.check_node(left, scope);
self.check_node(right, scope);
let lt = self.infer_type(left, scope);
let rt = self.infer_type(right, scope);
if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
match op.as_str() {
"-" | "/" | "%" | "**" => {
let numeric = ["int", "float"];
if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
self.error_at(
format!(
"can't use '{}' on {} and {} (needs numeric operands)",
op, l, r
),
span,
);
}
}
"*" => {
let numeric = ["int", "float"];
let is_numeric =
numeric.contains(&l.as_str()) && numeric.contains(&r.as_str());
let is_string_repeat =
(l == "string" && r == "int") || (l == "int" && r == "string");
if !is_numeric && !is_string_repeat {
self.error_at(
format!("can't multiply {} and {} (try string * int)", l, r),
span,
);
}
}
"+" => {
let valid = matches!(
(l.as_str(), r.as_str()),
("int" | "float", "int" | "float")
| ("string", "string")
| ("list", "list")
| ("dict", "dict")
);
if !valid {
let msg = format!("can't add {} and {}", l, r);
let fix = if l == "string" || r == "string" {
self.build_interpolation_fix(left, right, l == "string", span)
} else {
None
};
if let Some(fix) = fix {
self.error_at_with_fix(msg, span, fix);
} else {
self.error_at(msg, span);
}
}
}
"<" | ">" | "<=" | ">=" => {
let comparable = ["int", "float", "string"];
if !comparable.contains(&l.as_str())
|| !comparable.contains(&r.as_str())
{
self.warning_at(
format!(
"Comparison '{}' may not be meaningful for types {} and {}",
op, l, r
),
span,
);
} else if (l == "string") != (r == "string") {
self.warning_at(
format!(
"Comparing {} with {} using '{}' may give unexpected results",
l, r, op
),
span,
);
}
}
_ => {}
}
}
}
Node::UnaryOp { operand, .. } => {
self.check_node(operand, scope);
}
Node::MethodCall {
object,
method,
args,
..
}
| Node::OptionalMethodCall {
object,
method,
args,
..
} => {
self.check_node(object, scope);
for arg in args {
self.check_node(arg, scope);
}
if let Some(TypeExpr::Named(type_name)) = self.infer_type(object, scope) {
if scope.is_generic_type_param(&type_name) {
if let Some(iface_name) = scope.get_where_constraint(&type_name) {
if let Some(iface_methods) = scope.get_interface(iface_name) {
let has_method =
iface_methods.methods.iter().any(|m| m.name == *method);
if !has_method {
self.warning_at(
format!(
"Method '{}' not found in interface '{}' (constraint on '{}')",
method, iface_name, type_name
),
span,
);
}
}
}
}
}
}
Node::PropertyAccess { object, .. } | Node::OptionalPropertyAccess { object, .. } => {
if self.strict_types {
if let Node::FunctionCall { name, args } = &object.node {
if builtin_signatures::is_untyped_boundary_source(name) {
let has_schema = (name == "llm_call" || name == "llm_completion")
&& Self::extract_llm_schema_from_options(args, scope).is_some();
if !has_schema {
self.warning_at_with_help(
format!(
"Direct property access on unvalidated `{}()` result",
name
),
span,
"assign to a variable and validate with schema_expect() or a type annotation first".to_string(),
);
}
}
}
if let Node::Identifier(name) = &object.node {
if let Some(source) = scope.is_untyped_source(name) {
self.warning_at_with_help(
format!(
"Accessing property on unvalidated value '{}' from `{}`",
name, source
),
span,
"validate with schema_expect(), schema_is() in an if-condition, or add a shape type annotation".to_string(),
);
}
}
}
self.check_node(object, scope);
}
Node::SubscriptAccess { object, index } => {
if self.strict_types {
if let Node::FunctionCall { name, args } = &object.node {
if builtin_signatures::is_untyped_boundary_source(name) {
let has_schema = (name == "llm_call" || name == "llm_completion")
&& Self::extract_llm_schema_from_options(args, scope).is_some();
if !has_schema {
self.warning_at_with_help(
format!(
"Direct subscript access on unvalidated `{}()` result",
name
),
span,
"assign to a variable and validate with schema_expect() or a type annotation first".to_string(),
);
}
}
}
if let Node::Identifier(name) = &object.node {
if let Some(source) = scope.is_untyped_source(name) {
self.warning_at_with_help(
format!(
"Subscript access on unvalidated value '{}' from `{}`",
name, source
),
span,
"validate with schema_expect(), schema_is() in an if-condition, or add a shape type annotation".to_string(),
);
}
}
}
self.check_node(object, scope);
self.check_node(index, scope);
}
Node::SliceAccess { object, start, end } => {
self.check_node(object, scope);
if let Some(s) = start {
self.check_node(s, scope);
}
if let Some(e) = end {
self.check_node(e, scope);
}
}
Node::Ternary {
condition,
true_expr,
false_expr,
} => {
self.check_node(condition, scope);
let refs = Self::extract_refinements(condition, scope);
let mut true_scope = scope.child();
refs.apply_truthy(&mut true_scope);
self.check_node(true_expr, &mut true_scope);
let mut false_scope = scope.child();
refs.apply_falsy(&mut false_scope);
self.check_node(false_expr, &mut false_scope);
}
Node::ThrowStmt { value } => {
self.check_node(value, scope);
self.check_unknown_exhaustiveness(scope, snode.span, "throw");
}
Node::GuardStmt {
condition,
else_body,
} => {
self.check_node(condition, scope);
let refs = Self::extract_refinements(condition, scope);
let mut else_scope = scope.child();
refs.apply_falsy(&mut else_scope);
self.check_block(else_body, &mut else_scope);
refs.apply_truthy(scope);
}
Node::SpawnExpr { body } => {
let mut spawn_scope = scope.child();
self.check_block(body, &mut spawn_scope);
}
Node::Parallel {
mode,
expr,
variable,
body,
options,
} => {
self.check_node(expr, scope);
for (key, value) in options {
self.check_node(value, scope);
if key == "max_concurrent" {
if let Some(ty) = self.infer_type(value, scope) {
if !matches!(ty, TypeExpr::Named(ref n) if n == "int") {
self.error_at(
format!(
"`max_concurrent` on `parallel` must be int, got {ty:?}"
),
value.span,
);
}
}
}
}
let mut par_scope = scope.child();
if let Some(var) = variable {
let var_type = match mode {
ParallelMode::Count => Some(TypeExpr::Named("int".into())),
ParallelMode::Each | ParallelMode::Settle => {
match self.infer_type(expr, scope) {
Some(TypeExpr::List(inner)) => Some(*inner),
_ => None,
}
}
};
par_scope.define_var(var, var_type);
}
self.check_block(body, &mut par_scope);
}
Node::SelectExpr {
cases,
timeout,
default_body,
} => {
for case in cases {
self.check_node(&case.channel, scope);
let mut case_scope = scope.child();
case_scope.define_var(&case.variable, None);
self.check_block(&case.body, &mut case_scope);
}
if let Some((dur, body)) = timeout {
self.check_node(dur, scope);
let mut timeout_scope = scope.child();
self.check_block(body, &mut timeout_scope);
}
if let Some(body) = default_body {
let mut default_scope = scope.child();
self.check_block(body, &mut default_scope);
}
}
Node::DeadlineBlock { duration, body } => {
self.check_node(duration, scope);
let mut block_scope = scope.child();
self.check_block(body, &mut block_scope);
}
Node::MutexBlock { body } | Node::DeferStmt { body } => {
let mut block_scope = scope.child();
self.check_block(body, &mut block_scope);
}
Node::Retry { count, body } => {
self.check_node(count, scope);
let mut retry_scope = scope.child();
self.check_block(body, &mut retry_scope);
}
Node::Closure { params, body, .. } => {
let mut closure_scope = scope.child();
for p in params {
closure_scope.define_var(&p.name, p.type_expr.clone());
}
self.fn_depth += 1;
self.check_block(body, &mut closure_scope);
self.fn_depth -= 1;
}
Node::ListLiteral(elements) => {
for elem in elements {
self.check_node(elem, scope);
}
}
Node::DictLiteral(entries) => {
for entry in entries {
self.check_node(&entry.key, scope);
self.check_node(&entry.value, scope);
}
}
Node::RangeExpr { start, end, .. } => {
self.check_node(start, scope);
self.check_node(end, scope);
}
Node::Spread(inner) => {
self.check_node(inner, scope);
}
Node::Block(stmts) => {
let mut block_scope = scope.child();
self.check_block(stmts, &mut block_scope);
}
Node::YieldExpr { value } => {
if let Some(v) = value {
self.check_node(v, scope);
}
}
Node::StructConstruct {
struct_name,
fields,
} => {
for entry in fields {
self.check_node(&entry.key, scope);
self.check_node(&entry.value, scope);
}
if let Some(struct_info) = scope.get_struct(struct_name).cloned() {
let type_bindings = self.infer_struct_bindings(&struct_info, fields, scope);
for entry in fields {
if let Node::StringLiteral(key) | Node::Identifier(key) = &entry.key.node {
if !struct_info.fields.iter().any(|field| field.name == *key) {
self.warning_at(
format!("Unknown field '{}' in struct '{}'", key, struct_name),
entry.key.span,
);
}
}
}
let provided: Vec<String> = fields
.iter()
.filter_map(|e| match &e.key.node {
Node::StringLiteral(k) | Node::Identifier(k) => Some(k.clone()),
_ => None,
})
.collect();
for field in &struct_info.fields {
if !field.optional && !provided.contains(&field.name) {
self.warning_at(
format!(
"Missing field '{}' in struct '{}' construction",
field.name, struct_name
),
span,
);
}
}
for field in &struct_info.fields {
let Some(expected_type) = &field.type_expr else {
continue;
};
let Some(entry) = fields.iter().find(|entry| {
matches!(&entry.key.node, Node::StringLiteral(key) | Node::Identifier(key) if key == &field.name)
}) else {
continue;
};
let Some(actual_type) = self.infer_type(&entry.value, scope) else {
continue;
};
let expected = Self::apply_type_bindings(expected_type, &type_bindings);
if !self.types_compatible(&expected, &actual_type, scope) {
self.error_at(
format!(
"Field '{}' in struct '{}' expects {}, got {}",
field.name,
struct_name,
format_type(&expected),
format_type(&actual_type)
),
entry.value.span,
);
}
}
}
}
Node::EnumConstruct {
enum_name,
variant,
args,
} => {
for arg in args {
self.check_node(arg, scope);
}
if let Some(enum_info) = scope.get_enum(enum_name).cloned() {
let Some(enum_variant) = enum_info
.variants
.iter()
.find(|enum_variant| enum_variant.name == *variant)
else {
self.warning_at(
format!("Unknown variant '{}' in enum '{}'", variant, enum_name),
span,
);
return;
};
if args.len() != enum_variant.fields.len() {
self.warning_at(
format!(
"{}.{} expects {} argument(s), got {}",
enum_name,
variant,
enum_variant.fields.len(),
args.len()
),
span,
);
}
let type_param_set: std::collections::BTreeSet<String> = enum_info
.type_params
.iter()
.map(|tp| tp.name.clone())
.collect();
let mut type_bindings = BTreeMap::new();
for (field, arg) in enum_variant.fields.iter().zip(args.iter()) {
let Some(expected_type) = &field.type_expr else {
continue;
};
let Some(actual_type) = self.infer_type(arg, scope) else {
continue;
};
if let Err(message) = Self::extract_type_bindings(
expected_type,
&actual_type,
&type_param_set,
&mut type_bindings,
) {
self.error_at(message, arg.span);
}
}
for (field, arg) in enum_variant.fields.iter().zip(args.iter()) {
let Some(expected_type) = &field.type_expr else {
continue;
};
let Some(actual_type) = self.infer_type(arg, scope) else {
continue;
};
let expected = Self::apply_type_bindings(expected_type, &type_bindings);
if !self.types_compatible(&expected, &actual_type, scope) {
self.error_at(
format!(
"{}.{} expects {}: {}, got {}",
enum_name,
variant,
field.name,
format_type(&expected),
format_type(&actual_type)
),
arg.span,
);
}
}
}
}
Node::InterpolatedString(_) => {}
Node::StringLiteral(_)
| Node::RawStringLiteral(_)
| Node::IntLiteral(_)
| Node::FloatLiteral(_)
| Node::BoolLiteral(_)
| Node::NilLiteral
| Node::Identifier(_)
| Node::DurationLiteral(_)
| Node::BreakStmt
| Node::ContinueStmt
| Node::ReturnStmt { value: None }
| Node::ImportDecl { .. }
| Node::SelectiveImport { .. } => {}
Node::Pipeline { body, .. } | Node::OverrideDecl { body, .. } => {
let mut decl_scope = scope.child();
self.fn_depth += 1;
self.check_block(body, &mut decl_scope);
self.fn_depth -= 1;
}
}
}
fn check_fn_body(
&mut self,
type_params: &[TypeParam],
params: &[TypedParam],
return_type: &Option<TypeExpr>,
body: &[SNode],
where_clauses: &[WhereClause],
) {
self.fn_depth += 1;
self.check_fn_body_inner(type_params, params, return_type, body, where_clauses);
self.fn_depth -= 1;
}
fn check_fn_body_inner(
&mut self,
type_params: &[TypeParam],
params: &[TypedParam],
return_type: &Option<TypeExpr>,
body: &[SNode],
where_clauses: &[WhereClause],
) {
let mut fn_scope = self.scope.child();
for tp in type_params {
fn_scope.generic_type_params.insert(tp.name.clone());
}
for wc in where_clauses {
fn_scope
.where_constraints
.insert(wc.type_name.clone(), wc.bound.clone());
}
for param in params {
fn_scope.define_var(¶m.name, param.type_expr.clone());
if let Some(default) = ¶m.default_value {
self.check_node(default, &mut fn_scope);
}
}
let ret_scope_base = if return_type.is_some() {
Some(fn_scope.child())
} else {
None
};
self.check_block(body, &mut fn_scope);
if let Some(ret_type) = return_type {
let mut ret_scope = ret_scope_base.unwrap();
for stmt in body {
self.check_return_type(stmt, ret_type, &mut ret_scope);
}
}
}
fn check_return_type(&mut self, snode: &SNode, expected: &TypeExpr, scope: &mut TypeScope) {
let span = snode.span;
match &snode.node {
Node::ReturnStmt { value: Some(val) } => {
let inferred = self.infer_type(val, scope);
if let Some(actual) = &inferred {
if !self.types_compatible(expected, actual, scope) {
self.error_at(
format!(
"return type doesn't match: expected {}, got {}",
format_type(expected),
format_type(actual)
),
span,
);
}
}
}
Node::IfElse {
condition,
then_body,
else_body,
} => {
let refs = Self::extract_refinements(condition, scope);
let mut then_scope = scope.child();
refs.apply_truthy(&mut then_scope);
for stmt in then_body {
self.check_return_type(stmt, expected, &mut then_scope);
}
if let Some(else_body) = else_body {
let mut else_scope = scope.child();
refs.apply_falsy(&mut else_scope);
for stmt in else_body {
self.check_return_type(stmt, expected, &mut else_scope);
}
if Self::block_definitely_exits(then_body)
&& !Self::block_definitely_exits(else_body)
{
refs.apply_falsy(scope);
} else if Self::block_definitely_exits(else_body)
&& !Self::block_definitely_exits(then_body)
{
refs.apply_truthy(scope);
}
} else {
if Self::block_definitely_exits(then_body) {
refs.apply_falsy(scope);
}
}
}
_ => {}
}
}
fn satisfies_interface(
&self,
type_name: &str,
interface_name: &str,
interface_bindings: &BTreeMap<String, TypeExpr>,
scope: &TypeScope,
) -> bool {
self.interface_mismatch_reason(type_name, interface_name, interface_bindings, scope)
.is_none()
}
fn interface_mismatch_reason(
&self,
type_name: &str,
interface_name: &str,
interface_bindings: &BTreeMap<String, TypeExpr>,
scope: &TypeScope,
) -> Option<String> {
let interface_info = match scope.get_interface(interface_name) {
Some(info) => info,
None => return Some(format!("interface '{}' not found", interface_name)),
};
let impl_methods = match scope.get_impl_methods(type_name) {
Some(methods) => methods,
None => {
if interface_info.methods.is_empty() {
return None;
}
let names: Vec<_> = interface_info
.methods
.iter()
.map(|m| m.name.as_str())
.collect();
return Some(format!("missing method(s): {}", names.join(", ")));
}
};
let mut bindings = interface_bindings.clone();
let associated_type_names: std::collections::BTreeSet<String> = interface_info
.associated_types
.iter()
.map(|(name, _)| name.clone())
.collect();
for iface_method in &interface_info.methods {
let iface_params: Vec<_> = iface_method
.params
.iter()
.filter(|p| p.name != "self")
.collect();
let iface_param_count = iface_params.len();
let matching_impl = impl_methods.iter().find(|im| im.name == iface_method.name);
let impl_method = match matching_impl {
Some(m) => m,
None => {
return Some(format!("missing method '{}'", iface_method.name));
}
};
if impl_method.param_count != iface_param_count {
return Some(format!(
"method '{}' has {} parameter(s), expected {}",
iface_method.name, impl_method.param_count, iface_param_count
));
}
for (i, iface_param) in iface_params.iter().enumerate() {
if let (Some(expected), Some(actual)) = (
&iface_param.type_expr,
impl_method.param_types.get(i).and_then(|t| t.as_ref()),
) {
if let Err(message) = Self::extract_type_bindings(
expected,
actual,
&associated_type_names,
&mut bindings,
) {
return Some(message);
}
let expected = Self::apply_type_bindings(expected, &bindings);
if !self.types_compatible(&expected, actual, scope) {
return Some(format!(
"method '{}' parameter {} has type '{}', expected '{}'",
iface_method.name,
i + 1,
format_type(actual),
format_type(&expected),
));
}
}
}
if let (Some(expected_ret), Some(actual_ret)) =
(&iface_method.return_type, &impl_method.return_type)
{
if let Err(message) = Self::extract_type_bindings(
expected_ret,
actual_ret,
&associated_type_names,
&mut bindings,
) {
return Some(message);
}
let expected_ret = Self::apply_type_bindings(expected_ret, &bindings);
if !self.types_compatible(&expected_ret, actual_ret, scope) {
return Some(format!(
"method '{}' returns '{}', expected '{}'",
iface_method.name,
format_type(actual_ret),
format_type(&expected_ret),
));
}
}
}
for (assoc_name, default_type) in &interface_info.associated_types {
if let (Some(default_type), Some(actual)) = (default_type, bindings.get(assoc_name)) {
let expected = Self::apply_type_bindings(default_type, &bindings);
if !self.types_compatible(&expected, actual, scope) {
return Some(format!(
"associated type '{}' resolves to '{}', expected '{}'",
assoc_name,
format_type(actual),
format_type(&expected),
));
}
}
}
None
}
fn bind_type_param(
param_name: &str,
concrete: &TypeExpr,
bindings: &mut BTreeMap<String, TypeExpr>,
) -> Result<(), String> {
if let Some(existing) = bindings.get(param_name) {
if existing != concrete {
return Err(format!(
"type parameter '{}' was inferred as both {} and {}",
param_name,
format_type(existing),
format_type(concrete)
));
}
return Ok(());
}
bindings.insert(param_name.to_string(), concrete.clone());
Ok(())
}
fn extract_type_bindings(
param_type: &TypeExpr,
arg_type: &TypeExpr,
type_params: &std::collections::BTreeSet<String>,
bindings: &mut BTreeMap<String, TypeExpr>,
) -> Result<(), String> {
match (param_type, arg_type) {
(TypeExpr::Named(param_name), concrete) if type_params.contains(param_name) => {
Self::bind_type_param(param_name, concrete, bindings)
}
(TypeExpr::List(p_inner), TypeExpr::List(a_inner)) => {
Self::extract_type_bindings(p_inner, a_inner, type_params, bindings)
}
(TypeExpr::DictType(pk, pv), TypeExpr::DictType(ak, av)) => {
Self::extract_type_bindings(pk, ak, type_params, bindings)?;
Self::extract_type_bindings(pv, av, type_params, bindings)
}
(
TypeExpr::Applied {
name: p_name,
args: p_args,
},
TypeExpr::Applied {
name: a_name,
args: a_args,
},
) if p_name == a_name && p_args.len() == a_args.len() => {
for (param, arg) in p_args.iter().zip(a_args.iter()) {
Self::extract_type_bindings(param, arg, type_params, bindings)?;
}
Ok(())
}
(TypeExpr::Shape(param_fields), TypeExpr::Shape(arg_fields)) => {
for param_field in param_fields {
if let Some(arg_field) = arg_fields
.iter()
.find(|field| field.name == param_field.name)
{
Self::extract_type_bindings(
¶m_field.type_expr,
&arg_field.type_expr,
type_params,
bindings,
)?;
}
}
Ok(())
}
(
TypeExpr::FnType {
params: p_params,
return_type: p_ret,
},
TypeExpr::FnType {
params: a_params,
return_type: a_ret,
},
) => {
for (param, arg) in p_params.iter().zip(a_params.iter()) {
Self::extract_type_bindings(param, arg, type_params, bindings)?;
}
Self::extract_type_bindings(p_ret, a_ret, type_params, bindings)
}
_ => Ok(()),
}
}
fn apply_type_bindings(ty: &TypeExpr, bindings: &BTreeMap<String, TypeExpr>) -> TypeExpr {
match ty {
TypeExpr::Named(name) => bindings
.get(name)
.cloned()
.unwrap_or_else(|| TypeExpr::Named(name.clone())),
TypeExpr::Union(items) => TypeExpr::Union(
items
.iter()
.map(|item| Self::apply_type_bindings(item, bindings))
.collect(),
),
TypeExpr::Shape(fields) => TypeExpr::Shape(
fields
.iter()
.map(|field| ShapeField {
name: field.name.clone(),
type_expr: Self::apply_type_bindings(&field.type_expr, bindings),
optional: field.optional,
})
.collect(),
),
TypeExpr::List(inner) => {
TypeExpr::List(Box::new(Self::apply_type_bindings(inner, bindings)))
}
TypeExpr::Iter(inner) => {
TypeExpr::Iter(Box::new(Self::apply_type_bindings(inner, bindings)))
}
TypeExpr::DictType(key, value) => TypeExpr::DictType(
Box::new(Self::apply_type_bindings(key, bindings)),
Box::new(Self::apply_type_bindings(value, bindings)),
),
TypeExpr::Applied { name, args } => TypeExpr::Applied {
name: name.clone(),
args: args
.iter()
.map(|arg| Self::apply_type_bindings(arg, bindings))
.collect(),
},
TypeExpr::FnType {
params,
return_type,
} => TypeExpr::FnType {
params: params
.iter()
.map(|param| Self::apply_type_bindings(param, bindings))
.collect(),
return_type: Box::new(Self::apply_type_bindings(return_type, bindings)),
},
TypeExpr::Never => TypeExpr::Never,
TypeExpr::LitString(s) => TypeExpr::LitString(s.clone()),
TypeExpr::LitInt(v) => TypeExpr::LitInt(*v),
}
}
fn applied_type_or_name(name: &str, args: Vec<TypeExpr>) -> TypeExpr {
if args.is_empty() {
TypeExpr::Named(name.to_string())
} else {
TypeExpr::Applied {
name: name.to_string(),
args,
}
}
}
fn infer_struct_bindings(
&self,
struct_info: &StructDeclInfo,
fields: &[DictEntry],
scope: &TypeScope,
) -> BTreeMap<String, TypeExpr> {
let type_param_set: std::collections::BTreeSet<String> = struct_info
.type_params
.iter()
.map(|tp| tp.name.clone())
.collect();
let mut bindings = BTreeMap::new();
for field in &struct_info.fields {
let Some(expected_type) = &field.type_expr else {
continue;
};
let Some(entry) = fields.iter().find(|entry| {
matches!(&entry.key.node, Node::StringLiteral(key) | Node::Identifier(key) if key == &field.name)
}) else {
continue;
};
let Some(actual_type) = self.infer_type(&entry.value, scope) else {
continue;
};
let _ = Self::extract_type_bindings(
expected_type,
&actual_type,
&type_param_set,
&mut bindings,
);
}
bindings
}
fn infer_struct_type(
&self,
struct_name: &str,
struct_info: &StructDeclInfo,
fields: &[DictEntry],
scope: &TypeScope,
) -> TypeExpr {
let bindings = self.infer_struct_bindings(struct_info, fields, scope);
let args = struct_info
.type_params
.iter()
.map(|tp| {
bindings
.get(&tp.name)
.cloned()
.unwrap_or_else(Self::wildcard_type)
})
.collect();
Self::applied_type_or_name(struct_name, args)
}
fn infer_enum_type(
&self,
enum_name: &str,
enum_info: &EnumDeclInfo,
variant_name: &str,
args: &[SNode],
scope: &TypeScope,
) -> TypeExpr {
let type_param_set: std::collections::BTreeSet<String> = enum_info
.type_params
.iter()
.map(|tp| tp.name.clone())
.collect();
let mut bindings = BTreeMap::new();
if let Some(variant) = enum_info
.variants
.iter()
.find(|variant| variant.name == variant_name)
{
for (field, arg) in variant.fields.iter().zip(args.iter()) {
let Some(expected_type) = &field.type_expr else {
continue;
};
let Some(actual_type) = self.infer_type(arg, scope) else {
continue;
};
let _ = Self::extract_type_bindings(
expected_type,
&actual_type,
&type_param_set,
&mut bindings,
);
}
}
let args = enum_info
.type_params
.iter()
.map(|tp| {
bindings
.get(&tp.name)
.cloned()
.unwrap_or_else(Self::wildcard_type)
})
.collect();
Self::applied_type_or_name(enum_name, args)
}
fn infer_try_error_type(&self, stmts: &[SNode], scope: &TypeScope) -> InferredType {
let mut inferred: Vec<TypeExpr> = Vec::new();
for stmt in stmts {
match &stmt.node {
Node::ThrowStmt { value } => {
if let Some(ty) = self.infer_type(value, scope) {
inferred.push(ty);
}
}
Node::TryOperator { operand } => {
if let Some(TypeExpr::Applied { name, args }) = self.infer_type(operand, scope)
{
if name == "Result" && args.len() == 2 {
inferred.push(args[1].clone());
}
}
}
Node::IfElse {
then_body,
else_body,
..
} => {
if let Some(ty) = self.infer_try_error_type(then_body, scope) {
inferred.push(ty);
}
if let Some(else_body) = else_body {
if let Some(ty) = self.infer_try_error_type(else_body, scope) {
inferred.push(ty);
}
}
}
Node::Block(body)
| Node::TryExpr { body }
| Node::SpawnExpr { body }
| Node::Retry { body, .. }
| Node::WhileLoop { body, .. }
| Node::DeferStmt { body }
| Node::MutexBlock { body }
| Node::DeadlineBlock { body, .. }
| Node::Pipeline { body, .. }
| Node::OverrideDecl { body, .. } => {
if let Some(ty) = self.infer_try_error_type(body, scope) {
inferred.push(ty);
}
}
_ => {}
}
}
if inferred.is_empty() {
None
} else {
Some(simplify_union(inferred))
}
}
fn infer_list_literal_type(&self, items: &[SNode], scope: &TypeScope) -> TypeExpr {
let mut inferred: Option<TypeExpr> = None;
for item in items {
let Some(item_type) = self.infer_type(item, scope) else {
return TypeExpr::Named("list".into());
};
inferred = Some(match inferred {
None => item_type,
Some(current) if current == item_type => current,
Some(TypeExpr::Union(mut members)) => {
if !members.contains(&item_type) {
members.push(item_type);
}
TypeExpr::Union(members)
}
Some(current) => TypeExpr::Union(vec![current, item_type]),
});
}
inferred
.map(|item_type| TypeExpr::List(Box::new(item_type)))
.unwrap_or_else(|| TypeExpr::Named("list".into()))
}
fn extract_refinements(condition: &SNode, scope: &TypeScope) -> Refinements {
match &condition.node {
Node::BinaryOp { op, left, right } if op == "!=" || op == "==" => {
let nil_ref = Self::extract_nil_refinements(op, left, right, scope);
if !nil_ref.truthy.is_empty() || !nil_ref.falsy.is_empty() {
return nil_ref;
}
let typeof_ref = Self::extract_typeof_refinements(op, left, right, scope);
if !typeof_ref.truthy.is_empty() || !typeof_ref.falsy.is_empty() {
return typeof_ref;
}
Refinements::empty()
}
Node::BinaryOp { op, left, right } if op == "&&" => {
let left_ref = Self::extract_refinements(left, scope);
let right_ref = Self::extract_refinements(right, scope);
let mut truthy = left_ref.truthy;
truthy.extend(right_ref.truthy);
let mut truthy_ruled_out = left_ref.truthy_ruled_out;
truthy_ruled_out.extend(right_ref.truthy_ruled_out);
Refinements {
truthy,
falsy: vec![],
truthy_ruled_out,
falsy_ruled_out: vec![],
}
}
Node::BinaryOp { op, left, right } if op == "||" => {
let left_ref = Self::extract_refinements(left, scope);
let right_ref = Self::extract_refinements(right, scope);
let mut falsy = left_ref.falsy;
falsy.extend(right_ref.falsy);
let mut falsy_ruled_out = left_ref.falsy_ruled_out;
falsy_ruled_out.extend(right_ref.falsy_ruled_out);
Refinements {
truthy: vec![],
falsy,
truthy_ruled_out: vec![],
falsy_ruled_out,
}
}
Node::UnaryOp { op, operand } if op == "!" => {
Self::extract_refinements(operand, scope).inverted()
}
Node::Identifier(name) => {
if let Some(Some(TypeExpr::Union(members))) = scope.get_var(name) {
if members
.iter()
.any(|m| matches!(m, TypeExpr::Named(n) if n == "nil"))
{
if let Some(narrowed) = remove_from_union(members, "nil") {
return Refinements {
truthy: vec![(name.clone(), Some(narrowed))],
falsy: vec![(name.clone(), Some(TypeExpr::Named("nil".into())))],
truthy_ruled_out: vec![],
falsy_ruled_out: vec![],
};
}
}
}
Refinements::empty()
}
Node::MethodCall {
object,
method,
args,
} if method == "has" && args.len() == 1 => {
Self::extract_has_refinements(object, args, scope)
}
Node::FunctionCall { name, args }
if (name == "schema_is" || name == "is_type") && args.len() == 2 =>
{
Self::extract_schema_refinements(args, scope)
}
_ => Refinements::empty(),
}
}
fn extract_nil_refinements(
op: &str,
left: &SNode,
right: &SNode,
scope: &TypeScope,
) -> Refinements {
let var_node = if matches!(right.node, Node::NilLiteral) {
left
} else if matches!(left.node, Node::NilLiteral) {
right
} else {
return Refinements::empty();
};
if let Node::Identifier(name) = &var_node.node {
let var_type = scope.get_var(name).cloned().flatten();
match var_type {
Some(TypeExpr::Union(ref members)) => {
if let Some(narrowed) = remove_from_union(members, "nil") {
let neq_refs = Refinements {
truthy: vec![(name.clone(), Some(narrowed))],
falsy: vec![(name.clone(), Some(TypeExpr::Named("nil".into())))],
..Refinements::default()
};
return if op == "!=" {
neq_refs
} else {
neq_refs.inverted()
};
}
}
Some(TypeExpr::Named(ref n)) if n == "nil" => {
let eq_refs = Refinements {
truthy: vec![(name.clone(), Some(TypeExpr::Named("nil".into())))],
falsy: vec![(name.clone(), Some(TypeExpr::Never))],
..Refinements::default()
};
return if op == "==" {
eq_refs
} else {
eq_refs.inverted()
};
}
_ => {}
}
}
Refinements::empty()
}
fn extract_typeof_refinements(
op: &str,
left: &SNode,
right: &SNode,
scope: &TypeScope,
) -> Refinements {
let (var_name, type_name) = if let (Some(var), Node::StringLiteral(tn)) =
(extract_type_of_var(left), &right.node)
{
(var, tn.clone())
} else if let (Node::StringLiteral(tn), Some(var)) =
(&left.node, extract_type_of_var(right))
{
(var, tn.clone())
} else {
return Refinements::empty();
};
const KNOWN_TYPES: &[&str] = &[
"int", "string", "float", "bool", "nil", "list", "dict", "closure",
];
if !KNOWN_TYPES.contains(&type_name.as_str()) {
return Refinements::empty();
}
let var_type = scope.get_var(&var_name).cloned().flatten();
match var_type {
Some(TypeExpr::Union(ref members)) => {
let narrowed = narrow_to_single(members, &type_name);
let remaining = remove_from_union(members, &type_name);
if narrowed.is_some() || remaining.is_some() {
let eq_refs = Refinements {
truthy: narrowed
.map(|n| vec![(var_name.clone(), Some(n))])
.unwrap_or_default(),
falsy: remaining
.map(|r| vec![(var_name.clone(), Some(r))])
.unwrap_or_default(),
..Refinements::default()
};
return if op == "==" {
eq_refs
} else {
eq_refs.inverted()
};
}
}
Some(TypeExpr::Named(ref n)) if n == &type_name => {
let eq_refs = Refinements {
truthy: vec![(var_name.clone(), Some(TypeExpr::Named(type_name)))],
falsy: vec![(var_name.clone(), Some(TypeExpr::Never))],
..Refinements::default()
};
return if op == "==" {
eq_refs
} else {
eq_refs.inverted()
};
}
Some(TypeExpr::Named(ref n)) if n == "unknown" => {
let eq_refs = Refinements {
truthy: vec![(var_name.clone(), Some(TypeExpr::Named(type_name.clone())))],
falsy: vec![],
truthy_ruled_out: vec![],
falsy_ruled_out: vec![(var_name.clone(), type_name)],
};
return if op == "==" {
eq_refs
} else {
eq_refs.inverted()
};
}
_ => {}
}
Refinements::empty()
}
fn extract_has_refinements(object: &SNode, args: &[SNode], scope: &TypeScope) -> Refinements {
if let Node::Identifier(var_name) = &object.node {
if let Node::StringLiteral(key) = &args[0].node {
if let Some(Some(TypeExpr::Shape(fields))) = scope.get_var(var_name) {
if fields.iter().any(|f| f.name == *key && f.optional) {
let narrowed_fields: Vec<ShapeField> = fields
.iter()
.map(|f| {
if f.name == *key {
ShapeField {
name: f.name.clone(),
type_expr: f.type_expr.clone(),
optional: false,
}
} else {
f.clone()
}
})
.collect();
return Refinements {
truthy: vec![(
var_name.clone(),
Some(TypeExpr::Shape(narrowed_fields)),
)],
falsy: vec![],
..Refinements::default()
};
}
}
}
}
Refinements::empty()
}
fn extract_schema_refinements(args: &[SNode], scope: &TypeScope) -> Refinements {
let Node::Identifier(var_name) = &args[0].node else {
return Refinements::empty();
};
let Some(schema_type) = schema_type_expr_from_node(&args[1], scope) else {
return Refinements::empty();
};
let Some(Some(var_type)) = scope.get_var(var_name).cloned() else {
return Refinements::empty();
};
let truthy = intersect_types(&var_type, &schema_type)
.map(|ty| vec![(var_name.clone(), Some(ty))])
.unwrap_or_default();
let falsy = subtract_type(&var_type, &schema_type)
.map(|ty| vec![(var_name.clone(), Some(ty))])
.unwrap_or_default();
Refinements {
truthy,
falsy,
..Refinements::default()
}
}
fn block_definitely_exits(stmts: &[SNode]) -> bool {
block_definitely_exits(stmts)
}
fn check_match_exhaustiveness(
&mut self,
value: &SNode,
arms: &[MatchArm],
scope: &TypeScope,
span: Span,
) {
let enum_name = match &value.node {
Node::PropertyAccess { object, property } if property == "variant" => {
match self.infer_type(object, scope) {
Some(TypeExpr::Named(name)) => {
if scope.get_enum(&name).is_some() {
Some(name)
} else {
None
}
}
_ => None,
}
}
_ => {
match self.infer_type(value, scope) {
Some(TypeExpr::Named(name)) if scope.get_enum(&name).is_some() => Some(name),
_ => None,
}
}
};
let Some(enum_name) = enum_name else {
self.check_match_exhaustiveness_union(value, arms, scope, span);
return;
};
let Some(variants) = scope.get_enum(&enum_name) else {
return;
};
let mut covered: Vec<String> = Vec::new();
let mut has_wildcard = false;
for arm in arms {
match &arm.pattern.node {
Node::StringLiteral(s) => covered.push(s.clone()),
Node::Identifier(name)
if name == "_"
|| !variants
.variants
.iter()
.any(|variant| variant.name == *name) =>
{
has_wildcard = true;
}
Node::EnumConstruct { variant, .. } => covered.push(variant.clone()),
Node::PropertyAccess { property, .. } => covered.push(property.clone()),
_ => {
has_wildcard = true;
}
}
}
if has_wildcard {
return;
}
let missing: Vec<&String> = variants
.variants
.iter()
.map(|variant| &variant.name)
.filter(|variant| !covered.contains(variant))
.collect();
if !missing.is_empty() {
let missing_str = missing
.iter()
.map(|s| format!("\"{}\"", s))
.collect::<Vec<_>>()
.join(", ");
self.warning_at(
format!(
"Non-exhaustive match on enum {}: missing variants {}",
enum_name, missing_str
),
span,
);
}
}
fn check_match_exhaustiveness_union(
&mut self,
value: &SNode,
arms: &[MatchArm],
scope: &TypeScope,
span: Span,
) {
let Some(TypeExpr::Union(members)) = self.infer_type(value, scope) else {
return;
};
if !members.iter().all(|m| matches!(m, TypeExpr::Named(_))) {
return;
}
let mut has_wildcard = false;
let mut covered_types: Vec<String> = Vec::new();
for arm in arms {
match &arm.pattern.node {
Node::NilLiteral => covered_types.push("nil".into()),
Node::BoolLiteral(_) => {
if !covered_types.contains(&"bool".into()) {
covered_types.push("bool".into());
}
}
Node::IntLiteral(_) => {
if !covered_types.contains(&"int".into()) {
covered_types.push("int".into());
}
}
Node::FloatLiteral(_) => {
if !covered_types.contains(&"float".into()) {
covered_types.push("float".into());
}
}
Node::StringLiteral(_) => {
if !covered_types.contains(&"string".into()) {
covered_types.push("string".into());
}
}
Node::Identifier(name) if name == "_" => {
has_wildcard = true;
}
_ => {
has_wildcard = true;
}
}
}
if has_wildcard {
return;
}
let type_names: Vec<&str> = members
.iter()
.filter_map(|m| match m {
TypeExpr::Named(n) => Some(n.as_str()),
_ => None,
})
.collect();
let missing: Vec<&&str> = type_names
.iter()
.filter(|t| !covered_types.iter().any(|c| c == **t))
.collect();
if !missing.is_empty() {
let missing_str = missing
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ");
self.warning_at(
format!(
"Non-exhaustive match on union type: missing {}",
missing_str
),
span,
);
}
}
const UNKNOWN_CONCRETE_TYPES: &'static [&'static str] = &[
"int", "string", "float", "bool", "nil", "list", "dict", "closure",
];
fn check_unknown_exhaustiveness(&mut self, scope: &TypeScope, span: Span, site_label: &str) {
let entries = scope.collect_unknown_ruled_out();
for (var_name, covered) in entries {
if covered.is_empty() {
continue;
}
if !matches!(
scope.get_var(&var_name),
Some(Some(TypeExpr::Named(n))) if n == "unknown"
) {
continue;
}
let missing: Vec<&str> = Self::UNKNOWN_CONCRETE_TYPES
.iter()
.copied()
.filter(|t| !covered.iter().any(|c| c == t))
.collect();
if missing.is_empty() {
continue;
}
let missing_str = missing
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ");
self.warning_at(
format!(
"`{site}` reached but `{var}: unknown` was not fully narrowed — uncovered concrete type(s): {missing}",
site = site_label,
var = var_name,
missing = missing_str,
),
span,
);
}
}
fn check_call(&mut self, name: &str, args: &[SNode], scope: &mut TypeScope, span: Span) {
if name == "unreachable" {
if let Some(arg) = args.first() {
if matches!(&arg.node, Node::Identifier(_)) {
let arg_type = self.infer_type(arg, scope);
if let Some(ref ty) = arg_type {
if !matches!(ty, TypeExpr::Never) {
self.error_at(
format!(
"unreachable() argument has type `{}` — not all cases are handled",
format_type(ty)
),
span,
);
}
}
}
}
self.check_unknown_exhaustiveness(scope, span, "unreachable()");
for arg in args {
self.check_node(arg, scope);
}
return;
}
if let Some(sig) = scope.get_fn(name).cloned() {
if matches!(sig.return_type, Some(TypeExpr::Never)) {
self.check_unknown_exhaustiveness(scope, span, &format!("{}()", name));
}
}
let has_spread = args.iter().any(|a| matches!(&a.node, Node::Spread(_)));
if let Some(sig) = scope.get_fn(name).cloned() {
if !has_spread
&& !is_builtin(name)
&& !sig.has_rest
&& (args.len() < sig.required_params || args.len() > sig.params.len())
{
let expected = if sig.required_params == sig.params.len() {
format!("{}", sig.params.len())
} else {
format!("{}-{}", sig.required_params, sig.params.len())
};
self.warning_at(
format!(
"Function '{}' expects {} arguments, got {}",
name,
expected,
args.len()
),
span,
);
}
let call_scope = if sig.type_param_names.is_empty() {
scope.clone()
} else {
let mut s = scope.child();
for tp_name in &sig.type_param_names {
s.generic_type_params.insert(tp_name.clone());
}
s
};
let mut type_bindings: BTreeMap<String, TypeExpr> = BTreeMap::new();
let type_param_set: std::collections::BTreeSet<String> =
sig.type_param_names.iter().cloned().collect();
for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
if let Some(param_ty) = param_type {
if let Some(arg_ty) = self.infer_type(arg, scope) {
if let Err(message) = Self::extract_type_bindings(
param_ty,
&arg_ty,
&type_param_set,
&mut type_bindings,
) {
self.error_at(message, arg.span);
}
}
}
}
for (i, (arg, (param_name, param_type))) in
args.iter().zip(sig.params.iter()).enumerate()
{
if let Some(expected) = param_type {
let actual = self.infer_type(arg, scope);
if let Some(actual) = &actual {
let expected = Self::apply_type_bindings(expected, &type_bindings);
if !self.types_compatible(&expected, actual, &call_scope) {
self.error_at(
format!(
"Argument {} ('{}'): expected {}, got {}",
i + 1,
param_name,
format_type(&expected),
format_type(actual)
),
arg.span,
);
}
}
}
}
if !sig.where_clauses.is_empty() {
for (type_param, bound) in &sig.where_clauses {
if let Some(concrete_type) = type_bindings.get(type_param) {
let concrete_name = format_type(concrete_type);
let Some(base_type_name) = Self::base_type_name(concrete_type) else {
self.error_at(
format!(
"Type '{}' does not satisfy interface '{}': only named types can satisfy interfaces (required by constraint `where {}: {}`)",
concrete_name, bound, type_param, bound
),
span,
);
continue;
};
if let Some(reason) = self.interface_mismatch_reason(
base_type_name,
bound,
&BTreeMap::new(),
scope,
) {
self.error_at(
format!(
"Type '{}' does not satisfy interface '{}': {} \
(required by constraint `where {}: {}`)",
concrete_name, bound, reason, type_param, bound
),
span,
);
}
}
}
}
}
for arg in args {
self.check_node(arg, scope);
}
}
fn infer_type(&self, snode: &SNode, scope: &TypeScope) -> InferredType {
match &snode.node {
Node::IntLiteral(_) => Some(TypeExpr::Named("int".into())),
Node::FloatLiteral(_) => Some(TypeExpr::Named("float".into())),
Node::StringLiteral(_) | Node::InterpolatedString(_) => {
Some(TypeExpr::Named("string".into()))
}
Node::BoolLiteral(_) => Some(TypeExpr::Named("bool".into())),
Node::NilLiteral => Some(TypeExpr::Named("nil".into())),
Node::ListLiteral(items) => Some(self.infer_list_literal_type(items, scope)),
Node::RangeExpr { .. } => Some(TypeExpr::Named("range".into())),
Node::DictLiteral(entries) => {
let mut fields = Vec::new();
for entry in entries {
let key = match &entry.key.node {
Node::StringLiteral(key) | Node::Identifier(key) => key.clone(),
_ => return Some(TypeExpr::Named("dict".into())),
};
let val_type = self
.infer_type(&entry.value, scope)
.unwrap_or(TypeExpr::Named("nil".into()));
fields.push(ShapeField {
name: key,
type_expr: val_type,
optional: false,
});
}
if !fields.is_empty() {
Some(TypeExpr::Shape(fields))
} else {
Some(TypeExpr::Named("dict".into()))
}
}
Node::Closure { params, body, .. } => {
let all_typed = params.iter().all(|p| p.type_expr.is_some());
if all_typed && !params.is_empty() {
let param_types: Vec<TypeExpr> =
params.iter().filter_map(|p| p.type_expr.clone()).collect();
let ret = body.last().and_then(|last| self.infer_type(last, scope));
if let Some(ret_type) = ret {
return Some(TypeExpr::FnType {
params: param_types,
return_type: Box::new(ret_type),
});
}
}
Some(TypeExpr::Named("closure".into()))
}
Node::Identifier(name) => scope.get_var(name).cloned().flatten(),
Node::FunctionCall { name, args } => {
if let Some(struct_info) = scope.get_struct(name) {
return Some(Self::applied_type_or_name(
name,
struct_info
.type_params
.iter()
.map(|_| Self::wildcard_type())
.collect(),
));
}
if name == "Ok" {
let ok_type = args
.first()
.and_then(|arg| self.infer_type(arg, scope))
.unwrap_or_else(Self::wildcard_type);
return Some(TypeExpr::Applied {
name: "Result".into(),
args: vec![ok_type, Self::wildcard_type()],
});
}
if name == "Err" {
let err_type = args
.first()
.and_then(|arg| self.infer_type(arg, scope))
.unwrap_or_else(Self::wildcard_type);
return Some(TypeExpr::Applied {
name: "Result".into(),
args: vec![Self::wildcard_type(), err_type],
});
}
if let Some(sig) = scope.get_fn(name) {
let mut return_type = sig.return_type.clone();
if let Some(ty) = return_type.take() {
if sig.type_param_names.is_empty() {
return Some(ty);
}
let mut bindings = BTreeMap::new();
let type_param_set: std::collections::BTreeSet<String> =
sig.type_param_names.iter().cloned().collect();
for (arg, (_param_name, param_type)) in args.iter().zip(sig.params.iter()) {
if let Some(param_ty) = param_type {
if let Some(arg_ty) = self.infer_type(arg, scope) {
let _ = Self::extract_type_bindings(
param_ty,
&arg_ty,
&type_param_set,
&mut bindings,
);
}
}
}
return Some(Self::apply_type_bindings(&ty, &bindings));
}
return None;
}
if name == "schema_expect" && args.len() >= 2 {
if let Some(schema_type) = schema_type_expr_from_node(&args[1], scope) {
return Some(schema_type);
}
}
if (name == "schema_check" || name == "schema_parse") && args.len() >= 2 {
if let Some(schema_type) = schema_type_expr_from_node(&args[1], scope) {
return Some(TypeExpr::Applied {
name: "Result".into(),
args: vec![schema_type, TypeExpr::Named("string".into())],
});
}
}
if (name == "llm_call" || name == "llm_completion") && args.len() >= 3 {
if let Some(schema_type) = Self::extract_llm_schema_from_options(args, scope) {
return Some(TypeExpr::Shape(vec![
ShapeField {
name: "text".into(),
type_expr: TypeExpr::Named("string".into()),
optional: false,
},
ShapeField {
name: "model".into(),
type_expr: TypeExpr::Named("string".into()),
optional: false,
},
ShapeField {
name: "provider".into(),
type_expr: TypeExpr::Named("string".into()),
optional: false,
},
ShapeField {
name: "input_tokens".into(),
type_expr: TypeExpr::Named("int".into()),
optional: false,
},
ShapeField {
name: "output_tokens".into(),
type_expr: TypeExpr::Named("int".into()),
optional: false,
},
ShapeField {
name: "data".into(),
type_expr: schema_type,
optional: false,
},
ShapeField {
name: "visible_text".into(),
type_expr: TypeExpr::Named("string".into()),
optional: true,
},
ShapeField {
name: "tool_calls".into(),
type_expr: TypeExpr::Named("list".into()),
optional: true,
},
ShapeField {
name: "thinking".into(),
type_expr: TypeExpr::Named("string".into()),
optional: true,
},
ShapeField {
name: "stop_reason".into(),
type_expr: TypeExpr::Named("string".into()),
optional: true,
},
]));
}
}
builtin_return_type(name)
}
Node::BinaryOp { op, left, right } => {
let lt = self.infer_type(left, scope);
let rt = self.infer_type(right, scope);
infer_binary_op_type(op, <, &rt)
}
Node::UnaryOp { op, operand } => {
let t = self.infer_type(operand, scope);
match op.as_str() {
"!" => Some(TypeExpr::Named("bool".into())),
"-" => t, _ => None,
}
}
Node::Ternary {
condition,
true_expr,
false_expr,
} => {
let refs = Self::extract_refinements(condition, scope);
let mut true_scope = scope.child();
refs.apply_truthy(&mut true_scope);
let tt = self.infer_type(true_expr, &true_scope);
let mut false_scope = scope.child();
refs.apply_falsy(&mut false_scope);
let ft = self.infer_type(false_expr, &false_scope);
match (&tt, &ft) {
(Some(a), Some(b)) if a == b => tt,
(Some(a), Some(b)) => Some(TypeExpr::Union(vec![a.clone(), b.clone()])),
(Some(_), None) => tt,
(None, Some(_)) => ft,
(None, None) => None,
}
}
Node::EnumConstruct {
enum_name,
variant,
args,
} => {
if let Some(enum_info) = scope.get_enum(enum_name) {
Some(self.infer_enum_type(enum_name, enum_info, variant, args, scope))
} else {
Some(TypeExpr::Named(enum_name.clone()))
}
}
Node::PropertyAccess { object, property } => {
if let Node::Identifier(name) = &object.node {
if let Some(enum_info) = scope.get_enum(name) {
return Some(self.infer_enum_type(name, enum_info, property, &[], scope));
}
}
if property == "variant" {
let obj_type = self.infer_type(object, scope);
if let Some(name) = obj_type.as_ref().and_then(Self::base_type_name) {
if scope.get_enum(name).is_some() {
return Some(TypeExpr::Named("string".into()));
}
}
}
let obj_type = self.infer_type(object, scope);
if let Some(TypeExpr::Applied { name, args }) = &obj_type {
if name == "Pair" && args.len() == 2 {
if property == "first" {
return Some(args[0].clone());
} else if property == "second" {
return Some(args[1].clone());
}
}
}
if let Some(TypeExpr::Shape(fields)) = &obj_type {
if let Some(field) = fields.iter().find(|f| f.name == *property) {
return Some(field.type_expr.clone());
}
}
None
}
Node::SubscriptAccess { object, index } => {
let obj_type = self.infer_type(object, scope);
match &obj_type {
Some(TypeExpr::List(inner)) => Some(*inner.clone()),
Some(TypeExpr::DictType(_, v)) => Some(*v.clone()),
Some(TypeExpr::Shape(fields)) => {
if let Node::StringLiteral(key) = &index.node {
fields
.iter()
.find(|f| &f.name == key)
.map(|f| f.type_expr.clone())
} else {
None
}
}
Some(TypeExpr::Named(n)) if n == "list" => None,
Some(TypeExpr::Named(n)) if n == "dict" => None,
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
_ => None,
}
}
Node::SliceAccess { object, .. } => {
let obj_type = self.infer_type(object, scope);
match &obj_type {
Some(TypeExpr::List(_)) => obj_type,
Some(TypeExpr::Named(n)) if n == "list" => obj_type,
Some(TypeExpr::Named(n)) if n == "string" => {
Some(TypeExpr::Named("string".into()))
}
_ => None,
}
}
Node::MethodCall {
object,
method,
args,
}
| Node::OptionalMethodCall {
object,
method,
args,
} => {
if let Node::Identifier(name) = &object.node {
if let Some(enum_info) = scope.get_enum(name) {
return Some(self.infer_enum_type(name, enum_info, method, args, scope));
}
if name == "Result" && (method == "Ok" || method == "Err") {
let ok_type = if method == "Ok" {
args.first()
.and_then(|arg| self.infer_type(arg, scope))
.unwrap_or_else(Self::wildcard_type)
} else {
Self::wildcard_type()
};
let err_type = if method == "Err" {
args.first()
.and_then(|arg| self.infer_type(arg, scope))
.unwrap_or_else(Self::wildcard_type)
} else {
Self::wildcard_type()
};
return Some(TypeExpr::Applied {
name: "Result".into(),
args: vec![ok_type, err_type],
});
}
}
let obj_type = self.infer_type(object, scope);
let iter_elem_type: Option<TypeExpr> = match &obj_type {
Some(TypeExpr::Iter(inner)) => Some((**inner).clone()),
Some(TypeExpr::Named(n)) if n == "iter" => Some(TypeExpr::Named("any".into())),
_ => None,
};
if let Some(t) = iter_elem_type {
let pair = |k: TypeExpr, v: TypeExpr| TypeExpr::Applied {
name: "Pair".into(),
args: vec![k, v],
};
let iter_of = |ty: TypeExpr| TypeExpr::Iter(Box::new(ty));
match method.as_str() {
"iter" => return Some(iter_of(t)),
"map" | "flat_map" => {
return Some(TypeExpr::Named("iter".into()));
}
"filter" | "take" | "skip" | "take_while" | "skip_while" => {
return Some(iter_of(t));
}
"zip" => {
return Some(iter_of(pair(t, TypeExpr::Named("any".into()))));
}
"enumerate" => {
return Some(iter_of(pair(TypeExpr::Named("int".into()), t)));
}
"chain" => return Some(iter_of(t)),
"chunks" | "windows" => {
return Some(iter_of(TypeExpr::List(Box::new(t))));
}
"to_list" => return Some(TypeExpr::List(Box::new(t))),
"to_set" => {
return Some(TypeExpr::Applied {
name: "set".into(),
args: vec![t],
})
}
"to_dict" => return Some(TypeExpr::Named("dict".into())),
"count" => return Some(TypeExpr::Named("int".into())),
"sum" => {
return Some(TypeExpr::Union(vec![
TypeExpr::Named("int".into()),
TypeExpr::Named("float".into()),
]))
}
"min" | "max" | "first" | "last" | "find" => {
return Some(TypeExpr::Union(vec![t, TypeExpr::Named("nil".into())]));
}
"any" | "all" => return Some(TypeExpr::Named("bool".into())),
"for_each" => return Some(TypeExpr::Named("nil".into())),
"reduce" => return None,
_ => {}
}
}
if method == "iter" {
match &obj_type {
Some(TypeExpr::List(inner)) => {
return Some(TypeExpr::Iter(Box::new((**inner).clone())));
}
Some(TypeExpr::DictType(k, v)) => {
return Some(TypeExpr::Iter(Box::new(TypeExpr::Applied {
name: "Pair".into(),
args: vec![(**k).clone(), (**v).clone()],
})));
}
Some(TypeExpr::Named(n))
if n == "list" || n == "dict" || n == "set" || n == "string" =>
{
return Some(TypeExpr::Named("iter".into()));
}
_ => {}
}
}
let is_dict = matches!(&obj_type, Some(TypeExpr::Named(n)) if n == "dict")
|| matches!(&obj_type, Some(TypeExpr::DictType(..)))
|| matches!(&obj_type, Some(TypeExpr::Shape(_)));
match method.as_str() {
"contains" | "starts_with" | "ends_with" | "empty" | "has" | "any" | "all" => {
Some(TypeExpr::Named("bool".into()))
}
"count" | "index_of" => Some(TypeExpr::Named("int".into())),
"trim" | "lowercase" | "uppercase" | "reverse" | "replace" | "substring"
| "pad_left" | "pad_right" | "repeat" | "join" => {
Some(TypeExpr::Named("string".into()))
}
"split" | "chars" => Some(TypeExpr::Named("list".into())),
"filter" => {
if is_dict {
Some(TypeExpr::Named("dict".into()))
} else {
Some(TypeExpr::Named("list".into()))
}
}
"map" | "flat_map" | "sort" => Some(TypeExpr::Named("list".into())),
"window" | "each_cons" | "sliding_window" => match &obj_type {
Some(TypeExpr::List(inner)) => Some(TypeExpr::List(Box::new(
TypeExpr::List(Box::new((**inner).clone())),
))),
_ => Some(TypeExpr::Named("list".into())),
},
"reduce" | "find" | "first" | "last" => None,
"keys" | "values" | "entries" => Some(TypeExpr::Named("list".into())),
"merge" | "map_values" | "rekey" | "map_keys" => {
if let Some(TypeExpr::DictType(_, v)) = &obj_type {
Some(TypeExpr::DictType(
Box::new(TypeExpr::Named("string".into())),
v.clone(),
))
} else {
Some(TypeExpr::Named("dict".into()))
}
}
"to_string" => Some(TypeExpr::Named("string".into())),
"to_int" => Some(TypeExpr::Named("int".into())),
"to_float" => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
Node::TryOperator { operand } => match self.infer_type(operand, scope) {
Some(TypeExpr::Applied { name, args }) if name == "Result" && args.len() == 2 => {
Some(args[0].clone())
}
Some(TypeExpr::Named(name)) if name == "Result" => None,
_ => None,
},
Node::ThrowStmt { .. }
| Node::ReturnStmt { .. }
| Node::BreakStmt
| Node::ContinueStmt => Some(TypeExpr::Never),
Node::IfElse {
then_body,
else_body,
..
} => {
let then_type = self.infer_block_type(then_body, scope);
let else_type = else_body
.as_ref()
.and_then(|eb| self.infer_block_type(eb, scope));
match (then_type, else_type) {
(Some(TypeExpr::Never), Some(TypeExpr::Never)) => Some(TypeExpr::Never),
(Some(TypeExpr::Never), Some(other)) | (Some(other), Some(TypeExpr::Never)) => {
Some(other)
}
(Some(t), Some(e)) if t == e => Some(t),
(Some(t), Some(e)) => Some(simplify_union(vec![t, e])),
(Some(t), None) => Some(t),
(None, _) => None,
}
}
Node::TryExpr { body } => {
let ok_type = self
.infer_block_type(body, scope)
.unwrap_or_else(Self::wildcard_type);
let err_type = self
.infer_try_error_type(body, scope)
.unwrap_or_else(Self::wildcard_type);
Some(TypeExpr::Applied {
name: "Result".into(),
args: vec![ok_type, err_type],
})
}
Node::TryStar { operand } => self.infer_type(operand, scope),
Node::StructConstruct {
struct_name,
fields,
} => scope
.get_struct(struct_name)
.map(|struct_info| self.infer_struct_type(struct_name, struct_info, fields, scope))
.or_else(|| Some(TypeExpr::Named(struct_name.clone()))),
_ => None,
}
}
fn infer_block_type(&self, stmts: &[SNode], scope: &TypeScope) -> InferredType {
if Self::block_definitely_exits(stmts) {
return Some(TypeExpr::Never);
}
stmts.last().and_then(|s| self.infer_type(s, scope))
}
fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr, scope: &TypeScope) -> bool {
self.types_compatible_at(Polarity::Covariant, expected, actual, scope)
}
fn types_compatible_at(
&self,
polarity: Polarity,
expected: &TypeExpr,
actual: &TypeExpr,
scope: &TypeScope,
) -> bool {
match polarity {
Polarity::Covariant => {}
Polarity::Contravariant => {
return self.types_compatible_at(Polarity::Covariant, actual, expected, scope);
}
Polarity::Invariant => {
return self.types_compatible_at(Polarity::Covariant, expected, actual, scope)
&& self.types_compatible_at(Polarity::Covariant, actual, expected, scope);
}
}
if Self::is_wildcard_type(expected) || Self::is_wildcard_type(actual) {
return true;
}
if let TypeExpr::Named(name) = expected {
if scope.is_generic_type_param(name) {
return true;
}
}
if let TypeExpr::Named(name) = actual {
if scope.is_generic_type_param(name) {
return true;
}
}
let expected = self.resolve_alias(expected, scope);
let actual = self.resolve_alias(actual, scope);
if let Some(iface_name) = Self::base_type_name(&expected) {
if let Some(interface_info) = scope.get_interface(iface_name) {
let mut interface_bindings = BTreeMap::new();
if let TypeExpr::Applied { args, .. } = &expected {
for (type_param, arg) in interface_info.type_params.iter().zip(args.iter()) {
interface_bindings.insert(type_param.name.clone(), arg.clone());
}
}
if let Some(type_name) = Self::base_type_name(&actual) {
return self.satisfies_interface(
type_name,
iface_name,
&interface_bindings,
scope,
);
}
return false;
}
}
match (&expected, &actual) {
(_, TypeExpr::Never) => true,
(TypeExpr::Never, _) => false,
(TypeExpr::Named(n), _) if n == "any" => true,
(_, TypeExpr::Named(n)) if n == "any" => true,
(TypeExpr::Named(n), _) if n == "unknown" => true,
(TypeExpr::Named(a), TypeExpr::Named(b)) => a == b || (a == "float" && b == "int"),
(TypeExpr::Named(a), TypeExpr::Applied { name: b, .. })
| (TypeExpr::Applied { name: a, .. }, TypeExpr::Named(b)) => a == b,
(
TypeExpr::Applied {
name: expected_name,
args: expected_args,
},
TypeExpr::Applied {
name: actual_name,
args: actual_args,
},
) => {
if expected_name != actual_name || expected_args.len() != actual_args.len() {
return false;
}
let variances = scope.variance_of(expected_name);
for (idx, (expected_arg, actual_arg)) in
expected_args.iter().zip(actual_args.iter()).enumerate()
{
let child_variance = variances
.as_ref()
.and_then(|v| v.get(idx).copied())
.unwrap_or(Variance::Invariant);
let arg_polarity = Polarity::Covariant.compose(child_variance);
if !self.types_compatible_at(arg_polarity, expected_arg, actual_arg, scope) {
return false;
}
}
true
}
(TypeExpr::Union(exp_members), TypeExpr::Union(act_members)) => {
act_members.iter().all(|am| {
exp_members
.iter()
.any(|em| self.types_compatible(em, am, scope))
})
}
(TypeExpr::Union(members), actual_type) => members
.iter()
.any(|m| self.types_compatible(m, actual_type, scope)),
(expected_type, TypeExpr::Union(members)) => members
.iter()
.all(|m| self.types_compatible(expected_type, m, scope)),
(TypeExpr::Shape(_), TypeExpr::Named(n)) if n == "dict" => true,
(TypeExpr::Named(n), TypeExpr::Shape(_)) if n == "dict" => true,
(TypeExpr::Shape(ef), TypeExpr::Shape(af)) => ef.iter().all(|expected_field| {
if expected_field.optional {
return true;
}
af.iter().any(|actual_field| {
actual_field.name == expected_field.name
&& self.types_compatible(
&expected_field.type_expr,
&actual_field.type_expr,
scope,
)
})
}),
(TypeExpr::DictType(ek, ev), TypeExpr::Shape(af)) => {
let keys_ok = matches!(ek.as_ref(), TypeExpr::Named(n) if n == "string");
keys_ok
&& af
.iter()
.all(|f| self.types_compatible(ev, &f.type_expr, scope))
}
(TypeExpr::Shape(_), TypeExpr::DictType(_, _)) => true,
(TypeExpr::List(expected_inner), TypeExpr::List(actual_inner)) => {
self.types_compatible_at(Polarity::Invariant, expected_inner, actual_inner, scope)
}
(TypeExpr::Named(n), TypeExpr::List(_)) if n == "list" => true,
(TypeExpr::List(_), TypeExpr::Named(n)) if n == "list" => true,
(TypeExpr::Iter(expected_inner), TypeExpr::Iter(actual_inner)) => {
self.types_compatible(expected_inner, actual_inner, scope)
}
(TypeExpr::Named(n), TypeExpr::Iter(_)) if n == "iter" => true,
(TypeExpr::Iter(_), TypeExpr::Named(n)) if n == "iter" => true,
(TypeExpr::DictType(ek, ev), TypeExpr::DictType(ak, av)) => {
self.types_compatible_at(Polarity::Invariant, ek, ak, scope)
&& self.types_compatible_at(Polarity::Invariant, ev, av, scope)
}
(TypeExpr::Named(n), TypeExpr::DictType(_, _)) if n == "dict" => true,
(TypeExpr::DictType(_, _), TypeExpr::Named(n)) if n == "dict" => true,
(
TypeExpr::FnType {
params: ep,
return_type: er,
},
TypeExpr::FnType {
params: ap,
return_type: ar,
},
) => {
ep.len() == ap.len()
&& ep.iter().zip(ap.iter()).all(|(e, a)| {
self.types_compatible_at(Polarity::Contravariant, e, a, scope)
})
&& self.types_compatible(er, ar, scope)
}
(TypeExpr::FnType { .. }, TypeExpr::Named(n)) if n == "closure" => true,
(TypeExpr::Named(n), TypeExpr::FnType { .. }) if n == "closure" => true,
(TypeExpr::LitString(a), TypeExpr::LitString(b)) => a == b,
(TypeExpr::LitInt(a), TypeExpr::LitInt(b)) => a == b,
(TypeExpr::Named(n), TypeExpr::LitString(_)) if n == "string" => true,
(TypeExpr::Named(n), TypeExpr::LitInt(_)) if n == "int" || n == "float" => true,
(TypeExpr::LitString(_), TypeExpr::Named(n)) if n == "string" => true,
(TypeExpr::LitInt(_), TypeExpr::Named(n)) if n == "int" => true,
_ => false,
}
}
fn resolve_alias<'a>(&self, ty: &'a TypeExpr, scope: &'a TypeScope) -> TypeExpr {
match ty {
TypeExpr::Named(name) => {
if let Some(resolved) = scope.resolve_type(name) {
return self.resolve_alias(resolved, scope);
}
ty.clone()
}
TypeExpr::Union(types) => TypeExpr::Union(
types
.iter()
.map(|ty| self.resolve_alias(ty, scope))
.collect(),
),
TypeExpr::Shape(fields) => TypeExpr::Shape(
fields
.iter()
.map(|field| ShapeField {
name: field.name.clone(),
type_expr: self.resolve_alias(&field.type_expr, scope),
optional: field.optional,
})
.collect(),
),
TypeExpr::List(inner) => TypeExpr::List(Box::new(self.resolve_alias(inner, scope))),
TypeExpr::Iter(inner) => TypeExpr::Iter(Box::new(self.resolve_alias(inner, scope))),
TypeExpr::DictType(key, value) => TypeExpr::DictType(
Box::new(self.resolve_alias(key, scope)),
Box::new(self.resolve_alias(value, scope)),
),
TypeExpr::FnType {
params,
return_type,
} => TypeExpr::FnType {
params: params
.iter()
.map(|param| self.resolve_alias(param, scope))
.collect(),
return_type: Box::new(self.resolve_alias(return_type, scope)),
},
TypeExpr::Applied { name, args } => TypeExpr::Applied {
name: name.clone(),
args: args
.iter()
.map(|arg| self.resolve_alias(arg, scope))
.collect(),
},
TypeExpr::Never => TypeExpr::Never,
TypeExpr::LitString(s) => TypeExpr::LitString(s.clone()),
TypeExpr::LitInt(v) => TypeExpr::LitInt(*v),
}
}
fn check_fn_decl_variance(
&mut self,
type_params: &[TypeParam],
params: &[TypedParam],
return_type: Option<&TypeExpr>,
name: &str,
span: Span,
) {
let mut positions: Vec<(&TypeExpr, Polarity)> = Vec::new();
for p in params {
if let Some(te) = &p.type_expr {
positions.push((te, Polarity::Contravariant));
}
}
if let Some(rt) = return_type {
positions.push((rt, Polarity::Covariant));
}
let kind = format!("function '{name}'");
self.check_decl_variance(&kind, type_params, &positions, span);
}
fn check_type_alias_decl_variance(
&mut self,
type_params: &[TypeParam],
type_expr: &TypeExpr,
name: &str,
span: Span,
) {
let positions = [(type_expr, Polarity::Covariant)];
let kind = format!("type alias '{name}'");
self.check_decl_variance(&kind, type_params, &positions, span);
}
fn check_enum_decl_variance(
&mut self,
type_params: &[TypeParam],
variants: &[EnumVariant],
name: &str,
span: Span,
) {
let mut positions: Vec<(&TypeExpr, Polarity)> = Vec::new();
for variant in variants {
for field in &variant.fields {
if let Some(te) = &field.type_expr {
positions.push((te, Polarity::Covariant));
}
}
}
let kind = format!("enum '{name}'");
self.check_decl_variance(&kind, type_params, &positions, span);
}
fn check_struct_decl_variance(
&mut self,
type_params: &[TypeParam],
fields: &[StructField],
name: &str,
span: Span,
) {
let positions: Vec<(&TypeExpr, Polarity)> = fields
.iter()
.filter_map(|f| f.type_expr.as_ref().map(|te| (te, Polarity::Invariant)))
.collect();
let kind = format!("struct '{name}'");
self.check_decl_variance(&kind, type_params, &positions, span);
}
fn check_interface_decl_variance(
&mut self,
type_params: &[TypeParam],
methods: &[InterfaceMethod],
name: &str,
span: Span,
) {
let mut positions: Vec<(&TypeExpr, Polarity)> = Vec::new();
for method in methods {
for p in &method.params {
if let Some(te) = &p.type_expr {
positions.push((te, Polarity::Contravariant));
}
}
if let Some(rt) = &method.return_type {
positions.push((rt, Polarity::Covariant));
}
}
let kind = format!("interface '{name}'");
self.check_decl_variance(&kind, type_params, &positions, span);
}
fn check_decl_variance(
&mut self,
decl_kind: &str,
type_params: &[TypeParam],
positions: &[(&TypeExpr, Polarity)],
span: Span,
) {
if type_params
.iter()
.all(|tp| tp.variance == Variance::Invariant)
{
return;
}
let declared: BTreeMap<String, Variance> = type_params
.iter()
.map(|tp| (tp.name.clone(), tp.variance))
.collect();
for (ty, polarity) in positions {
self.walk_variance(decl_kind, ty, *polarity, &declared, span);
}
}
#[allow(clippy::only_used_in_recursion)]
fn walk_variance(
&mut self,
decl_kind: &str,
ty: &TypeExpr,
polarity: Polarity,
declared: &BTreeMap<String, Variance>,
span: Span,
) {
match ty {
TypeExpr::Named(name) => {
if let Some(&declared_variance) = declared.get(name) {
let ok = matches!(
(declared_variance, polarity),
(Variance::Invariant, _)
| (Variance::Covariant, Polarity::Covariant)
| (Variance::Contravariant, Polarity::Contravariant)
);
if !ok {
let (marker, declared_word) = match declared_variance {
Variance::Covariant => ("out", "covariant"),
Variance::Contravariant => ("in", "contravariant"),
Variance::Invariant => unreachable!(),
};
let position_word = match polarity {
Polarity::Covariant => "covariant",
Polarity::Contravariant => "contravariant",
Polarity::Invariant => "invariant",
};
self.error_at(
format!(
"type parameter '{name}' is declared '{marker}' \
({declared_word}) but appears in a \
{position_word} position in {decl_kind}"
),
span,
);
}
}
}
TypeExpr::List(inner) | TypeExpr::Iter(inner) => {
let sub = match ty {
TypeExpr::List(_) => Polarity::Invariant,
TypeExpr::Iter(_) => polarity,
_ => unreachable!(),
};
self.walk_variance(decl_kind, inner, sub, declared, span);
}
TypeExpr::DictType(k, v) => {
self.walk_variance(decl_kind, k, Polarity::Invariant, declared, span);
self.walk_variance(decl_kind, v, Polarity::Invariant, declared, span);
}
TypeExpr::Shape(fields) => {
for f in fields {
self.walk_variance(decl_kind, &f.type_expr, polarity, declared, span);
}
}
TypeExpr::Union(members) => {
for m in members {
self.walk_variance(decl_kind, m, polarity, declared, span);
}
}
TypeExpr::FnType {
params,
return_type,
} => {
let param_polarity = polarity.compose(Variance::Contravariant);
for p in params {
self.walk_variance(decl_kind, p, param_polarity, declared, span);
}
self.walk_variance(decl_kind, return_type, polarity, declared, span);
}
TypeExpr::Applied { name, args } => {
let variances: Option<Vec<Variance>> = self
.scope
.get_enum(name)
.map(|info| info.type_params.iter().map(|tp| tp.variance).collect())
.or_else(|| {
self.scope
.get_struct(name)
.map(|info| info.type_params.iter().map(|tp| tp.variance).collect())
})
.or_else(|| {
self.scope
.get_interface(name)
.map(|info| info.type_params.iter().map(|tp| tp.variance).collect())
});
for (idx, arg) in args.iter().enumerate() {
let child_variance = variances
.as_ref()
.and_then(|v| v.get(idx).copied())
.unwrap_or(Variance::Invariant);
let sub = polarity.compose(child_variance);
self.walk_variance(decl_kind, arg, sub, declared, span);
}
}
TypeExpr::Never | TypeExpr::LitString(_) | TypeExpr::LitInt(_) => {}
}
}
fn error_at(&mut self, message: String, span: Span) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Error,
span: Some(span),
help: None,
fix: None,
});
}
#[allow(dead_code)]
fn error_at_with_help(&mut self, message: String, span: Span, help: String) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Error,
span: Some(span),
help: Some(help),
fix: None,
});
}
fn error_at_with_fix(&mut self, message: String, span: Span, fix: Vec<FixEdit>) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Error,
span: Some(span),
help: None,
fix: Some(fix),
});
}
fn warning_at(&mut self, message: String, span: Span) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Warning,
span: Some(span),
help: None,
fix: None,
});
}
#[allow(dead_code)]
fn warning_at_with_help(&mut self, message: String, span: Span, help: String) {
self.diagnostics.push(TypeDiagnostic {
message,
severity: DiagnosticSeverity::Warning,
span: Some(span),
help: Some(help),
fix: None,
});
}
fn check_binops(&mut self, snode: &SNode, scope: &mut TypeScope) {
match &snode.node {
Node::BinaryOp { op, left, right } => {
self.check_binops(left, scope);
self.check_binops(right, scope);
let lt = self.infer_type(left, scope);
let rt = self.infer_type(right, scope);
if let (Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) = (<, &rt) {
let span = snode.span;
match op.as_str() {
"+" => {
let valid = matches!(
(l.as_str(), r.as_str()),
("int" | "float", "int" | "float")
| ("string", "string")
| ("list", "list")
| ("dict", "dict")
);
if !valid {
let msg = format!("can't add {} and {}", l, r);
let fix = if l == "string" || r == "string" {
self.build_interpolation_fix(left, right, l == "string", span)
} else {
None
};
if let Some(fix) = fix {
self.error_at_with_fix(msg, span, fix);
} else {
self.error_at(msg, span);
}
}
}
"-" | "/" | "%" | "**" => {
let numeric = ["int", "float"];
if !numeric.contains(&l.as_str()) || !numeric.contains(&r.as_str()) {
self.error_at(
format!(
"can't use '{}' on {} and {} (needs numeric operands)",
op, l, r
),
span,
);
}
}
"*" => {
let numeric = ["int", "float"];
let is_numeric =
numeric.contains(&l.as_str()) && numeric.contains(&r.as_str());
let is_string_repeat =
(l == "string" && r == "int") || (l == "int" && r == "string");
if !is_numeric && !is_string_repeat {
self.error_at(
format!("can't multiply {} and {} (try string * int)", l, r),
span,
);
}
}
_ => {}
}
}
}
Node::UnaryOp { operand, .. } => self.check_binops(operand, scope),
_ => {}
}
}
fn build_interpolation_fix(
&self,
left: &SNode,
right: &SNode,
left_is_string: bool,
expr_span: Span,
) -> Option<Vec<FixEdit>> {
let src = self.source.as_ref()?;
let (str_node, other_node) = if left_is_string {
(left, right)
} else {
(right, left)
};
let str_text = src.get(str_node.span.start..str_node.span.end)?;
let other_text = src.get(other_node.span.start..other_node.span.end)?;
let inner = str_text.strip_prefix('"')?.strip_suffix('"')?;
if other_text.contains('}') || other_text.contains('"') {
return None;
}
let replacement = if left_is_string {
format!("\"{inner}${{{other_text}}}\"")
} else {
format!("\"${{{other_text}}}{inner}\"")
};
Some(vec![FixEdit {
span: expr_span,
replacement,
}])
}
}
impl Default for TypeChecker {
fn default() -> Self {
Self::new()
}
}
fn infer_binary_op_type(op: &str, left: &InferredType, right: &InferredType) -> InferredType {
match op {
"==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
Some(TypeExpr::Named("bool".into()))
}
"+" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
("string", "string") => Some(TypeExpr::Named("string".into())),
("list", "list") => Some(TypeExpr::Named("list".into())),
("dict", "dict") => Some(TypeExpr::Named("dict".into())),
_ => None,
}
}
_ => None,
},
"-" | "/" | "%" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"**" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"*" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("string", "int") | ("int", "string") => Some(TypeExpr::Named("string".into())),
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"??" => match (left, right) {
(Some(TypeExpr::Union(members)), _) => {
let non_nil: Vec<_> = members
.iter()
.filter(|m| !matches!(m, TypeExpr::Named(n) if n == "nil"))
.cloned()
.collect();
if non_nil.len() == 1 {
Some(non_nil[0].clone())
} else if non_nil.is_empty() {
right.clone()
} else {
Some(TypeExpr::Union(non_nil))
}
}
(Some(TypeExpr::Named(n)), _) if n == "nil" => right.clone(),
(Some(l), _) => Some(l.clone()),
(None, _) => right.clone(),
},
"|>" => None,
_ => None,
}
}
pub fn shape_mismatch_detail(expected: &TypeExpr, actual: &TypeExpr) -> Option<String> {
if let (TypeExpr::Shape(ef), TypeExpr::Shape(af)) = (expected, actual) {
let mut details = Vec::new();
for field in ef {
if field.optional {
continue;
}
match af.iter().find(|f| f.name == field.name) {
None => details.push(format!(
"missing field '{}' ({})",
field.name,
format_type(&field.type_expr)
)),
Some(actual_field) => {
let e_str = format_type(&field.type_expr);
let a_str = format_type(&actual_field.type_expr);
if e_str != a_str {
details.push(format!(
"field '{}' has type {}, expected {}",
field.name, a_str, e_str
));
}
}
}
}
if details.is_empty() {
None
} else {
Some(details.join("; "))
}
} else {
None
}
}
fn is_obvious_type(value: &SNode, _ty: &TypeExpr) -> bool {
matches!(
&value.node,
Node::IntLiteral(_)
| Node::FloatLiteral(_)
| Node::StringLiteral(_)
| Node::BoolLiteral(_)
| Node::NilLiteral
| Node::ListLiteral(_)
| Node::DictLiteral(_)
| Node::InterpolatedString(_)
)
}
pub fn stmt_definitely_exits(stmt: &SNode) -> bool {
match &stmt.node {
Node::ReturnStmt { .. } | Node::ThrowStmt { .. } | Node::BreakStmt | Node::ContinueStmt => {
true
}
Node::IfElse {
then_body,
else_body: Some(else_body),
..
} => block_definitely_exits(then_body) && block_definitely_exits(else_body),
_ => false,
}
}
pub fn block_definitely_exits(stmts: &[SNode]) -> bool {
stmts.iter().any(stmt_definitely_exits)
}
pub fn format_type(ty: &TypeExpr) -> String {
match ty {
TypeExpr::Named(n) => n.clone(),
TypeExpr::Union(types) => types
.iter()
.map(format_type)
.collect::<Vec<_>>()
.join(" | "),
TypeExpr::Shape(fields) => {
let inner: Vec<String> = fields
.iter()
.map(|f| {
let opt = if f.optional { "?" } else { "" };
format!("{}{opt}: {}", f.name, format_type(&f.type_expr))
})
.collect();
format!("{{{}}}", inner.join(", "))
}
TypeExpr::List(inner) => format!("list<{}>", format_type(inner)),
TypeExpr::Iter(inner) => format!("iter<{}>", format_type(inner)),
TypeExpr::DictType(k, v) => format!("dict<{}, {}>", format_type(k), format_type(v)),
TypeExpr::Applied { name, args } => {
let args_str = args.iter().map(format_type).collect::<Vec<_>>().join(", ");
format!("{name}<{args_str}>")
}
TypeExpr::FnType {
params,
return_type,
} => {
let params_str = params
.iter()
.map(format_type)
.collect::<Vec<_>>()
.join(", ");
format!("fn({}) -> {}", params_str, format_type(return_type))
}
TypeExpr::Never => "never".to_string(),
TypeExpr::LitString(s) => format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\"")),
TypeExpr::LitInt(v) => v.to_string(),
}
}
fn simplify_union(members: Vec<TypeExpr>) -> TypeExpr {
let filtered: Vec<TypeExpr> = members
.into_iter()
.filter(|m| !matches!(m, TypeExpr::Never))
.collect();
match filtered.len() {
0 => TypeExpr::Never,
1 => filtered.into_iter().next().unwrap(),
_ => TypeExpr::Union(filtered),
}
}
fn remove_from_union(members: &[TypeExpr], to_remove: &str) -> InferredType {
let remaining: Vec<TypeExpr> = members
.iter()
.filter(|m| !matches!(m, TypeExpr::Named(n) if n == to_remove))
.cloned()
.collect();
match remaining.len() {
0 => Some(TypeExpr::Never),
1 => Some(remaining.into_iter().next().unwrap()),
_ => Some(TypeExpr::Union(remaining)),
}
}
fn narrow_to_single(members: &[TypeExpr], target: &str) -> InferredType {
if members
.iter()
.any(|m| matches!(m, TypeExpr::Named(n) if n == target))
{
Some(TypeExpr::Named(target.to_string()))
} else {
None
}
}
fn extract_type_of_var(node: &SNode) -> Option<String> {
if let Node::FunctionCall { name, args } = &node.node {
if name == "type_of" && args.len() == 1 {
if let Node::Identifier(var) = &args[0].node {
return Some(var.clone());
}
}
}
None
}
fn schema_type_expr_from_node(node: &SNode, scope: &TypeScope) -> Option<TypeExpr> {
match &node.node {
Node::Identifier(name) => {
if let Some(schema) = scope.get_schema_binding(name).cloned().flatten() {
return Some(schema);
}
scope.resolve_type(name).cloned()
}
Node::DictLiteral(entries) => schema_type_expr_from_dict(entries, scope),
Node::FunctionCall { name, args } if name == "schema_of" && args.len() == 1 => {
if let Node::Identifier(alias) = &args[0].node {
return scope.resolve_type(alias).cloned();
}
None
}
_ => None,
}
}
fn schema_type_expr_from_dict(entries: &[DictEntry], scope: &TypeScope) -> Option<TypeExpr> {
let mut type_name: Option<String> = None;
let mut properties: Option<&SNode> = None;
let mut required: Option<Vec<String>> = None;
let mut items: Option<&SNode> = None;
let mut union: Option<&SNode> = None;
let mut nullable = false;
let mut additional_properties: Option<&SNode> = None;
for entry in entries {
let key = schema_entry_key(&entry.key)?;
match key.as_str() {
"type" => match &entry.value.node {
Node::StringLiteral(text) | Node::RawStringLiteral(text) => {
type_name = Some(normalize_schema_type_name(text));
}
Node::ListLiteral(items_list) => {
let union_members = items_list
.iter()
.filter_map(|item| match &item.node {
Node::StringLiteral(text) | Node::RawStringLiteral(text) => {
Some(TypeExpr::Named(normalize_schema_type_name(text)))
}
_ => None,
})
.collect::<Vec<_>>();
if !union_members.is_empty() {
return Some(TypeExpr::Union(union_members));
}
}
_ => {}
},
"properties" => properties = Some(&entry.value),
"required" => {
required = schema_required_names(&entry.value);
}
"items" => items = Some(&entry.value),
"union" | "oneOf" | "anyOf" => union = Some(&entry.value),
"nullable" => {
nullable = matches!(entry.value.node, Node::BoolLiteral(true));
}
"additional_properties" | "additionalProperties" => {
additional_properties = Some(&entry.value);
}
_ => {}
}
}
let mut schema_type = if let Some(union_node) = union {
schema_union_type_expr(union_node, scope)?
} else if let Some(properties_node) = properties {
let property_entries = match &properties_node.node {
Node::DictLiteral(entries) => entries,
_ => return None,
};
let required_names = required.unwrap_or_default();
let mut fields = Vec::new();
for entry in property_entries {
let field_name = schema_entry_key(&entry.key)?;
let field_type = schema_type_expr_from_node(&entry.value, scope)?;
fields.push(ShapeField {
name: field_name.clone(),
type_expr: field_type,
optional: !required_names.contains(&field_name),
});
}
TypeExpr::Shape(fields)
} else if let Some(item_node) = items {
TypeExpr::List(Box::new(schema_type_expr_from_node(item_node, scope)?))
} else if let Some(type_name) = type_name {
if type_name == "dict" {
if let Some(extra_node) = additional_properties {
let value_type = match &extra_node.node {
Node::BoolLiteral(_) => None,
_ => schema_type_expr_from_node(extra_node, scope),
};
if let Some(value_type) = value_type {
TypeExpr::DictType(
Box::new(TypeExpr::Named("string".into())),
Box::new(value_type),
)
} else {
TypeExpr::Named(type_name)
}
} else {
TypeExpr::Named(type_name)
}
} else {
TypeExpr::Named(type_name)
}
} else {
return None;
};
if nullable {
schema_type = match schema_type {
TypeExpr::Union(mut members) => {
if !members
.iter()
.any(|member| matches!(member, TypeExpr::Named(name) if name == "nil"))
{
members.push(TypeExpr::Named("nil".into()));
}
TypeExpr::Union(members)
}
other => TypeExpr::Union(vec![other, TypeExpr::Named("nil".into())]),
};
}
Some(schema_type)
}
fn schema_union_type_expr(node: &SNode, scope: &TypeScope) -> Option<TypeExpr> {
let Node::ListLiteral(items) = &node.node else {
return None;
};
let members = items
.iter()
.filter_map(|item| schema_type_expr_from_node(item, scope))
.collect::<Vec<_>>();
match members.len() {
0 => None,
1 => members.into_iter().next(),
_ => Some(TypeExpr::Union(members)),
}
}
fn schema_required_names(node: &SNode) -> Option<Vec<String>> {
let Node::ListLiteral(items) = &node.node else {
return None;
};
Some(
items
.iter()
.filter_map(|item| match &item.node {
Node::StringLiteral(text) | Node::RawStringLiteral(text) => Some(text.clone()),
Node::Identifier(text) => Some(text.clone()),
_ => None,
})
.collect(),
)
}
fn schema_entry_key(node: &SNode) -> Option<String> {
match &node.node {
Node::Identifier(name) => Some(name.clone()),
Node::StringLiteral(name) | Node::RawStringLiteral(name) => Some(name.clone()),
_ => None,
}
}
fn normalize_schema_type_name(text: &str) -> String {
match text {
"object" => "dict".into(),
"array" => "list".into(),
"integer" => "int".into(),
"number" => "float".into(),
"boolean" => "bool".into(),
"null" => "nil".into(),
other => other.into(),
}
}
fn intersect_types(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match (current, schema_type) {
(TypeExpr::LitString(a), TypeExpr::LitString(b)) if a == b => {
Some(TypeExpr::LitString(a.clone()))
}
(TypeExpr::LitInt(a), TypeExpr::LitInt(b)) if a == b => Some(TypeExpr::LitInt(*a)),
(TypeExpr::LitString(s), TypeExpr::Named(n))
| (TypeExpr::Named(n), TypeExpr::LitString(s))
if n == "string" =>
{
Some(TypeExpr::LitString(s.clone()))
}
(TypeExpr::LitInt(v), TypeExpr::Named(n)) | (TypeExpr::Named(n), TypeExpr::LitInt(v))
if n == "int" || n == "float" =>
{
Some(TypeExpr::LitInt(*v))
}
(TypeExpr::Union(members), other) => {
let kept = members
.iter()
.filter_map(|member| intersect_types(member, other))
.collect::<Vec<_>>();
match kept.len() {
0 => None,
1 => kept.into_iter().next(),
_ => Some(TypeExpr::Union(kept)),
}
}
(other, TypeExpr::Union(members)) => {
let kept = members
.iter()
.filter_map(|member| intersect_types(other, member))
.collect::<Vec<_>>();
match kept.len() {
0 => None,
1 => kept.into_iter().next(),
_ => Some(TypeExpr::Union(kept)),
}
}
(TypeExpr::Named(left), TypeExpr::Named(right)) if left == right => {
Some(TypeExpr::Named(left.clone()))
}
(TypeExpr::Named(name), TypeExpr::Shape(fields)) if name == "dict" => {
Some(TypeExpr::Shape(fields.clone()))
}
(TypeExpr::Shape(fields), TypeExpr::Named(name)) if name == "dict" => {
Some(TypeExpr::Shape(fields.clone()))
}
(TypeExpr::Named(name), TypeExpr::List(inner)) if name == "list" => {
Some(TypeExpr::List(inner.clone()))
}
(TypeExpr::List(inner), TypeExpr::Named(name)) if name == "list" => {
Some(TypeExpr::List(inner.clone()))
}
(TypeExpr::Named(name), TypeExpr::DictType(key, value)) if name == "dict" => {
Some(TypeExpr::DictType(key.clone(), value.clone()))
}
(TypeExpr::DictType(key, value), TypeExpr::Named(name)) if name == "dict" => {
Some(TypeExpr::DictType(key.clone(), value.clone()))
}
(TypeExpr::Shape(_), TypeExpr::Shape(fields)) => Some(TypeExpr::Shape(fields.clone())),
(TypeExpr::List(current_inner), TypeExpr::List(schema_inner)) => {
intersect_types(current_inner, schema_inner)
.map(|inner| TypeExpr::List(Box::new(inner)))
}
(
TypeExpr::DictType(current_key, current_value),
TypeExpr::DictType(schema_key, schema_value),
) => {
let key = intersect_types(current_key, schema_key)?;
let value = intersect_types(current_value, schema_value)?;
Some(TypeExpr::DictType(Box::new(key), Box::new(value)))
}
_ => None,
}
}
fn subtract_type(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match current {
TypeExpr::Union(members) => {
let remaining = members
.iter()
.filter(|member| intersect_types(member, schema_type).is_none())
.cloned()
.collect::<Vec<_>>();
match remaining.len() {
0 => None,
1 => remaining.into_iter().next(),
_ => Some(TypeExpr::Union(remaining)),
}
}
other if intersect_types(other, schema_type).is_some() => None,
other => Some(other.clone()),
}
}
fn apply_refinements(scope: &mut TypeScope, refinements: &[(String, InferredType)]) {
for (var_name, narrowed_type) in refinements {
if !scope.narrowed_vars.contains_key(var_name) {
if let Some(original) = scope.get_var(var_name).cloned() {
scope.narrowed_vars.insert(var_name.clone(), original);
}
}
scope.define_var(var_name, narrowed_type.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Parser;
use harn_lexer::Lexer;
fn check_source(source: &str) -> Vec<TypeDiagnostic> {
let mut lexer = Lexer::new(source);
let tokens = lexer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
let program = parser.parse().unwrap();
TypeChecker::new().check(&program)
}
fn errors(source: &str) -> Vec<String> {
check_source(source)
.into_iter()
.filter(|d| d.severity == DiagnosticSeverity::Error)
.map(|d| d.message)
.collect()
}
#[test]
fn test_no_errors_for_untyped_code() {
let errs = errors("pipeline t(task) { let x = 42\nlog(x) }");
assert!(errs.is_empty());
}
#[test]
fn test_correct_typed_let() {
let errs = errors("pipeline t(task) { let x: int = 42 }");
assert!(errs.is_empty());
}
#[test]
fn test_type_mismatch_let() {
let errs = errors(r#"pipeline t(task) { let x: int = "hello" }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("declared as int"));
assert!(errs[0].contains("assigned string"));
}
#[test]
fn test_correct_typed_fn() {
let errs = errors(
"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }\nadd(1, 2) }",
);
assert!(errs.is_empty());
}
#[test]
fn test_fn_arg_type_mismatch() {
let errs = errors(
r#"pipeline t(task) { fn add(a: int, b: int) -> int { return a + b }
add("hello", 2) }"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("Argument 1"));
assert!(errs[0].contains("expected int"));
}
#[test]
fn test_return_type_mismatch() {
let errs = errors(r#"pipeline t(task) { fn get() -> int { return "hello" } }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("return type doesn't match"));
}
#[test]
fn test_union_type_compatible() {
let errs = errors(r#"pipeline t(task) { let x: string | nil = nil }"#);
assert!(errs.is_empty());
}
#[test]
fn test_union_type_mismatch() {
let errs = errors(r#"pipeline t(task) { let x: string | nil = 42 }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("declared as"));
}
#[test]
fn test_type_inference_propagation() {
let errs = errors(
r#"pipeline t(task) {
fn add(a: int, b: int) -> int { return a + b }
let result: string = add(1, 2)
}"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("declared as"));
assert!(errs[0].contains("string"));
assert!(errs[0].contains("int"));
}
#[test]
fn test_generic_return_type_instantiates_from_callsite() {
let errs = errors(
r#"pipeline t(task) {
fn identity<T>(x: T) -> T { return x }
fn first<T>(items: list<T>) -> T { return items[0] }
let n: int = identity(42)
let s: string = first(["a", "b"])
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_generic_type_param_must_bind_consistently() {
let errs = errors(
r#"pipeline t(task) {
fn keep<T>(a: T, b: T) -> T { return a }
keep(1, "x")
}"#,
);
assert_eq!(errs.len(), 2, "expected 2 errors, got: {:?}", errs);
assert!(
errs.iter()
.any(|err| err.contains("type parameter 'T' was inferred as both int and string")),
"missing generic binding conflict error: {:?}",
errs
);
assert!(
errs.iter()
.any(|err| err.contains("Argument 2 ('b'): expected int, got string")),
"missing instantiated argument mismatch error: {:?}",
errs
);
}
#[test]
fn test_generic_list_binding_propagates_element_type() {
let errs = errors(
r#"pipeline t(task) {
fn first<T>(items: list<T>) -> T { return items[0] }
let bad: string = first([1, 2, 3])
}"#,
);
assert_eq!(errs.len(), 1, "expected 1 error, got: {:?}", errs);
assert!(errs[0].contains("declared as string, but assigned int"));
}
#[test]
fn test_generic_struct_literal_instantiates_type_arguments() {
let errs = errors(
r#"pipeline t(task) {
struct Pair<A, B> {
first: A
second: B
}
let pair: Pair<int, string> = Pair { first: 1, second: "two" }
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_generic_enum_construct_instantiates_type_arguments() {
let errs = errors(
r#"pipeline t(task) {
enum Option<T> {
Some(value: T),
None
}
let value: Option<int> = Option.Some(42)
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_result_generic_type_compatibility() {
let errs = errors(
r#"pipeline t(task) {
let ok: Result<int, string> = Result.Ok(42)
let err: Result<int, string> = Result.Err("oops")
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_result_generic_type_mismatch_reports_error() {
let errs = errors(
r#"pipeline t(task) {
let bad: Result<int, string> = Result.Err(42)
}"#,
);
assert_eq!(errs.len(), 1, "expected 1 error, got: {errs:?}");
assert!(errs[0].contains("Result<int, string>"));
assert!(errs[0].contains("Result<_, int>"));
}
#[test]
fn test_builtin_return_type_inference() {
let errs = errors(r#"pipeline t(task) { let x: string = to_int("42") }"#);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("string"));
assert!(errs[0].contains("int"));
}
#[test]
fn test_workflow_and_transcript_builtins_are_known() {
let errs = errors(
r#"pipeline t(task) {
let flow = workflow_graph({name: "demo", entry: "act", nodes: {act: {kind: "stage"}}})
let report: dict = workflow_policy_report(flow, {tools: tool_registry(), capabilities: {workspace: ["read_text"]}})
let run: dict = workflow_execute("task", flow, [], {})
let tree: dict = load_run_tree("run.json")
let fixture: dict = run_record_fixture(run?.run)
let suite: dict = run_record_eval_suite([{run: run?.run, fixture: fixture}])
let diff: dict = run_record_diff(run?.run, run?.run)
let manifest: dict = eval_suite_manifest({cases: [{run_path: "run.json"}]})
let suite_report: dict = eval_suite_run(manifest)
let wf: dict = artifact_workspace_file("src/main.rs", "fn main() {}", {source: "host"})
let snap: dict = artifact_workspace_snapshot(["src/main.rs"], "snapshot")
let selection: dict = artifact_editor_selection("src/main.rs", "main")
let verify: dict = artifact_verification_result("verify", "ok")
let test_result: dict = artifact_test_result("tests", "pass")
let cmd: dict = artifact_command_result("cargo test", {status: 0})
let patch: dict = artifact_diff("src/main.rs", "old", "new")
let git: dict = artifact_git_diff("diff --git a b")
let review: dict = artifact_diff_review(patch, "review me")
let decision: dict = artifact_review_decision(review, "accepted")
let proposal: dict = artifact_patch_proposal(review, "*** Begin Patch")
let bundle: dict = artifact_verification_bundle("checks", [{name: "fmt", ok: true}])
let apply: dict = artifact_apply_intent(review, "apply")
let transcript = transcript_reset({metadata: {source: "test"}})
let visible: string = transcript_render_visible(transcript_archive(transcript))
let events: list = transcript_events(transcript)
let context: string = artifact_context([], {max_artifacts: 1})
println(report)
println(run)
println(tree)
println(fixture)
println(suite)
println(diff)
println(manifest)
println(suite_report)
println(wf)
println(snap)
println(selection)
println(verify)
println(test_result)
println(cmd)
println(patch)
println(git)
println(review)
println(decision)
println(proposal)
println(bundle)
println(apply)
println(visible)
println(events)
println(context)
}"#,
);
assert!(errs.is_empty(), "unexpected type errors: {errs:?}");
}
#[test]
fn test_binary_op_type_inference() {
let errs = errors("pipeline t(task) { let x: string = 1 + 2 }");
assert_eq!(errs.len(), 1);
}
#[test]
fn test_exponentiation_requires_numeric_operands() {
let errs = errors(r#"pipeline t(task) { let x = "nope" ** 2 }"#);
assert!(
errs.iter().any(|err| err.contains("can't use '**'")),
"missing exponentiation type error: {errs:?}"
);
}
#[test]
fn test_comparison_returns_bool() {
let errs = errors("pipeline t(task) { let x: bool = 1 < 2 }");
assert!(errs.is_empty());
}
#[test]
fn test_int_float_promotion() {
let errs = errors("pipeline t(task) { let x: float = 42 }");
assert!(errs.is_empty());
}
#[test]
fn test_untyped_code_no_errors() {
let errs = errors(
r#"pipeline t(task) {
fn process(data) {
let result = data + " processed"
return result
}
log(process("hello"))
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_type_alias() {
let errs = errors(
r#"pipeline t(task) {
type Name = string
let x: Name = "hello"
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_type_alias_mismatch() {
let errs = errors(
r#"pipeline t(task) {
type Name = string
let x: Name = 42
}"#,
);
assert_eq!(errs.len(), 1);
}
#[test]
fn test_assignment_type_check() {
let errs = errors(
r#"pipeline t(task) {
var x: int = 0
x = "hello"
}"#,
);
assert_eq!(errs.len(), 1);
assert!(errs[0].contains("can't assign string"));
}
#[test]
fn test_covariance_int_to_float_in_fn() {
let errs = errors(
"pipeline t(task) { fn scale(x: float) -> float { return x * 2.0 }\nscale(42) }",
);
assert!(errs.is_empty());
}
#[test]
fn test_covariance_return_type() {
let errs = errors("pipeline t(task) { fn get() -> float { return 42 } }");
assert!(errs.is_empty());
}
#[test]
fn test_no_contravariance_float_to_int() {
let errs = errors("pipeline t(task) { fn add(a: int) -> int { return a + 1 }\nadd(3.14) }");
assert_eq!(errs.len(), 1);
}
#[test]
fn test_fn_param_contravariance_positive() {
let errs = errors(
r#"pipeline t(task) {
let wide = fn(x: float) { return 0 }
let cb: fn(int) -> int = wide
}"#,
);
assert!(
errs.is_empty(),
"expected fn(float)->int to satisfy fn(int)->int, got: {errs:?}"
);
}
#[test]
fn test_fn_param_contravariance_negative() {
let errs = errors(
r#"pipeline t(task) {
let narrow = fn(x: int) { return 0 }
let cb: fn(float) -> int = narrow
}"#,
);
assert!(
!errs.is_empty(),
"expected fn(int)->int NOT to satisfy fn(float)->int, but type-check passed"
);
}
#[test]
fn test_list_invariant_int_to_float_rejected() {
let errs = errors(
r#"pipeline t(task) {
let xs: list<int> = [1, 2, 3]
let ys: list<float> = xs
}"#,
);
assert!(
!errs.is_empty(),
"expected list<int> NOT to flow into list<float>, but type-check passed"
);
}
#[test]
fn test_iter_covariant_int_to_float_accepted() {
let errs = errors(
r#"pipeline t(task) {
fn sink(ys: iter<float>) -> int { return 0 }
fn pipe(xs: iter<int>) -> int { return sink(xs) }
}"#,
);
assert!(
errs.is_empty(),
"expected iter<int> to flow into iter<float>, got: {errs:?}"
);
}
#[test]
fn test_decl_site_out_used_in_contravariant_position_rejected() {
let errs = errors(
r#"pipeline t(task) {
type Box<out T> = fn(T) -> int
}"#,
);
assert!(
errs.iter().any(|e| e.contains("declared 'out'")),
"expected 'out T' misuse diagnostic, got: {errs:?}"
);
}
#[test]
fn test_decl_site_in_used_in_covariant_position_rejected() {
let errs = errors(
r#"pipeline t(task) {
interface Producer<in T> { fn next() -> T }
}"#,
);
assert!(
errs.iter().any(|e| e.contains("declared 'in'")),
"expected 'in T' misuse diagnostic, got: {errs:?}"
);
}
#[test]
fn test_decl_site_out_in_covariant_position_ok() {
let errs = errors(
r#"pipeline t(task) {
type Reader<out T> = fn() -> T
}"#,
);
assert!(
errs.iter().all(|e| !e.contains("declared 'out'")),
"unexpected variance diagnostic: {errs:?}"
);
}
#[test]
fn test_dict_invariant_int_to_float_rejected() {
let errs = errors(
r#"pipeline t(task) {
let d: dict<string, int> = {"a": 1}
let e: dict<string, float> = d
}"#,
);
assert!(
!errs.is_empty(),
"expected dict<string, int> NOT to flow into dict<string, float>"
);
}
fn warnings(source: &str) -> Vec<String> {
check_source(source)
.into_iter()
.filter(|d| d.severity == DiagnosticSeverity::Warning)
.map(|d| d.message)
.collect()
}
#[test]
fn test_exhaustive_match_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c = Color.Red
match c.variant {
"Red" -> { log("r") }
"Green" -> { log("g") }
"Blue" -> { log("b") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert!(exhaustive_warns.is_empty());
}
#[test]
fn test_non_exhaustive_match_warning() {
let warns = warnings(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c = Color.Red
match c.variant {
"Red" -> { log("r") }
"Green" -> { log("g") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert_eq!(exhaustive_warns.len(), 1);
assert!(exhaustive_warns[0].contains("Blue"));
}
#[test]
fn test_non_exhaustive_multiple_missing() {
let warns = warnings(
r#"pipeline t(task) {
enum Status { Active, Inactive, Pending }
let s = Status.Active
match s.variant {
"Active" -> { log("a") }
}
}"#,
);
let exhaustive_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive"))
.collect();
assert_eq!(exhaustive_warns.len(), 1);
assert!(exhaustive_warns[0].contains("Inactive"));
assert!(exhaustive_warns[0].contains("Pending"));
}
fn exhaustive_warns(source: &str) -> Vec<String> {
warnings(source)
.into_iter()
.filter(|w| w.contains("was not fully narrowed"))
.collect()
}
#[test]
fn test_unknown_exhaustive_unreachable_happy_path() {
let source = r#"pipeline t(task) {
fn describe(v: unknown) -> string {
if type_of(v) == "string" { return "s" }
if type_of(v) == "int" { return "i" }
if type_of(v) == "float" { return "f" }
if type_of(v) == "bool" { return "b" }
if type_of(v) == "nil" { return "n" }
if type_of(v) == "list" { return "l" }
if type_of(v) == "dict" { return "d" }
if type_of(v) == "closure" { return "c" }
unreachable("unknown type_of variant")
}
log(describe(1))
}"#;
assert!(exhaustive_warns(source).is_empty());
}
#[test]
fn test_unknown_exhaustive_unreachable_incomplete_warns() {
let source = r#"pipeline t(task) {
fn describe(v: unknown) -> string {
if type_of(v) == "string" { return "s" }
if type_of(v) == "int" { return "i" }
unreachable("unknown type_of variant")
}
log(describe(1))
}"#;
let warns = exhaustive_warns(source);
assert_eq!(warns.len(), 1, "expected one warning, got: {:?}", warns);
let w = &warns[0];
for missing in &["float", "bool", "nil", "list", "dict", "closure"] {
assert!(w.contains(missing), "missing {missing} in: {w}");
}
assert!(!w.contains("int"));
assert!(!w.contains("string"));
assert!(w.contains("unreachable"));
assert!(w.contains("v: unknown"));
}
#[test]
fn test_unknown_incomplete_normal_return_no_warning() {
let source = r#"pipeline t(task) {
fn describe(v: unknown) -> string {
if type_of(v) == "string" { return "s" }
if type_of(v) == "int" { return "i" }
return "other"
}
log(describe(1))
}"#;
assert!(exhaustive_warns(source).is_empty());
}
#[test]
fn test_unknown_exhaustive_throw_incomplete_warns() {
let source = r#"pipeline t(task) {
fn describe(v: unknown) -> string {
if type_of(v) == "string" { return "s" }
throw "unhandled"
}
log(describe("x"))
}"#;
let warns = exhaustive_warns(source);
assert_eq!(warns.len(), 1, "expected one warning, got: {:?}", warns);
assert!(warns[0].contains("throw"));
assert!(warns[0].contains("int"));
}
#[test]
fn test_unknown_throw_without_narrowing_no_warning() {
let source = r#"pipeline t(task) {
fn crash(v: unknown) -> string {
throw "nope"
}
log(crash(1))
}"#;
assert!(exhaustive_warns(source).is_empty());
}
#[test]
fn test_unknown_exhaustive_nested_branch() {
let source = r#"pipeline t(task) {
fn describe(v: unknown, x: int) -> string {
if type_of(v) == "string" {
if x > 0 { return v.upper() } else { return "s" }
}
if type_of(v) == "int" { return "i" }
unreachable("unknown type_of variant")
}
log(describe(1, 1))
}"#;
let warns = exhaustive_warns(source);
assert_eq!(warns.len(), 1, "expected one warning, got: {:?}", warns);
assert!(warns[0].contains("float"));
}
#[test]
fn test_unknown_exhaustive_negated_check() {
let source = r#"pipeline t(task) {
fn describe(v: unknown) -> string {
if type_of(v) != "string" {
// v still unknown here, but "string" is NOT ruled out on this path
return "non-string"
}
// v: string here
return v.upper()
}
log(describe("x"))
}"#;
assert!(exhaustive_warns(source).is_empty());
}
#[test]
fn test_enum_construct_type_inference() {
let errs = errors(
r#"pipeline t(task) {
enum Color { Red, Green, Blue }
let c: Color = Color.Red
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_nil_coalescing_strips_nil() {
let errs = errors(
r#"pipeline t(task) {
let x: string | nil = nil
let y: string = x ?? "default"
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_shape_mismatch_detail_missing_field() {
let errs = errors(
r#"pipeline t(task) {
let x: {name: string, age: int} = {name: "hello"}
}"#,
);
assert_eq!(errs.len(), 1);
assert!(
errs[0].contains("missing field 'age'"),
"expected detail about missing field, got: {}",
errs[0]
);
}
#[test]
fn test_shape_mismatch_detail_wrong_type() {
let errs = errors(
r#"pipeline t(task) {
let x: {name: string, age: int} = {name: 42, age: 10}
}"#,
);
assert_eq!(errs.len(), 1);
assert!(
errs[0].contains("field 'name' has type int, expected string"),
"expected detail about wrong type, got: {}",
errs[0]
);
}
#[test]
fn test_match_pattern_string_against_int() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
"hello" -> { log("bad") }
42 -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching int against string literal"));
}
#[test]
fn test_match_pattern_int_against_string() {
let warns = warnings(
r#"pipeline t(task) {
let x: string = "hello"
match x {
42 -> { log("bad") }
"hello" -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching string against int literal"));
}
#[test]
fn test_match_pattern_bool_against_int() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
true -> { log("bad") }
42 -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching int against bool literal"));
}
#[test]
fn test_match_pattern_float_against_string() {
let warns = warnings(
r#"pipeline t(task) {
let x: string = "hello"
match x {
3.14 -> { log("bad") }
"hello" -> { log("ok") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert_eq!(pattern_warns.len(), 1);
assert!(pattern_warns[0].contains("matching string against float literal"));
}
#[test]
fn test_match_pattern_int_against_float_ok() {
let warns = warnings(
r#"pipeline t(task) {
let x: float = 3.14
match x {
42 -> { log("ok") }
_ -> { log("default") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_float_against_int_ok() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
3.14 -> { log("close") }
_ -> { log("default") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_correct_types_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
1 -> { log("one") }
2 -> { log("two") }
_ -> { log("other") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_wildcard_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: int = 42
match x {
_ -> { log("catch all") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
#[test]
fn test_match_pattern_untyped_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x = some_unknown_fn()
match x {
"hello" -> { log("string") }
42 -> { log("int") }
}
}"#,
);
let pattern_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Match pattern type mismatch"))
.collect();
assert!(pattern_warns.is_empty());
}
fn iface_errors(source: &str) -> Vec<String> {
errors(source)
.into_iter()
.filter(|message| message.contains("does not satisfy interface"))
.collect()
}
#[test]
fn test_interface_constraint_return_type_mismatch() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn size(self) -> string { return "nope" }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'size' returns 'string', expected 'int'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_param_type_mismatch() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Processor {
fn process(self, x: int) -> string
}
struct MyProc { name: string }
impl MyProc {
fn process(self, x: string) -> string { return x }
}
fn run_proc<T>(p: T) where T: Processor { log(p.process(42)) }
run_proc(MyProc({name: "a"}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'process' parameter 1 has type 'string', expected 'int'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_missing_method() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn area(self) -> int { return self.width }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("missing method 'size'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_param_count_mismatch() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Doubler {
fn double(self, x: int) -> int
}
struct Bad { v: int }
impl Bad {
fn double(self) -> int { return self.v * 2 }
}
fn run_double<T>(d: T) where T: Doubler { log(d.double(3)) }
run_double(Bad({v: 5}))
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("method 'double' has 0 parameter(s), expected 1"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_interface_constraint_satisfied() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int, height: int }
impl Box {
fn size(self) -> int { return self.width * self.height }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3, height: 4}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_constraint_untyped_impl_compatible() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Sizable {
fn size(self) -> int
}
struct Box { width: int }
impl Box {
fn size(self) { return self.width }
}
fn measure<T>(item: T) where T: Sizable { log(item.size()) }
measure(Box({width: 3}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_constraint_int_float_covariance() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Measurable {
fn value(self) -> float
}
struct Gauge { v: int }
impl Gauge {
fn value(self) -> int { return self.v }
}
fn read_val<T>(g: T) where T: Measurable { log(g.value()) }
read_val(Gauge({v: 42}))
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_associated_type_constraint_satisfied() {
let warns = iface_errors(
r#"pipeline t(task) {
interface Collection {
type Item
fn get(self, index: int) -> Item
}
struct Names {}
impl Names {
fn get(self, index: int) -> string { return "ada" }
}
fn first<C>(collection: C) where C: Collection {
log(collection.get(0))
}
first(Names {})
}"#,
);
assert!(warns.is_empty(), "expected no warnings, got: {:?}", warns);
}
#[test]
fn test_interface_associated_type_default_mismatch() {
let warns = iface_errors(
r#"pipeline t(task) {
interface IntCollection {
type Item = int
fn get(self, index: int) -> Item
}
struct Labels {}
impl Labels {
fn get(self, index: int) -> string { return "oops" }
}
fn first<C>(collection: C) where C: IntCollection {
log(collection.get(0))
}
first(Labels {})
}"#,
);
assert_eq!(warns.len(), 1, "expected 1 warning, got: {:?}", warns);
assert!(
warns[0].contains("associated type 'Item' resolves to 'string', expected 'int'"),
"unexpected message: {}",
warns[0]
);
}
#[test]
fn test_nil_narrowing_then_branch() {
let errs = errors(
r#"pipeline t(task) {
fn greet(name: string | nil) {
if name != nil {
let s: string = name
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_nil_narrowing_else_branch() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
if x != nil {
let s: string = x
} else {
let n: nil = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_nil_equality_narrows_both() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
if x == nil {
let n: nil = x
} else {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_truthiness_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
if x {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_negation_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
if !x {
let n: nil = x
} else {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_typeof_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int) {
if type_of(x) == "string" {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_typeof_narrowing_else() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int) {
if type_of(x) == "string" {
let s: string = x
} else {
let i: int = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_typeof_neq_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int) {
if type_of(x) != "string" {
let i: int = x
} else {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_and_combines_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int | nil) {
if x != nil && type_of(x) == "string" {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_or_falsy_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil, y: int | nil) {
if x || y {
// conservative: can't narrow
} else {
let xn: nil = x
let yn: nil = y
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_guard_narrows_outer_scope() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
guard x != nil else { return }
let s: string = x
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_while_narrows_body() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
while x != nil {
let s: string = x
break
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_early_return_narrows_after_if() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) -> string {
if x == nil {
return "default"
}
let s: string = x
return s
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_early_throw_narrows_after_if() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
if x == nil {
throw "missing"
}
let s: string = x
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_no_narrowing_unknown_type() {
let errs = errors(
r#"pipeline t(task) {
fn check(x) {
if x != nil {
let s: string = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_reassignment_invalidates_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | nil) {
var y: string | nil = x
if y != nil {
let s: string = y
y = nil
let s2: string = y
}
}
}"#,
);
assert_eq!(errs.len(), 1, "expected 1 error, got: {:?}", errs);
assert!(
errs[0].contains("declared as"),
"expected type mismatch, got: {}",
errs[0]
);
}
#[test]
fn test_let_immutable_warning() {
let all = check_source(
r#"pipeline t(task) {
let x = 42
x = 43
}"#,
);
let warnings: Vec<_> = all
.iter()
.filter(|d| d.severity == DiagnosticSeverity::Warning)
.collect();
assert!(
warnings.iter().any(|w| w.message.contains("immutable")),
"expected immutability warning, got: {:?}",
warnings
);
}
#[test]
fn test_nested_narrowing() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int | nil) {
if x != nil {
if type_of(x) == "int" {
let i: int = x
}
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_match_narrows_arms() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: string | int) {
match x {
"hello" -> {
let s: string = x
}
42 -> {
let i: int = x
}
_ -> {}
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
#[test]
fn test_has_narrows_optional_field() {
let errs = errors(
r#"pipeline t(task) {
fn check(x: {name?: string, age: int}) {
if x.has("name") {
let n: {name: string, age: int} = x
}
}
}"#,
);
assert!(errs.is_empty(), "got: {:?}", errs);
}
fn check_source_with_source(source: &str) -> Vec<TypeDiagnostic> {
let mut lexer = Lexer::new(source);
let tokens = lexer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
let program = parser.parse().unwrap();
TypeChecker::new().check_with_source(&program, source)
}
#[test]
fn test_fix_string_plus_int_literal() {
let source = "pipeline t(task) {\n let x = \"hello \" + 42\n log(x)\n}";
let diags = check_source_with_source(source);
let fixable: Vec<_> = diags.iter().filter(|d| d.fix.is_some()).collect();
assert_eq!(fixable.len(), 1, "expected 1 fixable diagnostic");
let fix = fixable[0].fix.as_ref().unwrap();
assert_eq!(fix.len(), 1);
assert_eq!(fix[0].replacement, "\"hello ${42}\"");
}
#[test]
fn test_fix_int_plus_string_literal() {
let source = "pipeline t(task) {\n let x = 42 + \"hello\"\n log(x)\n}";
let diags = check_source_with_source(source);
let fixable: Vec<_> = diags.iter().filter(|d| d.fix.is_some()).collect();
assert_eq!(fixable.len(), 1, "expected 1 fixable diagnostic");
let fix = fixable[0].fix.as_ref().unwrap();
assert_eq!(fix[0].replacement, "\"${42}hello\"");
}
#[test]
fn test_fix_string_plus_variable() {
let source = "pipeline t(task) {\n let n: int = 5\n let x = \"count: \" + n\n log(x)\n}";
let diags = check_source_with_source(source);
let fixable: Vec<_> = diags.iter().filter(|d| d.fix.is_some()).collect();
assert_eq!(fixable.len(), 1, "expected 1 fixable diagnostic");
let fix = fixable[0].fix.as_ref().unwrap();
assert_eq!(fix[0].replacement, "\"count: ${n}\"");
}
#[test]
fn test_no_fix_int_plus_int() {
let source = "pipeline t(task) {\n let x: int = 5\n let y: float = 3.0\n let z = x - y\n log(z)\n}";
let diags = check_source_with_source(source);
let fixable: Vec<_> = diags.iter().filter(|d| d.fix.is_some()).collect();
assert!(
fixable.is_empty(),
"no fix expected for numeric ops, got: {fixable:?}"
);
}
#[test]
fn test_no_fix_without_source() {
let source = "pipeline t(task) {\n let x = \"hello \" + 42\n log(x)\n}";
let diags = check_source(source);
let fixable: Vec<_> = diags.iter().filter(|d| d.fix.is_some()).collect();
assert!(
fixable.is_empty(),
"without source, no fix should be generated"
);
}
#[test]
fn test_union_exhaustive_match_no_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: string | int | nil = nil
match x {
"hello" -> { log("s") }
42 -> { log("i") }
nil -> { log("n") }
}
}"#,
);
let union_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive match on union"))
.collect();
assert!(union_warns.is_empty());
}
#[test]
fn test_union_non_exhaustive_match_warning() {
let warns = warnings(
r#"pipeline t(task) {
let x: string | int | nil = nil
match x {
"hello" -> { log("s") }
42 -> { log("i") }
}
}"#,
);
let union_warns: Vec<_> = warns
.iter()
.filter(|w| w.contains("Non-exhaustive match on union"))
.collect();
assert_eq!(union_warns.len(), 1);
assert!(union_warns[0].contains("nil"));
}
#[test]
fn test_nil_coalesce_non_union_preserves_left_type() {
let errs = errors(
r#"pipeline t(task) {
let x: int = 42
let y: int = x ?? 0
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_nil_coalesce_nil_returns_right_type() {
let errs = errors(
r#"pipeline t(task) {
let x: string = nil ?? "fallback"
}"#,
);
assert!(errs.is_empty());
}
#[test]
fn test_never_is_subtype_of_everything() {
let tc = TypeChecker::new();
let scope = TypeScope::new();
assert!(tc.types_compatible(&TypeExpr::Named("string".into()), &TypeExpr::Never, &scope));
assert!(tc.types_compatible(&TypeExpr::Named("int".into()), &TypeExpr::Never, &scope));
assert!(tc.types_compatible(
&TypeExpr::Union(vec![
TypeExpr::Named("string".into()),
TypeExpr::Named("nil".into()),
]),
&TypeExpr::Never,
&scope,
));
}
#[test]
fn test_nothing_is_subtype_of_never() {
let tc = TypeChecker::new();
let scope = TypeScope::new();
assert!(!tc.types_compatible(&TypeExpr::Never, &TypeExpr::Named("string".into()), &scope));
assert!(!tc.types_compatible(&TypeExpr::Never, &TypeExpr::Named("int".into()), &scope));
}
#[test]
fn test_never_never_compatible() {
let tc = TypeChecker::new();
let scope = TypeScope::new();
assert!(tc.types_compatible(&TypeExpr::Never, &TypeExpr::Never, &scope));
}
#[test]
fn test_any_is_top_type_bidirectional() {
let tc = TypeChecker::new();
let scope = TypeScope::new();
let any = TypeExpr::Named("any".into());
assert!(tc.types_compatible(&any, &TypeExpr::Named("string".into()), &scope));
assert!(tc.types_compatible(&any, &TypeExpr::Named("int".into()), &scope));
assert!(tc.types_compatible(&any, &TypeExpr::Named("nil".into()), &scope));
assert!(tc.types_compatible(
&any,
&TypeExpr::List(Box::new(TypeExpr::Named("int".into()))),
&scope
));
assert!(tc.types_compatible(&TypeExpr::Named("string".into()), &any, &scope));
assert!(tc.types_compatible(&TypeExpr::Named("nil".into()), &any, &scope));
}
#[test]
fn test_unknown_is_safe_top_one_way() {
let tc = TypeChecker::new();
let scope = TypeScope::new();
let unknown = TypeExpr::Named("unknown".into());
assert!(tc.types_compatible(&unknown, &TypeExpr::Named("string".into()), &scope));
assert!(tc.types_compatible(&unknown, &TypeExpr::Named("nil".into()), &scope));
assert!(tc.types_compatible(
&unknown,
&TypeExpr::List(Box::new(TypeExpr::Named("int".into()))),
&scope
));
assert!(!tc.types_compatible(&TypeExpr::Named("string".into()), &unknown, &scope));
assert!(!tc.types_compatible(&TypeExpr::Named("int".into()), &unknown, &scope));
assert!(tc.types_compatible(&unknown, &unknown, &scope));
assert!(tc.types_compatible(&TypeExpr::Named("any".into()), &unknown, &scope));
}
#[test]
fn test_unknown_narrows_via_type_of() {
let errs = errors(
r#"pipeline t(task) {
fn f(v: unknown) -> string {
if type_of(v) == "string" {
return v
}
return "other"
}
log(f("hi"))
}"#,
);
assert!(
errs.is_empty(),
"unknown should narrow to string inside type_of guard: {errs:?}"
);
}
#[test]
fn test_unknown_without_narrowing_errors() {
let errs = errors(
r#"pipeline t(task) {
let u: unknown = "hello"
let s: string = u
}"#,
);
assert!(
errs.iter().any(|e| e.contains("unknown")),
"expected an error mentioning unknown, got: {errs:?}"
);
}
#[test]
fn test_simplify_union_removes_never() {
assert_eq!(
simplify_union(vec![TypeExpr::Never, TypeExpr::Named("string".into())]),
TypeExpr::Named("string".into()),
);
assert_eq!(
simplify_union(vec![TypeExpr::Never, TypeExpr::Never]),
TypeExpr::Never,
);
assert_eq!(
simplify_union(vec![
TypeExpr::Named("string".into()),
TypeExpr::Never,
TypeExpr::Named("int".into()),
]),
TypeExpr::Union(vec![
TypeExpr::Named("string".into()),
TypeExpr::Named("int".into()),
]),
);
}
#[test]
fn test_remove_from_union_exhausted_returns_never() {
let result = remove_from_union(&[TypeExpr::Named("string".into())], "string");
assert_eq!(result, Some(TypeExpr::Never));
}
#[test]
fn test_if_else_one_branch_throws_infers_other() {
let errs = errors(
r#"pipeline t(task) {
fn foo(x: bool) -> int {
let result: int = if x { 42 } else { throw "err" }
return result
}
}"#,
);
assert!(errs.is_empty(), "unexpected errors: {errs:?}");
}
#[test]
fn test_if_else_both_branches_throw_infers_never() {
let errs = errors(
r#"pipeline t(task) {
fn foo(x: bool) -> string {
let result: string = if x { throw "a" } else { throw "b" }
return result
}
}"#,
);
assert!(errs.is_empty(), "unexpected errors: {errs:?}");
}
#[test]
fn test_unreachable_after_return() {
let warns = warnings(
r#"pipeline t(task) {
fn foo() -> int {
return 1
let x = 2
}
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unreachable")),
"expected unreachable warning: {warns:?}"
);
}
#[test]
fn test_unreachable_after_throw() {
let warns = warnings(
r#"pipeline t(task) {
fn foo() {
throw "err"
let x = 2
}
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unreachable")),
"expected unreachable warning: {warns:?}"
);
}
#[test]
fn test_unreachable_after_composite_exit() {
let warns = warnings(
r#"pipeline t(task) {
fn foo(x: bool) {
if x { return 1 } else { throw "err" }
let y = 2
}
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unreachable")),
"expected unreachable warning: {warns:?}"
);
}
#[test]
fn test_no_unreachable_warning_when_reachable() {
let warns = warnings(
r#"pipeline t(task) {
fn foo(x: bool) {
if x { return 1 }
let y = 2
}
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unreachable")),
"unexpected unreachable warning: {warns:?}"
);
}
#[test]
fn test_catch_typed_error_variable() {
let errs = errors(
r#"pipeline t(task) {
enum AppError { NotFound, Timeout }
try {
throw AppError.NotFound
} catch (e: AppError) {
let x: AppError = e
}
}"#,
);
assert!(errs.is_empty(), "unexpected errors: {errs:?}");
}
#[test]
fn test_unreachable_with_never_arg_no_error() {
let errs = errors(
r#"pipeline t(task) {
fn foo(x: string | int) {
if type_of(x) == "string" { return }
if type_of(x) == "int" { return }
unreachable(x)
}
}"#,
);
assert!(
!errs.iter().any(|e| e.contains("unreachable")),
"unexpected unreachable error: {errs:?}"
);
}
#[test]
fn test_unreachable_with_remaining_types_errors() {
let errs = errors(
r#"pipeline t(task) {
fn foo(x: string | int | nil) {
if type_of(x) == "string" { return }
unreachable(x)
}
}"#,
);
assert!(
errs.iter()
.any(|e| e.contains("unreachable") && e.contains("not all cases")),
"expected unreachable error about remaining types: {errs:?}"
);
}
#[test]
fn test_unreachable_no_args_no_compile_error() {
let errs = errors(
r#"pipeline t(task) {
fn foo() {
unreachable()
}
}"#,
);
assert!(
!errs
.iter()
.any(|e| e.contains("unreachable") && e.contains("not all cases")),
"unreachable() with no args should not produce type error: {errs:?}"
);
}
#[test]
fn test_never_type_annotation_parses() {
let errs = errors(
r#"pipeline t(task) {
fn foo() -> never {
throw "always throws"
}
}"#,
);
assert!(errs.is_empty(), "unexpected errors: {errs:?}");
}
#[test]
fn test_format_type_never() {
assert_eq!(format_type(&TypeExpr::Never), "never");
}
fn check_source_strict(source: &str) -> Vec<TypeDiagnostic> {
let mut lexer = Lexer::new(source);
let tokens = lexer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
let program = parser.parse().unwrap();
TypeChecker::with_strict_types(true).check(&program)
}
fn strict_warnings(source: &str) -> Vec<String> {
check_source_strict(source)
.into_iter()
.filter(|d| d.severity == DiagnosticSeverity::Warning)
.map(|d| d.message)
.collect()
}
#[test]
fn test_strict_types_json_parse_property_access() {
let warns = strict_warnings(
r#"pipeline t(task) {
let data = json_parse("{}")
log(data.name)
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unvalidated")),
"expected unvalidated warning, got: {warns:?}"
);
}
#[test]
fn test_strict_types_direct_chain_access() {
let warns = strict_warnings(
r#"pipeline t(task) {
log(json_parse("{}").name)
}"#,
);
assert!(
warns.iter().any(|w| w.contains("Direct property access")),
"expected direct access warning, got: {warns:?}"
);
}
#[test]
fn test_strict_types_schema_expect_clears() {
let warns = strict_warnings(
r#"pipeline t(task) {
let my_schema = {type: "object", properties: {name: {type: "string"}}}
let data = json_parse("{}")
schema_expect(data, my_schema)
log(data.name)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"expected no unvalidated warning after schema_expect, got: {warns:?}"
);
}
#[test]
fn test_strict_types_schema_is_if_guard() {
let warns = strict_warnings(
r#"pipeline t(task) {
let my_schema = {type: "object", properties: {name: {type: "string"}}}
let data = json_parse("{}")
if schema_is(data, my_schema) {
log(data.name)
}
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"expected no unvalidated warning inside schema_is guard, got: {warns:?}"
);
}
#[test]
fn test_strict_types_shape_annotation_clears() {
let warns = strict_warnings(
r#"pipeline t(task) {
let data: {name: string, age: int} = json_parse("{}")
log(data.name)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"expected no warning with shape annotation, got: {warns:?}"
);
}
#[test]
fn test_strict_types_propagation() {
let warns = strict_warnings(
r#"pipeline t(task) {
let data = json_parse("{}")
let x = data
log(x.name)
}"#,
);
assert!(
warns
.iter()
.any(|w| w.contains("unvalidated") && w.contains("'x'")),
"expected propagation warning for x, got: {warns:?}"
);
}
#[test]
fn test_strict_types_non_boundary_no_warning() {
let warns = strict_warnings(
r#"pipeline t(task) {
let x = len("hello")
log(x)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"non-boundary function should not be flagged, got: {warns:?}"
);
}
#[test]
fn test_strict_types_subscript_access() {
let warns = strict_warnings(
r#"pipeline t(task) {
let data = json_parse("{}")
log(data["name"])
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unvalidated")),
"expected subscript warning, got: {warns:?}"
);
}
#[test]
fn test_strict_types_disabled_by_default() {
let diags = check_source(
r#"pipeline t(task) {
let data = json_parse("{}")
log(data.name)
}"#,
);
assert!(
!diags.iter().any(|d| d.message.contains("unvalidated")),
"strict types should be off by default, got: {diags:?}"
);
}
#[test]
fn test_strict_types_llm_call_without_schema() {
let warns = strict_warnings(
r#"pipeline t(task) {
let result = llm_call("prompt", "system")
log(result.text)
}"#,
);
assert!(
warns.iter().any(|w| w.contains("unvalidated")),
"llm_call without schema should warn, got: {warns:?}"
);
}
#[test]
fn test_strict_types_llm_call_with_schema_clean() {
let warns = strict_warnings(
r#"pipeline t(task) {
let result = llm_call("prompt", "system", {
schema: {type: "object", properties: {name: {type: "string"}}}
})
log(result.data)
log(result.text)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"llm_call with schema should not warn, got: {warns:?}"
);
}
#[test]
fn test_strict_types_schema_expect_result_typed() {
let warns = strict_warnings(
r#"pipeline t(task) {
let my_schema = {type: "object", properties: {name: {type: "string"}}}
let validated = schema_expect(json_parse("{}"), my_schema)
log(validated.name)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"schema_expect result should be typed, got: {warns:?}"
);
}
#[test]
fn test_strict_types_realistic_orchestration() {
let warns = strict_warnings(
r#"pipeline t(task) {
let payload_schema = {type: "object", properties: {
name: {type: "string"},
steps: {type: "list", items: {type: "string"}}
}}
// Good: schema-aware llm_call
let result = llm_call("generate a workflow", "system", {
schema: payload_schema
})
let workflow_name = result.data.name
// Good: validate then access
let raw = json_parse("{}")
schema_expect(raw, payload_schema)
let steps = raw.steps
log(workflow_name)
log(steps)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"validated orchestration should be clean, got: {warns:?}"
);
}
#[test]
fn test_strict_types_llm_call_with_schema_via_variable() {
let warns = strict_warnings(
r#"pipeline t(task) {
let my_schema = {type: "object", properties: {score: {type: "float"}}}
let result = llm_call("rate this", "system", {
schema: my_schema
})
log(result.data.score)
}"#,
);
assert!(
!warns.iter().any(|w| w.contains("unvalidated")),
"llm_call with schema variable should not warn, got: {warns:?}"
);
}
}