Skip to main content

mq_check/
lib.rs

1//! Type inference engine for mq language using Hindley-Milner type inference.
2//!
3//! This crate provides static type checking and type inference capabilities for mq.
4//! It implements a Hindley-Milner style type inference algorithm with support for:
5//! - Automatic type inference (no type annotations required)
6//! - Polymorphic functions (generics)
7//! - Type constraints and unification
8//! - Integration with mq-hir for symbol and scope information
9//! - Error location reporting with source spans
10//!
11//! ## Error Location Reporting
12//!
13//! Type errors include location information (line and column numbers) extracted from
14//! the HIR symbols. This information is converted to `miette::SourceSpan` for diagnostic
15//! display. The span information helps users identify exactly where type errors occur
16//! in their source code.
17//!
18//! Example error output:
19//! ```text
20//! Error: Type mismatch: expected number, found string
21//!   Span: SourceSpan { offset: 42, length: 6 }
22//! ```
23
24// Suppress false-positive warnings for fields used in thiserror/miette macros
25#![allow(unused_assignments)]
26
27pub mod builtin;
28pub mod constraint;
29pub(crate) mod deferred;
30pub(crate) mod exhaustiveness;
31pub mod infer;
32pub mod narrowing;
33pub mod types;
34pub mod unify;
35
36use miette::Diagnostic;
37use mq_hir::{Hir, SymbolId};
38use rustc_hash::FxHashMap;
39use thiserror::Error;
40use types::TypeScheme;
41
42/// Result type for type checking operations
43pub type Result<T> = std::result::Result<T, TypeError>;
44
45/// Type environment mapping symbol IDs to their inferred type schemes
46#[derive(Debug, Clone, Default)]
47pub struct TypeEnv(FxHashMap<SymbolId, TypeScheme>);
48
49impl TypeEnv {
50    pub fn insert(&mut self, symbol_id: SymbolId, scheme: TypeScheme) {
51        self.0.insert(symbol_id, scheme);
52    }
53
54    pub fn get(&self, symbol_id: &SymbolId) -> Option<&TypeScheme> {
55        self.0.get(symbol_id)
56    }
57
58    pub fn get_all(&self) -> &FxHashMap<SymbolId, TypeScheme> {
59        &self.0
60    }
61
62    pub fn len(&self) -> usize {
63        self.0.len()
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.0.is_empty()
68    }
69}
70
71impl<'a> IntoIterator for &'a TypeEnv {
72    type Item = (&'a SymbolId, &'a TypeScheme);
73    type IntoIter = std::collections::hash_map::Iter<'a, SymbolId, TypeScheme>;
74
75    fn into_iter(self) -> Self::IntoIter {
76        self.0.iter()
77    }
78}
79
80/// Type checking errors
81#[derive(Debug, Error, Clone, Diagnostic)]
82#[allow(unused_assignments)]
83pub enum TypeError {
84    #[error("Type mismatch: expected {expected}, found {found}")]
85    #[diagnostic(code(typechecker::type_mismatch))]
86    #[allow(dead_code)]
87    Mismatch {
88        expected: String,
89        found: String,
90        #[label("type mismatch here")]
91        span: Option<miette::SourceSpan>,
92        location: Option<mq_lang::Range>,
93        #[help]
94        context: Option<String>,
95    },
96    #[error("Cannot unify types: {left} and {right}")]
97    #[diagnostic(code(typechecker::unification_error))]
98    #[allow(dead_code)]
99    UnificationError {
100        left: String,
101        right: String,
102        #[label("cannot unify these types")]
103        span: Option<miette::SourceSpan>,
104        location: Option<mq_lang::Range>,
105        #[help]
106        context: Option<String>,
107    },
108    #[error("Occurs check failed: type variable {var} occurs in {ty}")]
109    #[diagnostic(code(typechecker::occurs_check))]
110    #[allow(dead_code)]
111    OccursCheck {
112        var: String,
113        ty: String,
114        #[label("infinite type")]
115        span: Option<miette::SourceSpan>,
116        location: Option<mq_lang::Range>,
117    },
118    #[error("Undefined symbol: {name}")]
119    #[diagnostic(code(typechecker::undefined_symbol))]
120    #[allow(dead_code)]
121    UndefinedSymbol {
122        name: String,
123        #[label("undefined symbol")]
124        span: Option<miette::SourceSpan>,
125        location: Option<mq_lang::Range>,
126    },
127    #[error("Wrong number of arguments: expected {expected}, found {found}")]
128    #[diagnostic(code(typechecker::wrong_arity))]
129    #[allow(dead_code)]
130    WrongArity {
131        expected: usize,
132        found: usize,
133        #[label("wrong number of arguments")]
134        span: Option<miette::SourceSpan>,
135        location: Option<mq_lang::Range>,
136        #[help]
137        context: Option<String>,
138    },
139    #[error("Undefined field `{field}` in record type {record_ty}")]
140    #[diagnostic(code(typechecker::undefined_field))]
141    #[allow(dead_code)]
142    UndefinedField {
143        field: String,
144        record_ty: String,
145        #[label("field not found")]
146        span: Option<miette::SourceSpan>,
147        location: Option<mq_lang::Range>,
148    },
149    #[error("Heterogeneous array: elements have mixed types [{types}]")]
150    #[diagnostic(code(typechecker::heterogeneous_array))]
151    #[allow(dead_code)]
152    HeterogeneousArray {
153        types: String,
154        #[label("mixed types in array")]
155        span: Option<miette::SourceSpan>,
156        location: Option<mq_lang::Range>,
157    },
158    #[error("Type variable not found: {0}")]
159    #[diagnostic(code(typechecker::type_var_not_found))]
160    TypeVarNotFound(String),
161    #[error("Internal error: {0}")]
162    #[diagnostic(code(typechecker::internal_error))]
163    Internal(String),
164    /// Emitted when an operation receives a nullable argument and may silently
165    /// propagate `none` instead of producing the expected concrete type.
166    #[error("Operation `{op}` may propagate `none`: argument `{nullable_arg}` may be `none`")]
167    #[diagnostic(code(typechecker::nullable_propagation))]
168    #[allow(dead_code)]
169    NullablePropagation {
170        op: String,
171        nullable_arg: String,
172        #[label("this argument may be `none`")]
173        span: Option<miette::SourceSpan>,
174        location: Option<mq_lang::Range>,
175        #[help]
176        context: Option<String>,
177    },
178    /// Emitted when a branch is statically determined to be unreachable based on
179    /// the known type of a variable and the type predicate in the condition.
180    #[error("Unreachable code: this branch can never be executed")]
181    #[diagnostic(code(typechecker::unreachable_code))]
182    #[allow(dead_code)]
183    UnreachableCode {
184        reason: String,
185        #[label("this branch is unreachable")]
186        span: Option<miette::SourceSpan>,
187        location: Option<mq_lang::Range>,
188    },
189    /// Emitted when a match expression does not cover all possible values of the matched type.
190    #[error("Non-exhaustive patterns: missing case for {missing}")]
191    #[diagnostic(code(typechecker::non_exhaustive_patterns))]
192    #[allow(dead_code)]
193    NonExhaustiveMatch {
194        missing: String,
195        #[label("match expression is not exhaustive")]
196        span: Option<miette::SourceSpan>,
197        location: Option<mq_lang::Range>,
198        #[help]
199        context: Option<String>,
200    },
201}
202
203impl TypeError {
204    /// Returns the location (range) of the error, if available.
205    pub fn location(&self) -> Option<mq_lang::Range> {
206        match self {
207            TypeError::Mismatch { location, .. }
208            | TypeError::UnificationError { location, .. }
209            | TypeError::OccursCheck { location, .. }
210            | TypeError::UndefinedSymbol { location, .. }
211            | TypeError::WrongArity { location, .. }
212            | TypeError::UndefinedField { location, .. }
213            | TypeError::HeterogeneousArray { location, .. }
214            | TypeError::NullablePropagation { location, .. }
215            | TypeError::UnreachableCode { location, .. }
216            | TypeError::NonExhaustiveMatch { location, .. } => *location,
217            _ => None,
218        }
219    }
220}
221
222/// Walks the HIR parent chain from `start`, yielding `(SymbolId, &Symbol)` pairs.
223///
224/// Begins with the parent of `start` and follows the chain upward.
225/// Includes a depth limit to prevent infinite loops on cyclic structures.
226pub(crate) fn walk_ancestors(
227    hir: &Hir,
228    start: mq_hir::SymbolId,
229) -> impl Iterator<Item = (mq_hir::SymbolId, &mq_hir::Symbol)> {
230    let mut current = hir.symbol(start).and_then(|s| s.parent);
231    let mut depth = 0usize;
232    std::iter::from_fn(move || {
233        let id = current?;
234        depth += 1;
235        if depth > 200 {
236            current = None;
237            return None;
238        }
239        let sym = hir.symbol(id)?;
240        current = sym.parent;
241        Some((id, sym))
242    })
243}
244
245/// Options for configuring the type checker behavior
246#[derive(Debug, Clone, Copy, Default)]
247pub struct TypeCheckerOptions {
248    /// When true, arrays must contain elements of a single type.
249    /// Heterogeneous arrays like `[1, "hello"]` will produce a type error.
250    pub strict_array: bool,
251    /// When true, exhaustiveness checking for pattern match expressions is disabled.
252    pub no_exhaustive_patterns: bool,
253}
254
255/// Type checker for mq programs
256///
257/// Provides type inference and checking capabilities based on HIR information.
258pub struct TypeChecker {
259    /// Symbol type mappings
260    symbol_types: TypeEnv,
261    /// Type checker options
262    options: TypeCheckerOptions,
263}
264
265impl TypeChecker {
266    /// Creates a new type checker with default options
267    pub fn new() -> Self {
268        Self {
269            symbol_types: TypeEnv::default(),
270            options: TypeCheckerOptions::default(),
271        }
272    }
273
274    /// Creates a new type checker with the given options
275    pub fn with_options(options: TypeCheckerOptions) -> Self {
276        Self {
277            symbol_types: TypeEnv::default(),
278            options,
279        }
280    }
281
282    /// Runs type inference on the given HIR
283    ///
284    /// Returns a list of type errors found. An empty list means no errors.
285    pub fn check(&mut self, hir: &Hir) -> Vec<TypeError> {
286        // Create inference context with options
287        let mut ctx = infer::InferenceContext::with_options(self.options.strict_array);
288
289        builtin::register_all(&mut ctx);
290
291        // Generate constraints from HIR (collects errors internally).
292        // Returns the children index so it can be reused by later passes.
293        let children_index = constraint::generate_constraints(hir, &mut ctx);
294
295        // Solve constraints through unification (collects errors internally)
296        unify::solve_constraints(&mut ctx);
297
298        // Apply type narrowings from type predicate conditions (e.g., is_string(x))
299        // in if/elif branches. This overrides Ref types within narrowed branches.
300        // Returns dead branch ranges for unreachable code detection.
301        let dead_branches = narrowing::resolve_type_narrowings(hir, &mut ctx);
302        for (reason, range) in dead_branches {
303            ctx.add_error(TypeError::UnreachableCode {
304                reason,
305                span: range.as_ref().map(unify::range_to_span),
306                location: range,
307            });
308        }
309
310        // Resolve deferred tuple index accesses now that variable types are known.
311        if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
312            unify::solve_constraints(&mut ctx);
313        }
314
315        // Resolve deferred record field accesses now that variable types are known.
316        // This binds bracket access return types (e.g., v[:key]) to specific field types
317        // from Record types, enabling type error detection for subsequent operations.
318        if deferred::resolve_record_field_accesses(&mut ctx) {
319            // Re-run unification to propagate newly resolved record field types
320            unify::solve_constraints(&mut ctx);
321        }
322
323        // Resolve deferred selector field accesses (.field on records)
324        deferred::resolve_selector_field_accesses(&mut ctx);
325
326        // Propagate return types from user-defined function calls.
327        // After unification, the original function's return type may be concrete,
328        // allowing us to connect it to the fresh return type at each call site.
329        // This must run BEFORE deferred overload resolution so that operators using
330        // return types have concrete operands.
331        deferred::propagate_user_call_returns(&mut ctx);
332
333        // Process deferred overload resolutions (operators with type variable operands)
334        // After return type propagation + unification, operand types may now be resolved.
335        // Unresolved overloads are stored back for later processing.
336        deferred::resolve_deferred_overloads(&mut ctx);
337
338        // Re-run deferred tuple accesses after overload resolution, because some variable
339        // types (e.g., the return type of `first(xs)`) may only be resolved after
340        // `resolve_deferred_overloads` runs. This ensures that index accesses on variables
341        // with union types containing unresolved vars (e.g., `Union(None, Var)`) are
342        // retried with the now-concrete member types.
343        if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
344            unify::solve_constraints(&mut ctx);
345        }
346
347        // Check operators inside user-defined function bodies against call-site types.
348        // Uses local substitution (original params → call-site args) without modifying
349        // global state, so multiple call sites don't interfere.
350        deferred::check_user_call_body_operators(hir, &mut ctx);
351
352        // Check pattern match exhaustiveness (reuses the children index from constraint generation)
353        if !self.options.no_exhaustive_patterns {
354            for e in exhaustiveness::check_match_exhaustiveness(hir, &mut ctx, &children_index) {
355                ctx.add_error(e);
356            }
357        }
358
359        // Collect errors before finalizing
360        let errors = ctx.take_errors();
361
362        // Store inferred types
363        self.symbol_types = ctx.finalize();
364
365        errors
366    }
367
368    /// Gets the type of a symbol
369    pub fn type_of(&self, symbol: SymbolId) -> Option<&TypeScheme> {
370        self.symbol_types.get(&symbol)
371    }
372
373    /// Gets all symbol types
374    pub fn symbol_types(&self) -> &TypeEnv {
375        &self.symbol_types
376    }
377
378    /// Gets the type scheme for a specific symbol.
379    pub fn symbol_type(&self, symbol_id: SymbolId) -> Option<&TypeScheme> {
380        self.symbol_types.get(&symbol_id)
381    }
382}
383
384impl Default for TypeChecker {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use mq_hir::SymbolKind;
394    use rstest::rstest;
395
396    #[test]
397    fn test_typechecker_creation() {
398        let checker = TypeChecker::new();
399        assert_eq!(checker.symbol_types.len(), 0);
400    }
401
402    #[test]
403    fn test_type_env() {
404        let mut env = TypeEnv::default();
405        assert!(env.is_empty());
406        assert_eq!(env.len(), 0);
407
408        let mut hir = Hir::default();
409        let (source_id, _) = hir.add_code(None, "42");
410        let (symbol_id, _) = hir.symbols_for_source(source_id).next().unwrap();
411
412        let scheme = TypeScheme::mono(types::Type::Number);
413        env.insert(symbol_id, scheme.clone());
414
415        assert!(!env.is_empty());
416        assert_eq!(env.len(), 1);
417        assert_eq!(env.get(&symbol_id), Some(&scheme));
418        assert!(env.get_all().contains_key(&symbol_id));
419
420        for (&id, s) in &env {
421            assert_eq!(id, symbol_id);
422            assert_eq!(s, &scheme);
423        }
424    }
425
426    #[rstest]
427    #[case(TypeError::Mismatch { expected: "n".into(), found: "s".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }))]
428    #[case(TypeError::UnificationError { left: "l".into(), right: "r".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }))]
429    #[case(TypeError::TypeVarNotFound("a".into()), None)]
430    fn test_type_error_location(#[case] err: TypeError, #[case] expected: Option<mq_lang::Range>) {
431        assert_eq!(err.location(), expected);
432    }
433
434    #[test]
435    fn test_walk_ancestors() {
436        let mut hir = Hir::default();
437        // Use a nested structure: function -> body (block) -> expression
438        let code = "def f(x): x + 1;";
439        hir.add_code(None, code);
440
441        // Find the '1' literal
442        let (num_id, _) = hir
443            .symbols()
444            .find(|(_, s)| matches!(s.kind, SymbolKind::Number))
445            .unwrap();
446
447        let ancestors: Vec<_> = walk_ancestors(&hir, num_id).collect();
448        assert!(!ancestors.is_empty());
449
450        // The number '1' should be a child of the binary op '+', which is a child of the function
451        let has_function = ancestors.iter().any(|(_, s)| matches!(s.kind, SymbolKind::Function(_)));
452        assert!(has_function);
453    }
454}