Skip to main content

mq_check/
lib.rs

1//! Type inference engine for mq language using Hindley-Milner type inference.
2//!
3//! This crate provides static type checking and type inference capabilities for mq.
4//! It implements a Hindley-Milner style type inference algorithm with support for:
5//! - Automatic type inference (no type annotations required)
6//! - Polymorphic functions (generics)
7//! - Type constraints and unification
8//! - Integration with mq-hir for symbol and scope information
9//! - Error location reporting with source spans
10//!
11//! ## Error Location Reporting
12//!
13//! Type errors include location information (line and column numbers) extracted from
14//! the HIR symbols. This information is converted to `miette::SourceSpan` for diagnostic
15//! display. The span information helps users identify exactly where type errors occur
16//! in their source code.
17//!
18//! Example error output:
19//! ```text
20//! Error: Type mismatch: expected number, found string
21//!   Span: SourceSpan { offset: 42, length: 6 }
22//! ```
23
24// Suppress false-positive warnings for fields used in thiserror/miette macros
25#![allow(unused_assignments)]
26
27pub mod builtin;
28pub mod constraint;
29pub(crate) mod deferred;
30pub mod infer;
31pub mod narrowing;
32pub mod types;
33pub mod unify;
34
35use miette::Diagnostic;
36use mq_hir::{Hir, SymbolId};
37use rustc_hash::FxHashMap;
38use thiserror::Error;
39use types::TypeScheme;
40
41/// Result type for type checking operations
42pub type Result<T> = std::result::Result<T, TypeError>;
43
44/// Type environment mapping symbol IDs to their inferred type schemes
45#[derive(Debug, Clone, Default)]
46pub struct TypeEnv(FxHashMap<SymbolId, TypeScheme>);
47
48impl TypeEnv {
49    pub fn insert(&mut self, symbol_id: SymbolId, scheme: TypeScheme) {
50        self.0.insert(symbol_id, scheme);
51    }
52
53    pub fn get(&self, symbol_id: &SymbolId) -> Option<&TypeScheme> {
54        self.0.get(symbol_id)
55    }
56
57    pub fn get_all(&self) -> &FxHashMap<SymbolId, TypeScheme> {
58        &self.0
59    }
60
61    pub fn len(&self) -> usize {
62        self.0.len()
63    }
64
65    pub fn is_empty(&self) -> bool {
66        self.0.is_empty()
67    }
68}
69
70impl<'a> IntoIterator for &'a TypeEnv {
71    type Item = (&'a SymbolId, &'a TypeScheme);
72    type IntoIter = std::collections::hash_map::Iter<'a, SymbolId, TypeScheme>;
73
74    fn into_iter(self) -> Self::IntoIter {
75        self.0.iter()
76    }
77}
78
79/// Type checking errors
80#[derive(Debug, Error, Clone, Diagnostic)]
81#[allow(unused_assignments)]
82pub enum TypeError {
83    #[error("Type mismatch: expected {expected}, found {found}")]
84    #[diagnostic(code(typechecker::type_mismatch))]
85    #[allow(dead_code)]
86    Mismatch {
87        expected: String,
88        found: String,
89        #[label("type mismatch here")]
90        span: Option<miette::SourceSpan>,
91        location: Option<mq_lang::Range>,
92        #[help]
93        context: Option<String>,
94    },
95    #[error("Cannot unify types: {left} and {right}")]
96    #[diagnostic(code(typechecker::unification_error))]
97    #[allow(dead_code)]
98    UnificationError {
99        left: String,
100        right: String,
101        #[label("cannot unify these types")]
102        span: Option<miette::SourceSpan>,
103        location: Option<mq_lang::Range>,
104        #[help]
105        context: Option<String>,
106    },
107    #[error("Occurs check failed: type variable {var} occurs in {ty}")]
108    #[diagnostic(code(typechecker::occurs_check))]
109    #[allow(dead_code)]
110    OccursCheck {
111        var: String,
112        ty: String,
113        #[label("infinite type")]
114        span: Option<miette::SourceSpan>,
115        location: Option<mq_lang::Range>,
116    },
117    #[error("Undefined symbol: {name}")]
118    #[diagnostic(code(typechecker::undefined_symbol))]
119    #[allow(dead_code)]
120    UndefinedSymbol {
121        name: String,
122        #[label("undefined symbol")]
123        span: Option<miette::SourceSpan>,
124        location: Option<mq_lang::Range>,
125    },
126    #[error("Wrong number of arguments: expected {expected}, found {found}")]
127    #[diagnostic(code(typechecker::wrong_arity))]
128    #[allow(dead_code)]
129    WrongArity {
130        expected: usize,
131        found: usize,
132        #[label("wrong number of arguments")]
133        span: Option<miette::SourceSpan>,
134        location: Option<mq_lang::Range>,
135        #[help]
136        context: Option<String>,
137    },
138    #[error("Undefined field `{field}` in record type {record_ty}")]
139    #[diagnostic(code(typechecker::undefined_field))]
140    #[allow(dead_code)]
141    UndefinedField {
142        field: String,
143        record_ty: String,
144        #[label("field not found")]
145        span: Option<miette::SourceSpan>,
146        location: Option<mq_lang::Range>,
147    },
148    #[error("Heterogeneous array: elements have mixed types [{types}]")]
149    #[diagnostic(code(typechecker::heterogeneous_array))]
150    #[allow(dead_code)]
151    HeterogeneousArray {
152        types: String,
153        #[label("mixed types in array")]
154        span: Option<miette::SourceSpan>,
155        location: Option<mq_lang::Range>,
156    },
157    #[error("Type variable not found: {0}")]
158    #[diagnostic(code(typechecker::type_var_not_found))]
159    TypeVarNotFound(String),
160    #[error("Internal error: {0}")]
161    #[diagnostic(code(typechecker::internal_error))]
162    Internal(String),
163    /// Emitted when an operation receives a nullable argument and may silently
164    /// propagate `none` instead of producing the expected concrete type.
165    #[error("Operation `{op}` may propagate `none`: argument `{nullable_arg}` may be `none`")]
166    #[diagnostic(code(typechecker::nullable_propagation))]
167    #[allow(dead_code)]
168    NullablePropagation {
169        op: String,
170        nullable_arg: String,
171        #[label("this argument may be `none`")]
172        span: Option<miette::SourceSpan>,
173        location: Option<mq_lang::Range>,
174        #[help]
175        context: Option<String>,
176    },
177    /// Emitted when a branch is statically determined to be unreachable based on
178    /// the known type of a variable and the type predicate in the condition.
179    #[error("Unreachable code: this branch can never be executed")]
180    #[diagnostic(code(typechecker::unreachable_code))]
181    #[allow(dead_code)]
182    UnreachableCode {
183        reason: String,
184        #[label("this branch is unreachable")]
185        span: Option<miette::SourceSpan>,
186        location: Option<mq_lang::Range>,
187    },
188}
189
190impl TypeError {
191    /// Returns the location (range) of the error, if available.
192    pub fn location(&self) -> Option<mq_lang::Range> {
193        match self {
194            TypeError::Mismatch { location, .. }
195            | TypeError::UnificationError { location, .. }
196            | TypeError::OccursCheck { location, .. }
197            | TypeError::UndefinedSymbol { location, .. }
198            | TypeError::WrongArity { location, .. }
199            | TypeError::UndefinedField { location, .. }
200            | TypeError::HeterogeneousArray { location, .. }
201            | TypeError::NullablePropagation { location, .. }
202            | TypeError::UnreachableCode { location, .. } => *location,
203            _ => None,
204        }
205    }
206}
207
208/// Walks the HIR parent chain from `start`, yielding `(SymbolId, &Symbol)` pairs.
209///
210/// Begins with the parent of `start` and follows the chain upward.
211/// Includes a depth limit to prevent infinite loops on cyclic structures.
212pub(crate) fn walk_ancestors(
213    hir: &Hir,
214    start: mq_hir::SymbolId,
215) -> impl Iterator<Item = (mq_hir::SymbolId, &mq_hir::Symbol)> {
216    let mut current = hir.symbol(start).and_then(|s| s.parent);
217    let mut depth = 0usize;
218    std::iter::from_fn(move || {
219        let id = current?;
220        depth += 1;
221        if depth > 200 {
222            current = None;
223            return None;
224        }
225        let sym = hir.symbol(id)?;
226        current = sym.parent;
227        Some((id, sym))
228    })
229}
230
231/// Options for configuring the type checker behavior
232#[derive(Debug, Clone, Copy, Default)]
233pub struct TypeCheckerOptions {
234    /// When true, arrays must contain elements of a single type.
235    /// Heterogeneous arrays like `[1, "hello"]` will produce a type error.
236    pub strict_array: bool,
237}
238
239/// Type checker for mq programs
240///
241/// Provides type inference and checking capabilities based on HIR information.
242pub struct TypeChecker {
243    /// Symbol type mappings
244    symbol_types: TypeEnv,
245    /// Type checker options
246    options: TypeCheckerOptions,
247}
248
249impl TypeChecker {
250    /// Creates a new type checker with default options
251    pub fn new() -> Self {
252        Self {
253            symbol_types: TypeEnv::default(),
254            options: TypeCheckerOptions::default(),
255        }
256    }
257
258    /// Creates a new type checker with the given options
259    pub fn with_options(options: TypeCheckerOptions) -> Self {
260        Self {
261            symbol_types: TypeEnv::default(),
262            options,
263        }
264    }
265
266    /// Runs type inference on the given HIR
267    ///
268    /// Returns a list of type errors found. An empty list means no errors.
269    pub fn check(&mut self, hir: &Hir) -> Vec<TypeError> {
270        // Create inference context with options
271        let mut ctx = infer::InferenceContext::with_options(self.options.strict_array);
272
273        builtin::register_all(&mut ctx);
274
275        // Generate constraints from HIR (collects errors internally)
276        constraint::generate_constraints(hir, &mut ctx);
277
278        // Solve constraints through unification (collects errors internally)
279        unify::solve_constraints(&mut ctx);
280
281        // Apply type narrowings from type predicate conditions (e.g., is_string(x))
282        // in if/elif branches. This overrides Ref types within narrowed branches.
283        // Returns dead branch ranges for unreachable code detection.
284        let dead_branches = narrowing::resolve_type_narrowings(hir, &mut ctx);
285        for (reason, range) in dead_branches {
286            ctx.add_error(TypeError::UnreachableCode {
287                reason,
288                span: range.as_ref().map(unify::range_to_span),
289                location: range,
290            });
291        }
292
293        // Resolve deferred tuple index accesses now that variable types are known.
294        if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
295            unify::solve_constraints(&mut ctx);
296        }
297
298        // Resolve deferred record field accesses now that variable types are known.
299        // This binds bracket access return types (e.g., v[:key]) to specific field types
300        // from Record types, enabling type error detection for subsequent operations.
301        if deferred::resolve_record_field_accesses(&mut ctx) {
302            // Re-run unification to propagate newly resolved record field types
303            unify::solve_constraints(&mut ctx);
304        }
305
306        // Resolve deferred selector field accesses (.field on records)
307        deferred::resolve_selector_field_accesses(&mut ctx);
308
309        // Propagate return types from user-defined function calls.
310        // After unification, the original function's return type may be concrete,
311        // allowing us to connect it to the fresh return type at each call site.
312        // This must run BEFORE deferred overload resolution so that operators using
313        // return types have concrete operands.
314        deferred::propagate_user_call_returns(&mut ctx);
315
316        // Process deferred overload resolutions (operators with type variable operands)
317        // After return type propagation + unification, operand types may now be resolved.
318        // Unresolved overloads are stored back for later processing.
319        deferred::resolve_deferred_overloads(&mut ctx);
320
321        // Re-run deferred tuple accesses after overload resolution, because some variable
322        // types (e.g., the return type of `first(xs)`) may only be resolved after
323        // `resolve_deferred_overloads` runs. This ensures that index accesses on variables
324        // with union types containing unresolved vars (e.g., `Union(None, Var)`) are
325        // retried with the now-concrete member types.
326        if deferred::resolve_deferred_tuple_accesses(&mut ctx) {
327            unify::solve_constraints(&mut ctx);
328        }
329
330        // Check operators inside user-defined function bodies against call-site types.
331        // Uses local substitution (original params → call-site args) without modifying
332        // global state, so multiple call sites don't interfere.
333        deferred::check_user_call_body_operators(hir, &mut ctx);
334
335        // Collect errors before finalizing
336        let errors = ctx.take_errors();
337
338        // Store inferred types
339        self.symbol_types = ctx.finalize();
340
341        errors
342    }
343
344    /// Gets the type of a symbol
345    pub fn type_of(&self, symbol: SymbolId) -> Option<&TypeScheme> {
346        self.symbol_types.get(&symbol)
347    }
348
349    /// Gets all symbol types
350    pub fn symbol_types(&self) -> &TypeEnv {
351        &self.symbol_types
352    }
353
354    /// Gets the type scheme for a specific symbol.
355    pub fn symbol_type(&self, symbol_id: SymbolId) -> Option<&TypeScheme> {
356        self.symbol_types.get(&symbol_id)
357    }
358}
359
360impl Default for TypeChecker {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use mq_hir::SymbolKind;
370    use rstest::rstest;
371
372    #[test]
373    fn test_typechecker_creation() {
374        let checker = TypeChecker::new();
375        assert_eq!(checker.symbol_types.len(), 0);
376    }
377
378    #[test]
379    fn test_type_env() {
380        let mut env = TypeEnv::default();
381        assert!(env.is_empty());
382        assert_eq!(env.len(), 0);
383
384        let mut hir = Hir::default();
385        let (source_id, _) = hir.add_code(None, "42");
386        let (symbol_id, _) = hir.symbols_for_source(source_id).next().unwrap();
387
388        let scheme = TypeScheme::mono(types::Type::Number);
389        env.insert(symbol_id, scheme.clone());
390
391        assert!(!env.is_empty());
392        assert_eq!(env.len(), 1);
393        assert_eq!(env.get(&symbol_id), Some(&scheme));
394        assert!(env.get_all().contains_key(&symbol_id));
395
396        for (&id, s) in &env {
397            assert_eq!(id, symbol_id);
398            assert_eq!(s, &scheme);
399        }
400    }
401
402    #[rstest]
403    #[case(TypeError::Mismatch { expected: "n".into(), found: "s".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 1, column: 5 }, end: mq_lang::Position { line: 1, column: 6 } }))]
404    #[case(TypeError::UnificationError { left: "l".into(), right: "r".into(), span: None, location: Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }), context: None }, Some(mq_lang::Range { start: mq_lang::Position { line: 2, column: 10 }, end: mq_lang::Position { line: 2, column: 11 } }))]
405    #[case(TypeError::TypeVarNotFound("a".into()), None)]
406    fn test_type_error_location(#[case] err: TypeError, #[case] expected: Option<mq_lang::Range>) {
407        assert_eq!(err.location(), expected);
408    }
409
410    #[test]
411    fn test_walk_ancestors() {
412        let mut hir = Hir::default();
413        // Use a nested structure: function -> body (block) -> expression
414        let code = "def f(x): x + 1;";
415        hir.add_code(None, code);
416
417        // Find the '1' literal
418        let (num_id, _) = hir
419            .symbols()
420            .find(|(_, s)| matches!(s.kind, SymbolKind::Number))
421            .unwrap();
422
423        let ancestors: Vec<_> = walk_ancestors(&hir, num_id).collect();
424        assert!(!ancestors.is_empty());
425
426        // The number '1' should be a child of the binary op '+', which is a child of the function
427        let has_function = ancestors.iter().any(|(_, s)| matches!(s.kind, SymbolKind::Function(_)));
428        assert!(has_function);
429    }
430}