1#![allow(unused_assignments)]
26
27pub mod builtin;
28pub mod constraint;
29pub(crate) mod deferred;
30pub mod infer;
31pub mod narrowing;
32pub mod types;
33pub mod unify;
34
35use miette::Diagnostic;
36use mq_hir::{Hir, SymbolId};
37use rustc_hash::FxHashMap;
38use thiserror::Error;
39use types::TypeScheme;
40
41pub type Result<T> = std::result::Result<T, TypeError>;
43
44#[derive(Debug, Clone, Default)]
46pub struct TypeEnv(FxHashMap<SymbolId, TypeScheme>);
47
48impl TypeEnv {
49 pub fn insert(&mut self, symbol_id: SymbolId, scheme: TypeScheme) {
50 self.0.insert(symbol_id, scheme);
51 }
52
53 pub fn get(&self, symbol_id: &SymbolId) -> Option<&TypeScheme> {
54 self.0.get(symbol_id)
55 }
56
57 pub fn get_all(&self) -> &FxHashMap<SymbolId, TypeScheme> {
58 &self.0
59 }
60
61 pub fn len(&self) -> usize {
62 self.0.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
66 self.0.is_empty()
67 }
68}
69
70impl<'a> IntoIterator for &'a TypeEnv {
71 type Item = (&'a SymbolId, &'a TypeScheme);
72 type IntoIter = std::collections::hash_map::Iter<'a, SymbolId, TypeScheme>;
73
74 fn into_iter(self) -> Self::IntoIter {
75 self.0.iter()
76 }
77}
78
79#[derive(Debug, Error, Clone, Diagnostic)]
81#[allow(unused_assignments)]
82pub enum TypeError {
83 #[error("Type mismatch: expected {expected}, found {found}")]
84 #[diagnostic(code(typechecker::type_mismatch))]
85 #[allow(dead_code)]
86 Mismatch {
87 expected: String,
88 found: String,
89 #[label("type mismatch here")]
90 span: Option<miette::SourceSpan>,
91 location: Option<mq_lang::Range>,
92 #[help]
93 context: Option<String>,
94 },
95 #[error("Cannot unify types: {left} and {right}")]
96 #[diagnostic(code(typechecker::unification_error))]
97 #[allow(dead_code)]
98 UnificationError {
99 left: String,
100 right: String,
101 #[label("cannot unify these types")]
102 span: Option<miette::SourceSpan>,
103 location: Option<mq_lang::Range>,
104 #[help]
105 context: Option<String>,
106 },
107 #[error("Occurs check failed: type variable {var} occurs in {ty}")]
108 #[diagnostic(code(typechecker::occurs_check))]
109 #[allow(dead_code)]
110 OccursCheck {
111 var: String,
112 ty: String,
113 #[label("infinite type")]
114 span: Option<miette::SourceSpan>,
115 location: Option<mq_lang::Range>,
116 },
117 #[error("Undefined symbol: {name}")]
118 #[diagnostic(code(typechecker::undefined_symbol))]
119 #[allow(dead_code)]
120 UndefinedSymbol {
121 name: String,
122 #[label("undefined symbol")]
123 span: Option<miette::SourceSpan>,
124 location: Option<mq_lang::Range>,
125 },
126 #[error("Wrong number of arguments: expected {expected}, found {found}")]
127 #[diagnostic(code(typechecker::wrong_arity))]
128 #[allow(dead_code)]
129 WrongArity {
130 expected: usize,
131 found: usize,
132 #[label("wrong number of arguments")]
133 span: Option<miette::SourceSpan>,
134 location: Option<mq_lang::Range>,
135 #[help]
136 context: Option<String>,
137 },
138 #[error("Undefined field `{field}` in record type {record_ty}")]
139 #[diagnostic(code(typechecker::undefined_field))]
140 #[allow(dead_code)]
141 UndefinedField {
142 field: String,
143 record_ty: String,
144 #[label("field not found")]
145 span: Option<miette::SourceSpan>,
146 location: Option<mq_lang::Range>,
147 },
148 #[error("Heterogeneous array: elements have mixed types [{types}]")]
149 #[diagnostic(code(typechecker::heterogeneous_array))]
150 #[allow(dead_code)]
151 HeterogeneousArray {
152 types: String,
153 #[label("mixed types in array")]
154 span: Option<miette::SourceSpan>,
155 location: Option<mq_lang::Range>,
156 },
157 #[error("Type variable not found: {0}")]
158 #[diagnostic(code(typechecker::type_var_not_found))]
159 TypeVarNotFound(String),
160 #[error("Internal error: {0}")]
161 #[diagnostic(code(typechecker::internal_error))]
162 Internal(String),
163 #[error("Operation `{op}` may propagate `none`: argument `{nullable_arg}` may be `none`")]
166 #[diagnostic(code(typechecker::nullable_propagation))]
167 #[allow(dead_code)]
168 NullablePropagation {
169 op: String,
170 nullable_arg: String,
171 #[label("this argument may be `none`")]
172 span: Option<miette::SourceSpan>,
173 location: Option<mq_lang::Range>,
174 #[help]
175 context: Option<String>,
176 },
177 #[error("Unreachable code: this branch can never be executed")]
180 #[diagnostic(code(typechecker::unreachable_code))]
181 #[allow(dead_code)]
182 UnreachableCode {
183 reason: String,
184 #[label("this branch is unreachable")]
185 span: Option<miette::SourceSpan>,
186 location: Option<mq_lang::Range>,
187 },
188}
189
190impl TypeError {
191 pub fn location(&self) -> Option<mq_lang::Range> {
193 match self {
194 TypeError::Mismatch { location, .. }
195 | TypeError::UnificationError { location, .. }
196 | TypeError::OccursCheck { location, .. }
197 | TypeError::UndefinedSymbol { location, .. }
198 | TypeError::WrongArity { location, .. }
199 | TypeError::UndefinedField { location, .. }
200 | TypeError::HeterogeneousArray { location, .. }
201 | TypeError::NullablePropagation { location, .. }
202 | TypeError::UnreachableCode { location, .. } => *location,
203 _ => None,
204 }
205 }
206}
207
208pub(crate) fn walk_ancestors(
213 hir: &Hir,
214 start: mq_hir::SymbolId,
215) -> impl Iterator<Item = (mq_hir::SymbolId, &mq_hir::Symbol)> {
216 let mut current = hir.symbol(start).and_then(|s| s.parent);
217 let mut depth = 0usize;
218 std::iter::from_fn(move || {
219 let id = current?;
220 depth += 1;
221 if depth > 200 {
222 current = None;
223 return None;
224 }
225 let sym = hir.symbol(id)?;
226 current = sym.parent;
227 Some((id, sym))
228 })
229}
230
231#[derive(Debug, Clone, Copy, Default)]
233pub struct TypeCheckerOptions {
234 pub strict_array: bool,
237}
238
239pub struct TypeChecker {
243 symbol_types: TypeEnv,
245 options: TypeCheckerOptions,
247}
248
249impl TypeChecker {
250 pub fn new() -> Self {
252 Self {
253 symbol_types: TypeEnv::default(),
254 options: TypeCheckerOptions::default(),
255 }
256 }
257
258 pub fn with_options(options: TypeCheckerOptions) -> Self {
260 Self {
261 symbol_types: TypeEnv::default(),
262 options,
263 }
264 }
265
266 pub fn check(&mut self, hir: &Hir) -> Vec<TypeError> {
270 let mut ctx = infer::InferenceContext::with_options(self.options.strict_array);
272
273 builtin::register_all(&mut ctx);
274
275 constraint::generate_constraints(hir, &mut ctx);
277
278 unify::solve_constraints(&mut ctx);
280
281 let dead_branches = narrowing::resolve_type_narrowings(hir, &mut ctx);
285 for (reason, range) in dead_branches {
286 ctx.add_error(TypeError::UnreachableCode {
287 reason,
288 span: range.as_ref().map(unify::range_to_span),
289 location: range,
290 });
291 }
292
293 if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
295 unify::solve_constraints(&mut ctx);
296 }
297
298 if deferred::resolve_record_field_accesses(&mut ctx) {
302 unify::solve_constraints(&mut ctx);
304 }
305
306 deferred::resolve_selector_field_accesses(&mut ctx);
308
309 deferred::propagate_user_call_returns(&mut ctx);
315
316 deferred::resolve_deferred_overloads(&mut ctx);
320
321 if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
327 unify::solve_constraints(&mut ctx);
328 }
329
330 deferred::check_user_call_body_operators(hir, &mut ctx);
334
335 let errors = ctx.take_errors();
337
338 self.symbol_types = ctx.finalize();
340
341 errors
342 }
343
344 pub fn type_of(&self, symbol: SymbolId) -> Option<&TypeScheme> {
346 self.symbol_types.get(&symbol)
347 }
348
349 pub fn symbol_types(&self) -> &TypeEnv {
351 &self.symbol_types
352 }
353
354 pub fn symbol_type(&self, symbol_id: SymbolId) -> Option<&TypeScheme> {
356 self.symbol_types.get(&symbol_id)
357 }
358}
359
360impl Default for TypeChecker {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use mq_hir::SymbolKind;
370 use rstest::rstest;
371
372 #[test]
373 fn test_typechecker_creation() {
374 let checker = TypeChecker::new();
375 assert_eq!(checker.symbol_types.len(), 0);
376 }
377
378 #[test]
379 fn test_type_env() {
380 let mut env = TypeEnv::default();
381 assert!(env.is_empty());
382 assert_eq!(env.len(), 0);
383
384 let mut hir = Hir::default();
385 let (source_id, _) = hir.add_code(None, "42");
386 let (symbol_id, _) = hir.symbols_for_source(source_id).next().unwrap();
387
388 let scheme = TypeScheme::mono(types::Type::Number);
389 env.insert(symbol_id, scheme.clone());
390
391 assert!(!env.is_empty());
392 assert_eq!(env.len(), 1);
393 assert_eq!(env.get(&symbol_id), Some(&scheme));
394 assert!(env.get_all().contains_key(&symbol_id));
395
396 for (&id, s) in &env {
397 assert_eq!(id, symbol_id);
398 assert_eq!(s, &scheme);
399 }
400 }
401
402 #[rstest]
403 #[case(TypeError::Mismatch { expected: "n".into(), found: "s".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }))]
404 #[case(TypeError::UnificationError { left: "l".into(), right: "r".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }))]
405 #[case(TypeError::TypeVarNotFound("a".into()), None)]
406 fn test_type_error_location(#[case] err: TypeError, #[case] expected: Option<mq_lang::Range>) {
407 assert_eq!(err.location(), expected);
408 }
409
410 #[test]
411 fn test_walk_ancestors() {
412 let mut hir = Hir::default();
413 let code = "def f(x): x + 1;";
415 hir.add_code(None, code);
416
417 let (num_id, _) = hir
419 .symbols()
420 .find(|(_, s)| matches!(s.kind, SymbolKind::Number))
421 .unwrap();
422
423 let ancestors: Vec<_> = walk_ancestors(&hir, num_id).collect();
424 assert!(!ancestors.is_empty());
425
426 let has_function = ancestors.iter().any(|(_, s)| matches!(s.kind, SymbolKind::Function(_)));
428 assert!(has_function);
429 }
430}