1#![allow(unused_assignments)]
26
27pub mod builtin;
28pub mod constraint;
29pub(crate) mod deferred;
30pub(crate) mod exhaustiveness;
31pub mod infer;
32pub mod narrowing;
33pub mod types;
34pub mod unify;
35
36use miette::Diagnostic;
37use mq_hir::{Hir, SymbolId};
38use rustc_hash::FxHashMap;
39use thiserror::Error;
40use types::TypeScheme;
41
42pub type Result<T> = std::result::Result<T, TypeError>;
44
45#[derive(Debug, Clone, Default)]
47pub struct TypeEnv(FxHashMap<SymbolId, TypeScheme>);
48
49impl TypeEnv {
50 pub fn insert(&mut self, symbol_id: SymbolId, scheme: TypeScheme) {
51 self.0.insert(symbol_id, scheme);
52 }
53
54 pub fn get(&self, symbol_id: &SymbolId) -> Option<&TypeScheme> {
55 self.0.get(symbol_id)
56 }
57
58 pub fn get_all(&self) -> &FxHashMap<SymbolId, TypeScheme> {
59 &self.0
60 }
61
62 pub fn len(&self) -> usize {
63 self.0.len()
64 }
65
66 pub fn is_empty(&self) -> bool {
67 self.0.is_empty()
68 }
69}
70
71impl<'a> IntoIterator for &'a TypeEnv {
72 type Item = (&'a SymbolId, &'a TypeScheme);
73 type IntoIter = std::collections::hash_map::Iter<'a, SymbolId, TypeScheme>;
74
75 fn into_iter(self) -> Self::IntoIter {
76 self.0.iter()
77 }
78}
79
80#[derive(Debug, Error, Clone, Diagnostic)]
82#[allow(unused_assignments)]
83pub enum TypeError {
84 #[error("Type mismatch: expected {expected}, found {found}")]
85 #[diagnostic(code(typechecker::type_mismatch))]
86 #[allow(dead_code)]
87 Mismatch {
88 expected: String,
89 found: String,
90 #[label("type mismatch here")]
91 span: Option<miette::SourceSpan>,
92 location: Option<mq_lang::Range>,
93 #[help]
94 context: Option<String>,
95 },
96 #[error("Cannot unify types: {left} and {right}")]
97 #[diagnostic(code(typechecker::unification_error))]
98 #[allow(dead_code)]
99 UnificationError {
100 left: String,
101 right: String,
102 #[label("cannot unify these types")]
103 span: Option<miette::SourceSpan>,
104 location: Option<mq_lang::Range>,
105 #[help]
106 context: Option<String>,
107 },
108 #[error("Occurs check failed: type variable {var} occurs in {ty}")]
109 #[diagnostic(code(typechecker::occurs_check))]
110 #[allow(dead_code)]
111 OccursCheck {
112 var: String,
113 ty: String,
114 #[label("infinite type")]
115 span: Option<miette::SourceSpan>,
116 location: Option<mq_lang::Range>,
117 },
118 #[error("Undefined symbol: {name}")]
119 #[diagnostic(code(typechecker::undefined_symbol))]
120 #[allow(dead_code)]
121 UndefinedSymbol {
122 name: String,
123 #[label("undefined symbol")]
124 span: Option<miette::SourceSpan>,
125 location: Option<mq_lang::Range>,
126 },
127 #[error("Wrong number of arguments: expected {expected}, found {found}")]
128 #[diagnostic(code(typechecker::wrong_arity))]
129 #[allow(dead_code)]
130 WrongArity {
131 expected: usize,
132 found: usize,
133 #[label("wrong number of arguments")]
134 span: Option<miette::SourceSpan>,
135 location: Option<mq_lang::Range>,
136 #[help]
137 context: Option<String>,
138 },
139 #[error("Undefined field `{field}` in record type {record_ty}")]
140 #[diagnostic(code(typechecker::undefined_field))]
141 #[allow(dead_code)]
142 UndefinedField {
143 field: String,
144 record_ty: String,
145 #[label("field not found")]
146 span: Option<miette::SourceSpan>,
147 location: Option<mq_lang::Range>,
148 },
149 #[error("Heterogeneous array: elements have mixed types [{types}]")]
150 #[diagnostic(code(typechecker::heterogeneous_array))]
151 #[allow(dead_code)]
152 HeterogeneousArray {
153 types: String,
154 #[label("mixed types in array")]
155 span: Option<miette::SourceSpan>,
156 location: Option<mq_lang::Range>,
157 },
158 #[error("Type variable not found: {0}")]
159 #[diagnostic(code(typechecker::type_var_not_found))]
160 TypeVarNotFound(String),
161 #[error("Internal error: {0}")]
162 #[diagnostic(code(typechecker::internal_error))]
163 Internal(String),
164 #[error("Operation `{op}` may propagate `none`: argument `{nullable_arg}` may be `none`")]
167 #[diagnostic(code(typechecker::nullable_propagation))]
168 #[allow(dead_code)]
169 NullablePropagation {
170 op: String,
171 nullable_arg: String,
172 #[label("this argument may be `none`")]
173 span: Option<miette::SourceSpan>,
174 location: Option<mq_lang::Range>,
175 #[help]
176 context: Option<String>,
177 },
178 #[error("Unreachable code: this branch can never be executed")]
181 #[diagnostic(code(typechecker::unreachable_code))]
182 #[allow(dead_code)]
183 UnreachableCode {
184 reason: String,
185 #[label("this branch is unreachable")]
186 span: Option<miette::SourceSpan>,
187 location: Option<mq_lang::Range>,
188 },
189 #[error("Non-exhaustive patterns: missing case for {missing}")]
191 #[diagnostic(code(typechecker::non_exhaustive_patterns))]
192 #[allow(dead_code)]
193 NonExhaustiveMatch {
194 missing: String,
195 #[label("match expression is not exhaustive")]
196 span: Option<miette::SourceSpan>,
197 location: Option<mq_lang::Range>,
198 #[help]
199 context: Option<String>,
200 },
201}
202
203impl TypeError {
204 pub fn location(&self) -> Option<mq_lang::Range> {
206 match self {
207 TypeError::Mismatch { location, .. }
208 | TypeError::UnificationError { location, .. }
209 | TypeError::OccursCheck { location, .. }
210 | TypeError::UndefinedSymbol { location, .. }
211 | TypeError::WrongArity { location, .. }
212 | TypeError::UndefinedField { location, .. }
213 | TypeError::HeterogeneousArray { location, .. }
214 | TypeError::NullablePropagation { location, .. }
215 | TypeError::UnreachableCode { location, .. }
216 | TypeError::NonExhaustiveMatch { location, .. } => *location,
217 _ => None,
218 }
219 }
220}
221
222pub(crate) fn walk_ancestors(
227 hir: &Hir,
228 start: mq_hir::SymbolId,
229) -> impl Iterator<Item = (mq_hir::SymbolId, &mq_hir::Symbol)> {
230 let mut current = hir.symbol(start).and_then(|s| s.parent);
231 let mut depth = 0usize;
232 std::iter::from_fn(move || {
233 let id = current?;
234 depth += 1;
235 if depth > 200 {
236 current = None;
237 return None;
238 }
239 let sym = hir.symbol(id)?;
240 current = sym.parent;
241 Some((id, sym))
242 })
243}
244
245#[derive(Debug, Clone, Copy, Default)]
247pub struct TypeCheckerOptions {
248 pub strict_array: bool,
251 pub no_exhaustive_patterns: bool,
253}
254
255pub struct TypeChecker {
259 symbol_types: TypeEnv,
261 options: TypeCheckerOptions,
263}
264
265impl TypeChecker {
266 pub fn new() -> Self {
268 Self {
269 symbol_types: TypeEnv::default(),
270 options: TypeCheckerOptions::default(),
271 }
272 }
273
274 pub fn with_options(options: TypeCheckerOptions) -> Self {
276 Self {
277 symbol_types: TypeEnv::default(),
278 options,
279 }
280 }
281
282 pub fn check(&mut self, hir: &Hir) -> Vec<TypeError> {
286 let mut ctx = infer::InferenceContext::with_options(self.options.strict_array);
288
289 builtin::register_all(&mut ctx);
290
291 let children_index = constraint::generate_constraints(hir, &mut ctx);
294
295 unify::solve_constraints(&mut ctx);
297
298 let dead_branches = narrowing::resolve_type_narrowings(hir, &mut ctx);
302 for (reason, range) in dead_branches {
303 ctx.add_error(TypeError::UnreachableCode {
304 reason,
305 span: range.as_ref().map(unify::range_to_span),
306 location: range,
307 });
308 }
309
310 if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
312 unify::solve_constraints(&mut ctx);
313 }
314
315 if deferred::resolve_record_field_accesses(&mut ctx) {
319 unify::solve_constraints(&mut ctx);
321 }
322
323 deferred::resolve_selector_field_accesses(&mut ctx);
325
326 deferred::propagate_user_call_returns(&mut ctx);
332
333 deferred::resolve_deferred_overloads(&mut ctx);
337
338 if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
344 unify::solve_constraints(&mut ctx);
345 }
346
347 deferred::check_user_call_body_operators(hir, &mut ctx);
351
352 if !self.options.no_exhaustive_patterns {
354 for e in exhaustiveness::check_match_exhaustiveness(hir, &mut ctx, &children_index) {
355 ctx.add_error(e);
356 }
357 }
358
359 let errors = ctx.take_errors();
361
362 self.symbol_types = ctx.finalize();
364
365 errors
366 }
367
368 pub fn type_of(&self, symbol: SymbolId) -> Option<&TypeScheme> {
370 self.symbol_types.get(&symbol)
371 }
372
373 pub fn symbol_types(&self) -> &TypeEnv {
375 &self.symbol_types
376 }
377
378 pub fn symbol_type(&self, symbol_id: SymbolId) -> Option<&TypeScheme> {
380 self.symbol_types.get(&symbol_id)
381 }
382}
383
384impl Default for TypeChecker {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use mq_hir::SymbolKind;
394 use rstest::rstest;
395
396 #[test]
397 fn test_typechecker_creation() {
398 let checker = TypeChecker::new();
399 assert_eq!(checker.symbol_types.len(), 0);
400 }
401
402 #[test]
403 fn test_type_env() {
404 let mut env = TypeEnv::default();
405 assert!(env.is_empty());
406 assert_eq!(env.len(), 0);
407
408 let mut hir = Hir::default();
409 let (source_id, _) = hir.add_code(None, "42");
410 let (symbol_id, _) = hir.symbols_for_source(source_id).next().unwrap();
411
412 let scheme = TypeScheme::mono(types::Type::Number);
413 env.insert(symbol_id, scheme.clone());
414
415 assert!(!env.is_empty());
416 assert_eq!(env.len(), 1);
417 assert_eq!(env.get(&symbol_id), Some(&scheme));
418 assert!(env.get_all().contains_key(&symbol_id));
419
420 for (&id, s) in &env {
421 assert_eq!(id, symbol_id);
422 assert_eq!(s, &scheme);
423 }
424 }
425
426 #[rstest]
427 #[case(TypeError::Mismatch { expected: "n".into(), found: "s".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }))]
428 #[case(TypeError::UnificationError { left: "l".into(), right: "r".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }))]
429 #[case(TypeError::TypeVarNotFound("a".into()), None)]
430 fn test_type_error_location(#[case] err: TypeError, #[case] expected: Option<mq_lang::Range>) {
431 assert_eq!(err.location(), expected);
432 }
433
434 #[test]
435 fn test_walk_ancestors() {
436 let mut hir = Hir::default();
437 let code = "def f(x): x + 1;";
439 hir.add_code(None, code);
440
441 let (num_id, _) = hir
443 .symbols()
444 .find(|(_, s)| matches!(s.kind, SymbolKind::Number))
445 .unwrap();
446
447 let ancestors: Vec<_> = walk_ancestors(&hir, num_id).collect();
448 assert!(!ancestors.is_empty());
449
450 let has_function = ancestors.iter().any(|(_, s)| matches!(s.kind, SymbolKind::Function(_)));
452 assert!(has_function);
453 }
454}