Skip to main content

mq_check/
types.rs

1//! Type representations for the mq type system.
2
3use rustc_hash::FxHashMap;
4use slotmap::SlotMap;
5use std::collections::BTreeMap;
6use std::fmt;
7
8slotmap::new_key_type! {
9    /// Unique identifier for type variables
10    pub struct TypeVarId;
11}
12
13/// Represents a type in the mq type system
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum Type {
16    /// Integer type
17    Int,
18    /// Floating point type
19    Float,
20    /// Number type (unified numeric type)
21    Number,
22    /// String type
23    String,
24    /// Boolean type
25    Bool,
26    /// Symbol type
27    Symbol,
28    /// None/null type
29    None,
30    /// Markdown document type
31    Markdown,
32    /// Array type with element type
33    Array(Box<Type>),
34    /// Tuple type with known element types (e.g., `(number, string)`)
35    ///
36    /// Used when `--tuple` mode is enabled to track per-element types
37    /// for small heterogeneous arrays like `[1, "hello"]`.
38    Tuple(Vec<Type>),
39    /// Dictionary type with key and value types
40    Dict(Box<Type>, Box<Type>),
41    /// Function type: arguments -> return type
42    Function(Vec<Type>, Box<Type>),
43    /// Union type: represents a value that could be one of multiple types
44    /// Used for try/catch expressions with different branch types
45    Union(Vec<Type>),
46    /// Record type with known fields and optional row extension (row polymorphism).
47    ///
48    /// The first element is a map of field names to their types.
49    /// The second element is the row tail: either `RowEmpty` (closed record)
50    /// or `Var(id)` (open record that may have additional fields).
51    ///
52    /// Examples:
53    /// - `{a: number, b: string}` → `Record({"a": Number, "b": String}, RowEmpty)`
54    /// - `{a: number | r}` → `Record({"a": Number}, Var(r))`
55    Record(BTreeMap<String, Type>, Box<Type>),
56    /// Empty row — marks a closed record with no additional fields
57    RowEmpty,
58    /// Type variable for inference
59    Var(TypeVarId),
60    /// Never type (bottom type) — represents unreachable code paths.
61    ///
62    /// Produced when type narrowing eliminates all possible types from a union.
63    /// For example, `Union(Array(String), Array(Number))` minus `is_array` predicate
64    /// leaves no viable types, producing `Never` for the else-branch.
65    Never,
66}
67
68impl Type {
69    /// Creates a new function type
70    pub fn function(params: Vec<Type>, ret: Type) -> Self {
71        Type::Function(params, Box::new(ret))
72    }
73
74    /// Creates a new array type
75    pub fn array(elem: Type) -> Self {
76        Type::Array(Box::new(elem))
77    }
78
79    /// Creates a new tuple type with known element types
80    pub fn tuple(elems: Vec<Type>) -> Self {
81        Type::Tuple(elems)
82    }
83
84    /// Creates a new dict type
85    pub fn dict(key: Type, value: Type) -> Self {
86        Type::Dict(Box::new(key), Box::new(value))
87    }
88
89    /// Creates a new record type with known fields
90    pub fn record(fields: BTreeMap<String, Type>, rest: Type) -> Self {
91        Type::Record(fields, Box::new(rest))
92    }
93
94    /// Creates a new union type from two or more types
95    /// Automatically normalizes the union (removes duplicates, flattens nested unions)
96    pub fn union(types: Vec<Type>) -> Self {
97        let mut normalized = Vec::with_capacity(types.len());
98        for ty in types {
99            match ty {
100                // Flatten nested unions
101                Type::Union(inner) => normalized.extend(inner),
102                _ => normalized.push(ty),
103            }
104        }
105
106        // Deduplicate and sort
107        if normalized.len() <= 4 {
108            // For small unions, avoid HashSet allocation with an O(n²) scan
109            let mut i = 0;
110            while i < normalized.len() {
111                if normalized[..i].iter().any(|t| t == &normalized[i]) {
112                    normalized.swap_remove(i);
113                } else {
114                    i += 1;
115                }
116            }
117            normalized.sort_by_key(|t| t.discriminant());
118        } else {
119            // For larger unions, use a HashSet for O(n) deduplication
120            let mut seen = rustc_hash::FxHashSet::default();
121            normalized.retain(|t| seen.insert(t.clone()));
122            normalized.sort_by_key(|t| t.discriminant());
123        }
124
125        // If only one type remains, return it directly
126        if normalized.len() == 1 {
127            normalized.into_iter().next().unwrap()
128        } else {
129            Type::Union(normalized)
130        }
131    }
132
133    /// Removes a type from a union by discriminant, returning the remaining type.
134    ///
135    /// If this is not a union, returns self unchanged.
136    /// If only one type remains after subtraction, returns it directly (unwrapped).
137    /// If all members are removed (e.g., `Union(Array(String), Array(Number))` minus
138    /// `Array`), returns `Type::Never` to signal an unreachable code path.
139    pub fn subtract(&self, exclude: &Type) -> Type {
140        match self {
141            Type::Union(members) => {
142                let remaining: Vec<Type> = members
143                    .iter()
144                    .filter(|t| std::mem::discriminant(*t) != std::mem::discriminant(exclude))
145                    .cloned()
146                    .collect();
147                if remaining.is_empty() {
148                    Type::Never
149                } else {
150                    Type::union(remaining)
151                }
152            }
153            _ => self.clone(),
154        }
155    }
156
157    /// Checks if this type is nullable (is `None` or a `Union` containing `None`).
158    pub fn is_nullable(&self) -> bool {
159        match self {
160            Type::None => true,
161            Type::Union(members) => members.iter().any(|m| matches!(m, Type::None)),
162            _ => false,
163        }
164    }
165
166    /// Checks if this is the bottom type (never/unreachable).
167    pub fn is_never(&self) -> bool {
168        matches!(self, Type::Never)
169    }
170
171    /// Returns a numeric discriminant for ordering purposes.
172    /// Used by `Type::union` to sort union members without allocating.
173    fn discriminant(&self) -> u8 {
174        match self {
175            Type::Int => 0,
176            Type::Float => 1,
177            Type::Number => 2,
178            Type::String => 3,
179            Type::Bool => 4,
180            Type::Symbol => 5,
181            Type::None => 6,
182            Type::Markdown => 7,
183            Type::Array(_) => 8,
184            Type::Tuple(_) => 9,
185            Type::Dict(_, _) => 10,
186            Type::Function(_, _) => 11,
187            Type::Union(_) => 12,
188            Type::Record(_, _) => 13,
189            Type::RowEmpty => 14,
190            Type::Var(_) => 15,
191            Type::Never => 16,
192        }
193    }
194
195    /// Checks if this is a type variable
196    pub fn is_var(&self) -> bool {
197        matches!(self, Type::Var(_))
198    }
199
200    /// Checks if this type contains no free type variables (is fully concrete)
201    pub fn is_concrete(&self) -> bool {
202        self.free_vars().is_empty()
203    }
204
205    /// Checks if this is a union type
206    pub fn is_union(&self) -> bool {
207        matches!(self, Type::Union(_))
208    }
209
210    /// Gets the type variable ID if this is a type variable
211    pub fn as_var(&self) -> Option<TypeVarId> {
212        match self {
213            Type::Var(id) => Some(*id),
214            _ => None,
215        }
216    }
217
218    /// Substitutes type variables according to the given substitution
219    pub fn apply_subst(&self, subst: &Substitution) -> Type {
220        if subst.is_empty() {
221            return self.clone();
222        }
223        match self {
224            Type::Var(id) => subst.lookup(*id).map_or_else(|| self.clone(), |t| t.apply_subst(subst)),
225            Type::Array(elem) => Type::Array(Box::new(elem.apply_subst(subst))),
226            Type::Tuple(elems) => Type::Tuple(elems.iter().map(|e| e.apply_subst(subst)).collect()),
227            Type::Dict(key, value) => Type::Dict(Box::new(key.apply_subst(subst)), Box::new(value.apply_subst(subst))),
228            Type::Function(params, ret) => {
229                let new_params = params.iter().map(|p| p.apply_subst(subst)).collect();
230                Type::Function(new_params, Box::new(ret.apply_subst(subst)))
231            }
232            Type::Union(types) => {
233                let new_types = types.iter().map(|t| t.apply_subst(subst)).collect();
234                Type::union(new_types)
235            }
236            Type::Record(fields, rest) => {
237                let new_fields = fields.iter().map(|(k, v)| (k.clone(), v.apply_subst(subst))).collect();
238                Type::Record(new_fields, Box::new(rest.apply_subst(subst)))
239            }
240            // Never and other ground types have no type variables to substitute
241            _ => self.clone(),
242        }
243    }
244
245    /// Gets all free type variables in this type
246    pub fn free_vars(&self) -> Vec<TypeVarId> {
247        match self {
248            Type::Var(id) => vec![*id],
249            Type::Array(elem) => elem.free_vars(),
250            Type::Tuple(elems) => elems.iter().flat_map(|e| e.free_vars()).collect(),
251            Type::Dict(key, value) => {
252                let mut vars = key.free_vars();
253                vars.extend(value.free_vars());
254                vars
255            }
256            Type::Function(params, ret) => {
257                let mut vars: Vec<TypeVarId> = params.iter().flat_map(|p| p.free_vars()).collect();
258                vars.extend(ret.free_vars());
259                vars
260            }
261            Type::Union(types) => types.iter().flat_map(|t| t.free_vars()).collect(),
262            Type::Record(fields, rest) => {
263                let mut vars: Vec<TypeVarId> = fields.values().flat_map(|v| v.free_vars()).collect();
264                vars.extend(rest.free_vars());
265                vars
266            }
267            _ => Vec::new(),
268        }
269    }
270
271    /// Checks if this type can match with another type (for overload resolution).
272    ///
273    /// This is a weaker check than unification - it returns true if the types
274    /// could potentially be unified, but doesn't require them to be identical.
275    /// Type variables always match.
276    pub fn can_match(&self, other: &Type) -> bool {
277        match (self, other) {
278            // Never (bottom type) can match anything
279            (Type::Never, _) | (_, Type::Never) => true,
280
281            // Type variables always match
282            (Type::Var(_), _) | (_, Type::Var(_)) => true,
283
284            // Union types match if any of their constituent types can match
285            (Type::Union(types), other) => types.iter().any(|t| t.can_match(other)),
286            (other, Type::Union(types)) => types.iter().any(|t| other.can_match(t)),
287
288            // Concrete types must match exactly
289            (Type::Int, Type::Int)
290            | (Type::Float, Type::Float)
291            | (Type::Number, Type::Number)
292            | (Type::String, Type::String)
293            | (Type::Bool, Type::Bool)
294            | (Type::Symbol, Type::Symbol)
295            | (Type::None, Type::None)
296            | (Type::Markdown, Type::Markdown) => true,
297
298            // Arrays match if their element types can match
299            (Type::Array(elem1), Type::Array(elem2)) => elem1.can_match(elem2),
300
301            // Tuples match if they have the same length and all elements can match
302            (Type::Tuple(elems1), Type::Tuple(elems2)) => {
303                elems1.len() == elems2.len() && elems1.iter().zip(elems2.iter()).all(|(e1, e2)| e1.can_match(e2))
304            }
305
306            // Tuple can match Array (compatibility)
307            (Type::Tuple(_), Type::Array(_)) | (Type::Array(_), Type::Tuple(_)) => true,
308
309            // Dicts match if both key and value types can match
310            (Type::Dict(k1, v1), Type::Dict(k2, v2)) => k1.can_match(k2) && v1.can_match(v2),
311
312            // Functions match if they have the same arity and all parameter/return types can match
313            (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
314                params1.len() == params2.len()
315                    && params1.iter().zip(params2.iter()).all(|(p1, p2)| p1.can_match(p2))
316                    && ret1.can_match(ret2)
317            }
318
319            // Records match if common fields can match and rests can match
320            (Type::Record(f1, r1), Type::Record(f2, r2)) => {
321                // Common fields must match
322                for (k, v1) in f1 {
323                    if let Some(v2) = f2.get(k)
324                        && !v1.can_match(v2)
325                    {
326                        return false;
327                    }
328                }
329                r1.can_match(r2)
330            }
331
332            // Record can match Dict (compatibility)
333            (Type::Record(_, _), Type::Dict(_, _)) | (Type::Dict(_, _), Type::Record(_, _)) => true,
334
335            // RowEmpty matches RowEmpty
336            (Type::RowEmpty, Type::RowEmpty) => true,
337
338            // Everything else doesn't match
339            _ => false,
340        }
341    }
342
343    /// Strict branch compatibility check for if-expression unification decisions.
344    ///
345    /// Unlike `can_match`, this treats an unresolved type variable (`Var`) as incompatible
346    /// with any known concrete type. This is used when deciding whether if-expression
347    /// branches should be unified or placed in a Union type: if one branch has `None` in
348    /// a tuple element position and another has `Var` (which might resolve to `Number`),
349    /// we conservatively choose Union to avoid a spurious unification error later.
350    ///
351    /// - `Var(a)` vs `Var(b)` → `true` (both unknown; let unification decide)
352    /// - `Var(a)` vs `concrete` → `false` (Var might not match the concrete type)
353    /// - `concrete` vs `Var(a)` → `false` (symmetric)
354    /// - All other cases mirror `can_match`
355    pub fn can_branch_unify_with(&self, other: &Type) -> bool {
356        match (self, other) {
357            // Never (bottom type) can unify with anything
358            (Type::Never, _) | (_, Type::Never) => true,
359
360            // Two type variables are compatible (unification will sort them out)
361            (Type::Var(_), Type::Var(_)) => true,
362
363            // Var vs concrete — unknown; do NOT assume compatible
364            (Type::Var(_), _) | (_, Type::Var(_)) => false,
365
366            // Union types: strict — any member must strictly match
367            (Type::Union(types), other) => types.iter().any(|t| t.can_branch_unify_with(other)),
368            (other, Type::Union(types)) => types.iter().any(|t| other.can_branch_unify_with(t)),
369
370            // Concrete types must match exactly
371            (Type::Int, Type::Int)
372            | (Type::Float, Type::Float)
373            | (Type::Number, Type::Number)
374            | (Type::String, Type::String)
375            | (Type::Bool, Type::Bool)
376            | (Type::Symbol, Type::Symbol)
377            | (Type::None, Type::None)
378            | (Type::Markdown, Type::Markdown) => true,
379
380            // Arrays: recurse strictly
381            (Type::Array(elem1), Type::Array(elem2)) => elem1.can_branch_unify_with(elem2),
382
383            // Tuples: same length and all elements strictly match
384            (Type::Tuple(elems1), Type::Tuple(elems2)) => {
385                elems1.len() == elems2.len()
386                    && elems1
387                        .iter()
388                        .zip(elems2.iter())
389                        .all(|(e1, e2)| e1.can_branch_unify_with(e2))
390            }
391
392            // Dicts: recurse strictly
393            (Type::Dict(k1, v1), Type::Dict(k2, v2)) => k1.can_branch_unify_with(k2) && v1.can_branch_unify_with(v2),
394
395            // Functions: same arity and all param/ret types strictly match
396            (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
397                params1.len() == params2.len()
398                    && params1
399                        .iter()
400                        .zip(params2.iter())
401                        .all(|(p1, p2)| p1.can_branch_unify_with(p2))
402                    && ret1.can_branch_unify_with(ret2)
403            }
404
405            // Everything else is incompatible
406            _ => false,
407        }
408    }
409
410    /// Computes a match score for overload resolution.
411    /// Higher scores indicate better matches. Returns None if types cannot match.
412    ///
413    /// Scoring:
414    /// - Exact match: 100
415    /// - Union type: best match among variants (slightly penalized)
416    /// - Type variable: 10
417    /// - Structural match (array/dict/function): sum of component scores
418    pub fn match_score(&self, other: &Type) -> Option<u32> {
419        if !self.can_match(other) {
420            return None;
421        }
422
423        match (self, other) {
424            // Never (bottom type) can match anything but with lowest score
425            (Type::Never, _) | (_, Type::Never) => Some(1),
426
427            // Exact matches get highest score
428            (Type::Int, Type::Int)
429            | (Type::Float, Type::Float)
430            | (Type::Number, Type::Number)
431            | (Type::String, Type::String)
432            | (Type::Bool, Type::Bool)
433            | (Type::Symbol, Type::Symbol)
434            | (Type::None, Type::None)
435            | (Type::Markdown, Type::Markdown) => Some(100),
436
437            // Type variables get low score (prefer concrete types).
438            // This arm must come BEFORE the union arms so that a Var parameter
439            // (e.g. `to_number: (Var) -> Number`) scores 10 against a union arg
440            // rather than 0 (the union arm penalises by -15, giving 10-15=0).
441            (Type::Var(_), _) | (_, Type::Var(_)) => Some(10),
442
443            // Union types: take the best match among all variants, but penalize
444            (Type::Union(types), other) => types
445                .iter()
446                .filter_map(|t| t.match_score(other))
447                .max()
448                .map(|s| s.saturating_sub(15)),
449            (other, Type::Union(types)) => types
450                .iter()
451                .filter_map(|t| other.match_score(t))
452                .max()
453                .map(|s| s.saturating_sub(15)),
454
455            // Arrays: structural match scores higher than bare type variable
456            (Type::Array(elem1), Type::Array(elem2)) => elem1.match_score(elem2).map(|s| s + 20),
457
458            // Tuples: structural match on all elements
459            (Type::Tuple(elems1), Type::Tuple(elems2)) if elems1.len() == elems2.len() => {
460                let total: u32 = elems1
461                    .iter()
462                    .zip(elems2.iter())
463                    .map(|(e1, e2)| e1.match_score(e2).unwrap_or(0))
464                    .sum();
465                Some(total / elems1.len() as u32 + 20)
466            }
467
468            // Tuple ↔ Array compatibility (lower score than direct Tuple match)
469            (Type::Tuple(_), Type::Array(_)) | (Type::Array(_), Type::Tuple(_)) => Some(15),
470
471            // Dicts: structural match scores higher than bare type variable
472            (Type::Dict(k1, v1), Type::Dict(k2, v2)) => {
473                let key_score = k1.match_score(k2)?;
474                let val_score = v1.match_score(v2)?;
475                Some((key_score + val_score) / 2 + 20)
476            }
477
478            // Records: structural match on fields
479            (Type::Record(f1, r1), Type::Record(f2, r2)) => {
480                let mut total = 0u32;
481                let mut count = 0u32;
482                for (k, v1) in f1 {
483                    if let Some(v2) = f2.get(k) {
484                        total += v1.match_score(v2)?;
485                        count += 1;
486                    }
487                }
488                let field_score = if count > 0 { total / count } else { 10 };
489                let rest_score = r1.match_score(r2).unwrap_or(10);
490                Some(field_score + rest_score + 20)
491            }
492
493            // Record ↔ Dict compatibility (lower score than direct Record match)
494            (Type::Record(_, _), Type::Dict(_, _)) | (Type::Dict(_, _), Type::Record(_, _)) => Some(15),
495
496            (Type::RowEmpty, Type::RowEmpty) => Some(100),
497
498            // Functions: sum all parameter scores and return score
499            (Type::Function(params1, ret1), Type::Function(params2, ret2)) => {
500                let param_score: u32 = params1
501                    .iter()
502                    .zip(params2.iter())
503                    .map(|(p1, p2)| p1.match_score(p2).unwrap_or(0))
504                    .sum();
505                let ret_score = ret1.match_score(ret2)?;
506                Some(param_score + ret_score)
507            }
508
509            _ => None,
510        }
511    }
512}
513
514impl Type {
515    /// Formats the type as a string, resolving type variables to their readable names.
516    /// This is used for better error messages.
517    pub fn display_resolved(&self) -> String {
518        match self {
519            Type::Int => "int".to_string(),
520            Type::Float => "float".to_string(),
521            Type::Number => "number".to_string(),
522            Type::String => "string".to_string(),
523            Type::Bool => "bool".to_string(),
524            Type::Symbol => "symbol".to_string(),
525            Type::None => "none".to_string(),
526            Type::Markdown => "markdown".to_string(),
527            Type::Array(elem) => format!("[{}]", elem.display_resolved()),
528            Type::Tuple(elems) => {
529                let elems_str = elems
530                    .iter()
531                    .map(|e| e.display_resolved())
532                    .collect::<Vec<_>>()
533                    .join(", ");
534                format!("({})", elems_str)
535            }
536            Type::Dict(key, value) => format!("{{{}: {}}}", key.display_resolved(), value.display_resolved()),
537            Type::Record(fields, rest) => {
538                let fields_str = fields
539                    .iter()
540                    .map(|(k, v)| format!("{}: {}", k, v.display_resolved()))
541                    .collect::<Vec<_>>()
542                    .join(", ");
543                match rest.as_ref() {
544                    Type::RowEmpty => format!("{{{}}}", fields_str),
545                    _ => {
546                        if fields_str.is_empty() {
547                            format!("{{| {}}}", rest.display_resolved())
548                        } else {
549                            format!("{{{} | {}}}", fields_str, rest.display_resolved())
550                        }
551                    }
552                }
553            }
554            Type::RowEmpty => "{}".to_string(),
555            Type::Function(params, ret) => {
556                let params_str = params
557                    .iter()
558                    .map(|p| p.display_resolved())
559                    .collect::<Vec<_>>()
560                    .join(", ");
561                format!("({}) -> {}", params_str, ret.display_resolved())
562            }
563            Type::Union(types) => {
564                let types_str = types
565                    .iter()
566                    .map(|t| t.display_resolved())
567                    .collect::<Vec<_>>()
568                    .join(" | ");
569                format!("({})", types_str)
570            }
571            Type::Var(id) => {
572                // Convert TypeVarId to a readable name like 'a, 'b, 'c, etc.
573                type_var_name(*id)
574            }
575            Type::Never => "never".to_string(),
576        }
577    }
578
579    /// Formats the type with renumbered type variables starting from `'a`.
580    ///
581    /// This produces clean, sequential type variable names regardless of internal
582    /// slotmap indices, which can be very large due to builtin registrations.
583    /// For example, instead of `'y32` or `'x3`, this produces `'a`, `'b`, etc.
584    pub fn display_renumbered(&self) -> String {
585        let mut var_map = FxHashMap::default();
586        let mut counter = 0usize;
587        self.fmt_renumbered(&mut var_map, &mut counter)
588    }
589
590    /// Internal helper for renumbered formatting.
591    pub(crate) fn fmt_renumbered(&self, var_map: &mut FxHashMap<TypeVarId, usize>, counter: &mut usize) -> String {
592        match self {
593            Type::Int => "int".to_string(),
594            Type::Float => "float".to_string(),
595            Type::Number => "number".to_string(),
596            Type::String => "string".to_string(),
597            Type::Bool => "bool".to_string(),
598            Type::Symbol => "symbol".to_string(),
599            Type::None => "none".to_string(),
600            Type::Markdown => "markdown".to_string(),
601            Type::Array(elem) => format!("[{}]", elem.fmt_renumbered(var_map, counter)),
602            Type::Tuple(elems) => {
603                let elems_str = elems
604                    .iter()
605                    .map(|e| e.fmt_renumbered(var_map, counter))
606                    .collect::<Vec<_>>()
607                    .join(", ");
608                format!("({})", elems_str)
609            }
610            Type::Dict(key, value) => {
611                format!(
612                    "{{{}: {}}}",
613                    key.fmt_renumbered(var_map, counter),
614                    value.fmt_renumbered(var_map, counter)
615                )
616            }
617            Type::Function(params, ret) => {
618                let params_str = params
619                    .iter()
620                    .map(|p| p.fmt_renumbered(var_map, counter))
621                    .collect::<Vec<_>>()
622                    .join(", ");
623                format!("({}) -> {}", params_str, ret.fmt_renumbered(var_map, counter))
624            }
625            Type::Union(types) => {
626                let types_str = types
627                    .iter()
628                    .map(|t| t.fmt_renumbered(var_map, counter))
629                    .collect::<Vec<_>>()
630                    .join(" | ");
631                format!("({})", types_str)
632            }
633            Type::Record(fields, rest) => {
634                let fields_str = fields
635                    .iter()
636                    .map(|(k, v)| format!("{}: {}", k, v.fmt_renumbered(var_map, counter)))
637                    .collect::<Vec<_>>()
638                    .join(", ");
639                match rest.as_ref() {
640                    Type::RowEmpty => format!("{{{}}}", fields_str),
641                    _ => {
642                        let rest_str = rest.fmt_renumbered(var_map, counter);
643                        if fields_str.is_empty() {
644                            format!("{{| {}}}", rest_str)
645                        } else {
646                            format!("{{{} | {}}}", fields_str, rest_str)
647                        }
648                    }
649                }
650            }
651            Type::RowEmpty => "{}".to_string(),
652            Type::Never => "never".to_string(),
653            Type::Var(id) => {
654                let index = *var_map.entry(*id).or_insert_with(|| {
655                    let i = *counter;
656                    *counter += 1;
657                    i
658                });
659                format_var_name(index)
660            }
661        }
662    }
663}
664
665/// Converts a TypeVarId to a readable name like `'a`, `'b`, ..., `'z`, `'a1`, `'b1`, etc.
666///
667/// Uses the slotmap key's index (via `KeyData`) for a reliable numeric index,
668/// then maps it to a human-readable alphabetic name.
669fn type_var_name(id: TypeVarId) -> String {
670    use slotmap::Key;
671    let index = id.data().as_ffi() as u32 as usize;
672    format_var_name(index)
673}
674
675/// Formats a type variable name from a sequential index.
676///
677/// Maps index 0 → `'a`, 1 → `'b`, ..., 25 → `'z`, 26 → `'a1`, etc.
678pub fn format_var_name(index: usize) -> String {
679    let letter = (b'a' + (index % 26) as u8) as char;
680    let suffix = index / 26;
681    if suffix == 0 {
682        format!("'{}", letter)
683    } else {
684        format!("'{}{}", letter, suffix)
685    }
686}
687
688/// Formats a list of types as a comma-separated string using renumbered display.
689///
690/// Convenience helper used in error messages and overload reporting.
691pub(crate) fn format_type_list(types: &[Type]) -> String {
692    types
693        .iter()
694        .map(|t| t.display_renumbered())
695        .collect::<Vec<_>>()
696        .join(", ")
697}
698
699impl fmt::Display for Type {
700    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701        // Use display_renumbered() to produce clean, sequential type variable names
702        // regardless of internal slotmap indices.
703        write!(f, "{}", self.display_renumbered())
704    }
705}
706
707/// Type scheme for polymorphic types (generalized types)
708///
709/// A type scheme represents a polymorphic type by quantifying over type variables.
710/// For example: forall a b. (a -> b) -> [a] -> [b]
711#[derive(Debug, Clone, PartialEq, Eq)]
712pub struct TypeScheme {
713    /// Quantified type variables
714    pub quantified: Vec<TypeVarId>,
715    /// The actual type
716    pub ty: Type,
717}
718
719impl TypeScheme {
720    /// Creates a monomorphic type scheme (no quantified variables)
721    pub fn mono(ty: Type) -> Self {
722        Self {
723            quantified: Vec::new(),
724            ty,
725        }
726    }
727
728    /// Creates a polymorphic type scheme
729    pub fn poly(quantified: Vec<TypeVarId>, ty: Type) -> Self {
730        Self { quantified, ty }
731    }
732
733    /// Instantiates this type scheme with fresh type variables
734    pub fn instantiate(&self, ctx: &mut TypeVarContext) -> Type {
735        if self.quantified.is_empty() {
736            return self.ty.clone();
737        }
738
739        // Create fresh type variables for each quantified variable
740        let mut subst = Substitution::empty();
741        for var_id in &self.quantified {
742            let fresh = ctx.fresh();
743            subst.insert(*var_id, Type::Var(fresh));
744        }
745
746        self.ty.apply_subst(&subst)
747    }
748
749    /// Generalizes a type into a type scheme
750    pub fn generalize(ty: Type, env_vars: &[TypeVarId]) -> Self {
751        let ty_vars = ty.free_vars();
752        let quantified: Vec<TypeVarId> = ty_vars.into_iter().filter(|v| !env_vars.contains(v)).collect();
753        Self::poly(quantified, ty)
754    }
755}
756
757impl fmt::Display for TypeScheme {
758    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
759        if self.quantified.is_empty() {
760            // Monomorphic type — renumber for clean display
761            write!(f, "{}", self.ty.display_renumbered())
762        } else {
763            // Polymorphic type — renumber quantified vars to 'a, 'b, ...
764            let mut var_map: FxHashMap<TypeVarId, usize> = FxHashMap::default();
765            for (i, var) in self.quantified.iter().enumerate() {
766                var_map.insert(*var, i);
767            }
768            let mut counter = self.quantified.len();
769
770            write!(f, "forall ")?;
771            for (i, var) in self.quantified.iter().enumerate() {
772                if i > 0 {
773                    write!(f, " ")?;
774                }
775                write!(f, "{}", format_var_name(var_map[var]))?;
776            }
777            write!(f, ". {}", self.ty.fmt_renumbered(&mut var_map, &mut counter))
778        }
779    }
780}
781
782/// Type variable context for generating fresh type variables
783pub struct TypeVarContext {
784    vars: SlotMap<TypeVarId, Option<Type>>,
785}
786
787impl TypeVarContext {
788    /// Creates a new type variable context
789    pub fn new() -> Self {
790        Self {
791            vars: SlotMap::with_key(),
792        }
793    }
794
795    /// Generates a fresh type variable
796    pub fn fresh(&mut self) -> TypeVarId {
797        self.vars.insert(None)
798    }
799
800    /// Gets the resolved type for a type variable
801    pub fn get(&self, var: TypeVarId) -> Option<&Type> {
802        self.vars.get(var).and_then(|opt| opt.as_ref())
803    }
804
805    /// Sets the resolved type for a type variable
806    pub fn set(&mut self, var: TypeVarId, ty: Type) {
807        if let Some(slot) = self.vars.get_mut(var) {
808            *slot = Some(ty);
809        }
810    }
811
812    /// Checks if a type variable is resolved
813    pub fn is_resolved(&self, var: TypeVarId) -> bool {
814        self.vars.get(var).and_then(|opt| opt.as_ref()).is_some()
815    }
816}
817
818impl Default for TypeVarContext {
819    fn default() -> Self {
820        Self::new()
821    }
822}
823
824/// Type substitution mapping type variables to types
825#[derive(Debug, Clone, Default)]
826pub struct Substitution {
827    map: FxHashMap<TypeVarId, Type>,
828}
829
830impl Substitution {
831    /// Creates an empty substitution
832    pub fn empty() -> Self {
833        Self {
834            map: FxHashMap::default(),
835        }
836    }
837
838    /// Inserts a substitution
839    pub fn insert(&mut self, var: TypeVarId, ty: Type) {
840        self.map.insert(var, ty);
841    }
842
843    /// Looks up a type variable in the substitution
844    pub fn lookup(&self, var: TypeVarId) -> Option<&Type> {
845        self.map.get(&var)
846    }
847
848    /// Returns true if the substitution has no bindings
849    pub fn is_empty(&self) -> bool {
850        self.map.is_empty()
851    }
852
853    /// Composes two substitutions
854    pub fn compose(&self, other: &Substitution) -> Substitution {
855        let mut result = Substitution::empty();
856
857        // Apply other to all types in self
858        for (var, ty) in &self.map {
859            result.insert(*var, ty.apply_subst(other));
860        }
861
862        // Add mappings from other that aren't in self
863        for (var, ty) in &other.map {
864            if !self.map.contains_key(var) {
865                result.insert(*var, ty.clone());
866            }
867        }
868
869        result
870    }
871}
872
873#[cfg(test)]
874mod tests {
875    use super::*;
876    use rstest::rstest;
877
878    #[test]
879    fn test_type_display() {
880        assert_eq!(Type::Number.to_string(), "number");
881        assert_eq!(Type::String.to_string(), "string");
882        assert_eq!(Type::array(Type::Number).to_string(), "[number]");
883        assert_eq!(
884            Type::function(vec![Type::Number, Type::String], Type::Bool).to_string(),
885            "(number, string) -> bool"
886        );
887    }
888
889    #[test]
890    fn test_type_var_context() {
891        let mut ctx = TypeVarContext::new();
892        let var1 = ctx.fresh();
893        let var2 = ctx.fresh();
894        assert_ne!(var1, var2);
895    }
896
897    #[test]
898    fn test_substitution() {
899        let mut ctx = TypeVarContext::new();
900        let var = ctx.fresh();
901        let ty = Type::Var(var);
902
903        let mut subst = Substitution::empty();
904        subst.insert(var, Type::Number);
905
906        let result = ty.apply_subst(&subst);
907        assert_eq!(result, Type::Number);
908    }
909
910    #[test]
911    fn test_type_scheme_instantiate() {
912        let mut ctx = TypeVarContext::new();
913        let var = ctx.fresh();
914
915        let scheme = TypeScheme::poly(vec![var], Type::Var(var));
916        let inst1 = scheme.instantiate(&mut ctx);
917        let inst2 = scheme.instantiate(&mut ctx);
918
919        // Each instantiation should create fresh variables
920        assert_ne!(inst1, inst2);
921    }
922
923    #[test]
924    fn test_can_match_concrete_types() {
925        assert!(Type::Number.can_match(&Type::Number));
926        assert!(Type::String.can_match(&Type::String));
927        assert!(!Type::Number.can_match(&Type::String));
928    }
929
930    #[test]
931    fn test_can_match_type_variables() {
932        let mut ctx = TypeVarContext::new();
933        let var = ctx.fresh();
934
935        // Type variables can match anything
936        assert!(Type::Var(var).can_match(&Type::Number));
937        assert!(Type::Number.can_match(&Type::Var(var)));
938        assert!(Type::Var(var).can_match(&Type::String));
939    }
940
941    #[test]
942    fn test_can_match_arrays() {
943        let arr_num = Type::array(Type::Number);
944        let arr_str = Type::array(Type::String);
945
946        assert!(arr_num.can_match(&arr_num));
947        assert!(!arr_num.can_match(&arr_str));
948    }
949
950    #[test]
951    fn test_can_match_functions() {
952        let func1 = Type::function(vec![Type::Number], Type::String);
953        let func2 = Type::function(vec![Type::Number], Type::String);
954        let func3 = Type::function(vec![Type::String], Type::String);
955
956        assert!(func1.can_match(&func2));
957        assert!(!func1.can_match(&func3));
958    }
959
960    #[test]
961    fn test_match_score() {
962        // Exact matches get highest score
963        assert_eq!(Type::Number.match_score(&Type::Number), Some(100));
964        assert_eq!(Type::String.match_score(&Type::String), Some(100));
965
966        // Type variables get lower score
967        let mut ctx = TypeVarContext::new();
968        let var = ctx.fresh();
969        assert_eq!(Type::Var(var).match_score(&Type::Number), Some(10));
970
971        // Incompatible types return None
972        assert_eq!(Type::Number.match_score(&Type::String), None);
973    }
974
975    #[rstest]
976    #[case(vec![Type::Number, Type::Number], Type::Number)]
977    #[case(vec![Type::Number, Type::String], Type::union(vec![Type::Number, Type::String]))]
978    #[case(vec![Type::union(vec![Type::Number, Type::String]), Type::Bool], Type::union(vec![Type::Number, Type::String, Type::Bool]))]
979    #[case(vec![Type::Number, Type::String, Type::Number], Type::union(vec![Type::Number, Type::String]))]
980    fn test_type_union(#[case] types: Vec<Type>, #[case] expected: Type) {
981        assert_eq!(Type::union(types), expected);
982    }
983
984    #[rstest]
985    #[case(Type::union(vec![Type::Number, Type::String]), &Type::Number, Type::String)]
986    #[case(Type::union(vec![Type::Number, Type::String, Type::Bool]), &Type::String, Type::union(vec![Type::Number, Type::Bool]))]
987    #[case(Type::Number, &Type::Number, Type::Number)]
988    fn test_type_subtract(#[case] ty: Type, #[case] exclude: &Type, #[case] expected: Type) {
989        assert_eq!(ty.subtract(exclude), expected);
990    }
991
992    #[rstest]
993    #[case(Type::Number, Type::Number, true)]
994    #[case(Type::Number, Type::String, false)]
995    #[case(Type::array(Type::Number), Type::array(Type::Number), true)]
996    #[case(Type::array(Type::Number), Type::array(Type::String), false)]
997    #[case(Type::tuple(vec![Type::Number]), Type::array(Type::Number), true)]
998    #[case(Type::record(BTreeMap::from([("a".to_string(), Type::Number)]), Type::RowEmpty), Type::dict(Type::String, Type::Number), true)]
999    fn test_can_match_complex(#[case] t1: Type, #[case] t2: Type, #[case] expected: bool) {
1000        assert_eq!(t1.can_match(&t2), expected);
1001    }
1002
1003    #[rstest]
1004    #[case(Type::Number, Type::Number, true)]
1005    #[case(Type::Number, Type::String, false)]
1006    fn test_can_branch_unify_with(#[case] t1: Type, #[case] t2: Type, #[case] expected: bool) {
1007        assert_eq!(t1.can_branch_unify_with(&t2), expected);
1008    }
1009
1010    #[test]
1011    fn test_can_branch_unify_with_vars() {
1012        let mut ctx = TypeVarContext::new();
1013        let v1 = ctx.fresh();
1014        let v2 = ctx.fresh();
1015        assert!(Type::Var(v1).can_branch_unify_with(&Type::Var(v2)));
1016        assert!(!Type::Var(v1).can_branch_unify_with(&Type::Number));
1017    }
1018
1019    #[test]
1020    fn test_substitution_compose() {
1021        let mut ctx = TypeVarContext::new();
1022        let v1 = ctx.fresh();
1023        let v2 = ctx.fresh();
1024
1025        let mut s1 = Substitution::empty();
1026        s1.insert(v1, Type::Var(v2));
1027
1028        let mut s2 = Substitution::empty();
1029        s2.insert(v2, Type::Number);
1030
1031        let s3 = s1.compose(&s2);
1032        assert_eq!(s3.lookup(v1), Some(&Type::Number));
1033        assert_eq!(s3.lookup(v2), Some(&Type::Number));
1034    }
1035
1036    #[test]
1037    fn test_type_scheme_generalize() {
1038        let mut ctx = TypeVarContext::new();
1039        let v1 = ctx.fresh();
1040        let v2 = ctx.fresh();
1041
1042        let ty = Type::function(vec![Type::Var(v1)], Type::Var(v2));
1043        let env_vars = vec![v1];
1044
1045        let scheme = TypeScheme::generalize(ty, &env_vars);
1046        assert_eq!(scheme.quantified, vec![v2]);
1047    }
1048
1049    #[test]
1050    fn test_display_renumbered() {
1051        let mut ctx = TypeVarContext::new();
1052        let v1 = ctx.fresh();
1053        let v2 = ctx.fresh();
1054        let ty = Type::function(vec![Type::Var(v1)], Type::Var(v2));
1055
1056        // Renumbered should always start from 'a
1057        assert_eq!(ty.display_renumbered(), "('a) -> 'b");
1058    }
1059}