Skip to main content

openjd_expr/
types.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Expression type system.
6//!
7//! Mirrors Python `openjd.expr._types` and the spec's ExprType definition.
8//! Supports primitive types, list types, union types, type variables,
9//! unresolved types, and normalization rules.
10
11use std::collections::HashMap;
12use std::fmt;
13
14/// Type codes for expression values.
15///
16/// Marked `#[non_exhaustive]` so that future revisions and extensions can
17/// introduce new primitive types (e.g. `Duration`, `Url`) without a SemVer
18/// break. External callers must use a wildcard arm when matching.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize)]
20#[non_exhaustive]
21pub enum TypeCode {
22    NullType,
23    Bool,
24    Int,
25    Float,
26    String,
27    Path,
28    List,
29    RangeExpr,
30    Any,
31    Union,
32    NoReturn,
33    Unresolved,
34    TypeVarT,
35    TypeVarT1,
36    TypeVarT2,
37    TypeVarT3,
38    /// Function signature: params are [param_type_0, ..., param_type_n, return_type].
39    Signature,
40}
41
42/// Represents a type in the expression language.
43///
44/// Fields are crate-private to enforce normalization via constructors.
45/// Use [`code()`](ExprType::code) and [`params()`](ExprType::params) for read access.
46#[derive(Debug, Clone, Eq, serde::Serialize)]
47pub struct ExprType {
48    code: TypeCode,
49    params: Vec<ExprType>,
50}
51
52// ── Constants ──
53
54impl ExprType {
55    pub const BOOL: ExprType = ExprType {
56        code: TypeCode::Bool,
57        params: Vec::new(),
58    };
59    pub const INT: ExprType = ExprType {
60        code: TypeCode::Int,
61        params: Vec::new(),
62    };
63    pub const FLOAT: ExprType = ExprType {
64        code: TypeCode::Float,
65        params: Vec::new(),
66    };
67    pub const STRING: ExprType = ExprType {
68        code: TypeCode::String,
69        params: Vec::new(),
70    };
71    pub const PATH: ExprType = ExprType {
72        code: TypeCode::Path,
73        params: Vec::new(),
74    };
75    pub const RANGE_EXPR: ExprType = ExprType {
76        code: TypeCode::RangeExpr,
77        params: Vec::new(),
78    };
79    pub const NULLTYPE: ExprType = ExprType {
80        code: TypeCode::NullType,
81        params: Vec::new(),
82    };
83    pub const ANY: ExprType = ExprType {
84        code: TypeCode::Any,
85        params: Vec::new(),
86    };
87    pub const NORETURN: ExprType = ExprType {
88        code: TypeCode::NoReturn,
89        params: Vec::new(),
90    };
91    pub const T: ExprType = ExprType {
92        code: TypeCode::TypeVarT,
93        params: Vec::new(),
94    };
95    pub const T1: ExprType = ExprType {
96        code: TypeCode::TypeVarT1,
97        params: Vec::new(),
98    };
99    pub const T2: ExprType = ExprType {
100        code: TypeCode::TypeVarT2,
101        params: Vec::new(),
102    };
103    pub const T3: ExprType = ExprType {
104        code: TypeCode::TypeVarT3,
105        params: Vec::new(),
106    };
107
108    /// The type code for this type.
109    pub fn code(&self) -> TypeCode {
110        self.code
111    }
112
113    /// The type parameters (e.g. element type for lists, member types for unions).
114    pub fn params(&self) -> &[ExprType] {
115        &self.params
116    }
117
118    pub fn list(elem: ExprType) -> Self {
119        // Normalize: list[unresolved[T]] -> unresolved[list[T]]
120        if elem.code == TypeCode::Unresolved && elem.params.len() == 1 {
121            let inner_list = ExprType {
122                code: TypeCode::List,
123                params: vec![elem.params[0].clone()],
124            };
125            return ExprType {
126                code: TypeCode::Unresolved,
127                params: vec![inner_list],
128            };
129        }
130        ExprType {
131            code: TypeCode::List,
132            params: vec![elem],
133        }
134    }
135
136    pub fn union(types: Vec<ExprType>) -> Self {
137        normalize_union(types)
138    }
139
140    pub fn unresolved(constraint: ExprType) -> Self {
141        // Normalize: unresolved[unresolved[T]] -> unresolved[T]
142        if constraint.code == TypeCode::Unresolved {
143            return constraint;
144        }
145        ExprType {
146            code: TypeCode::Unresolved,
147            params: vec![constraint],
148        }
149    }
150
151    /// Create a function signature type: `(param_types) -> return_type`.
152    /// Stored as params = [param0, param1, ..., return_type].
153    pub fn signature(param_types: Vec<ExprType>, return_type: ExprType) -> Self {
154        let mut params = param_types;
155        params.push(return_type);
156        ExprType {
157            code: TypeCode::Signature,
158            params,
159        }
160    }
161
162    /// Get the parameter types of a signature (all params except the last).
163    pub fn sig_params(&self) -> &[ExprType] {
164        debug_assert_eq!(self.code, TypeCode::Signature);
165        &self.params[..self.params.len() - 1]
166    }
167
168    /// Get the return type of a signature (the last param).
169    pub fn sig_return(&self) -> &ExprType {
170        debug_assert_eq!(self.code, TypeCode::Signature);
171        self.params.last().unwrap()
172    }
173
174    /// Try to match a call's argument types against this signature.
175    /// Returns type variable bindings if successful, None if no match.
176    pub fn match_call(&self, arg_types: &[ExprType]) -> Option<HashMap<TypeCode, ExprType>> {
177        let sig_params = self.sig_params();
178        if sig_params.len() != arg_types.len() {
179            return None;
180        }
181        let mut bindings = HashMap::new();
182        for (sig_p, arg_t) in sig_params.iter().zip(arg_types.iter()) {
183            let sub = sig_p.match_type(arg_t)?;
184            for (k, v) in sub {
185                if let Some(existing) = bindings.get(&k) {
186                    if *existing != v {
187                        return None;
188                    }
189                }
190                bindings.insert(k, v);
191            }
192        }
193        Some(bindings)
194    }
195
196    /// Resolve the return type of a signature given argument types.
197    /// Matches args, substitutes bindings into return type.
198    pub fn resolve_call(&self, arg_types: &[ExprType]) -> Option<ExprType> {
199        let bindings = self.match_call(arg_types)?;
200        Some(self.sig_return().substitute(&bindings))
201    }
202}
203
204// ── Normalization ──
205
206fn normalize_union(types: Vec<ExprType>) -> ExprType {
207    let mut members = Vec::new();
208    for t in types {
209        match t.code {
210            TypeCode::Any => return ExprType::ANY,
211            TypeCode::NoReturn => continue,
212            TypeCode::Union => members.extend(t.params),
213            _ => members.push(t),
214        }
215    }
216    // Hoist unresolved: T | unresolved[S] -> unresolved[T | S]
217    let mut unresolved_constraints = Vec::new();
218    let mut non_unresolved = Vec::new();
219    for m in &members {
220        if m.code == TypeCode::Unresolved {
221            unresolved_constraints.push(m.params[0].clone());
222        } else {
223            non_unresolved.push(m.clone());
224        }
225    }
226    if !unresolved_constraints.is_empty() {
227        let mut all_parts = non_unresolved;
228        all_parts.extend(unresolved_constraints);
229        let inner = ExprType::union(all_parts);
230        return ExprType::unresolved(inner);
231    }
232    // Deduplicate
233    members.sort_by_key(|a| a.to_string());
234    members.dedup();
235    match members.len() {
236        0 => ExprType::NORETURN,
237        1 => members.into_iter().next().unwrap(),
238        _ => ExprType {
239            code: TypeCode::Union,
240            params: members,
241        },
242    }
243}
244
245// ── Queries ──
246
247impl ExprType {
248    pub fn is_list(&self) -> bool {
249        self.code == TypeCode::List
250    }
251
252    pub fn list_element_type(&self) -> Option<&ExprType> {
253        if self.code == TypeCode::List {
254            self.params.first()
255        } else {
256            None
257        }
258    }
259
260    pub fn is_symbolic(&self) -> bool {
261        matches!(
262            self.code,
263            TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
264        ) || self.params.iter().any(|p| p.is_symbolic())
265    }
266
267    pub fn is_concrete(&self) -> bool {
268        if matches!(
269            self.code,
270            TypeCode::Any
271                | TypeCode::Union
272                | TypeCode::Unresolved
273                | TypeCode::TypeVarT
274                | TypeCode::TypeVarT1
275                | TypeCode::TypeVarT2
276                | TypeCode::TypeVarT3
277                | TypeCode::Signature
278        ) {
279            return false;
280        }
281        self.params.iter().all(|p| p.is_concrete())
282    }
283
284    /// Substitute type variables with concrete types.
285    pub fn substitute(&self, bindings: &HashMap<TypeCode, ExprType>) -> ExprType {
286        if let Some(bound) = bindings.get(&self.code) {
287            return bound.clone();
288        }
289        if self.params.is_empty() {
290            return self.clone();
291        }
292        let new_params: Vec<ExprType> =
293            self.params.iter().map(|p| p.substitute(bindings)).collect();
294        ExprType::new(self.code, new_params)
295    }
296
297    /// Try to match this (possibly symbolic) type against another type.
298    /// Returns bindings if successful, None if no match.
299    pub fn match_type(&self, other: &ExprType) -> Option<HashMap<TypeCode, ExprType>> {
300        // Type variables bind
301        if matches!(
302            self.code,
303            TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
304        ) {
305            let mut m = HashMap::new();
306            m.insert(self.code, other.clone());
307            return Some(m);
308        }
309        if matches!(
310            other.code,
311            TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
312        ) {
313            let mut m = HashMap::new();
314            m.insert(other.code, self.clone());
315            return Some(m);
316        }
317        // ANY matches anything
318        if self.code == TypeCode::Any || other.code == TypeCode::Any {
319            return Some(HashMap::new());
320        }
321        // UNRESOLVED delegates to constraint
322        if self.code == TypeCode::Unresolved {
323            return self.params[0].match_type(other);
324        }
325        if other.code == TypeCode::Unresolved {
326            return self.match_type(&other.params[0]);
327        }
328        // UNION: matches if any member matches
329        if self.code == TypeCode::Union && other.code == TypeCode::Union {
330            for s in &self.params {
331                for c in &other.params {
332                    if let Some(r) = s.match_type(c) {
333                        return Some(r);
334                    }
335                }
336            }
337            return None;
338        }
339        if self.code == TypeCode::Union {
340            for member in &self.params {
341                if let Some(r) = member.match_type(other) {
342                    return Some(r);
343                }
344            }
345            return None;
346        }
347        if other.code == TypeCode::Union {
348            for member in &other.params {
349                if let Some(r) = self.match_type(member) {
350                    return Some(r);
351                }
352            }
353            return None;
354        }
355        // Same code, match params
356        if self.code != other.code {
357            return None;
358        }
359        if self.params.len() != other.params.len() {
360            return None;
361        }
362        let mut bindings = HashMap::new();
363        for (sp, cp) in self.params.iter().zip(other.params.iter()) {
364            let sub = sp.match_type(cp)?;
365            for (k, v) in sub {
366                if let Some(existing) = bindings.get(&k) {
367                    if *existing != v {
368                        return None;
369                    }
370                }
371                bindings.insert(k, v);
372            }
373        }
374        Some(bindings)
375    }
376
377    /// Construct with normalization.
378    pub fn new(code: TypeCode, params: Vec<ExprType>) -> Self {
379        match code {
380            TypeCode::Union => normalize_union(params),
381            TypeCode::List if params.len() == 1 => {
382                ExprType::list(params.into_iter().next().unwrap())
383            }
384            TypeCode::Unresolved if params.len() == 1 => {
385                ExprType::unresolved(params.into_iter().next().unwrap())
386            }
387            _ => ExprType { code, params },
388        }
389    }
390}
391
392// ── Parsing ──
393
394impl ExprType {
395    const MAX_PARSE_DEPTH: usize = 10;
396
397    /// Parse a type string like "int", "list\[string\]", "int?", "int | string", "unresolved\[int\]".
398    pub fn parse(s: &str) -> Result<ExprType, String> {
399        Self::parse_inner(s, 0)
400    }
401
402    fn parse_inner(s: &str, depth: usize) -> Result<ExprType, String> {
403        if depth > Self::MAX_PARSE_DEPTH {
404            return Err("Type nesting depth exceeded".to_string());
405        }
406        // Signature: (T1, T2) -> T3  — check before union split since return type may contain |
407        if s.starts_with('(') {
408            if let Some(arrow_pos) = s.find(") -> ") {
409                let params_str = &s[1..arrow_pos];
410                let ret_str = &s[arrow_pos + 5..];
411                let param_types = if params_str.is_empty() {
412                    Vec::new()
413                } else {
414                    split_params(params_str)
415                        .iter()
416                        .map(|p| Self::parse_inner(p, depth + 1))
417                        .collect::<Result<Vec<_>, _>>()?
418                };
419                let return_type = Self::parse_inner(ret_str, depth + 1)?;
420                return Ok(ExprType::signature(param_types, return_type));
421            }
422        }
423        // Union: split on " | " outside brackets/parens
424        let parts = split_union(s);
425        if parts.len() > 1 {
426            let types: Result<Vec<_>, _> = parts
427                .iter()
428                .map(|p| Self::parse_inner(p, depth + 1))
429                .collect();
430            return Ok(ExprType::union(types?));
431        }
432        // Base types
433        match s {
434            "bool" => return Ok(ExprType::BOOL),
435            "int" => return Ok(ExprType::INT),
436            "float" => return Ok(ExprType::FLOAT),
437            "string" => return Ok(ExprType::STRING),
438            "path" => return Ok(ExprType::PATH),
439            "range_expr" => return Ok(ExprType::RANGE_EXPR),
440            "nulltype" => return Ok(ExprType::NULLTYPE),
441            "noreturn" => return Ok(ExprType::NORETURN),
442            "any" => return Ok(ExprType::ANY),
443            "unresolved" => return Ok(ExprType::unresolved(ExprType::ANY)),
444            _ => {}
445        }
446        // Optional: T?
447        if let Some(inner) = s.strip_suffix('?') {
448            let t = Self::parse_inner(inner, depth + 1)?;
449            return Ok(ExprType::union(vec![t, ExprType::NULLTYPE]));
450        }
451        // list[T]
452        if let Some(inner) = s.strip_prefix("list[").and_then(|s| s.strip_suffix(']')) {
453            let elem = Self::parse_inner(inner, depth + 1)?;
454            return Ok(ExprType::list(elem));
455        }
456        // unresolved[T]
457        if let Some(inner) = s
458            .strip_prefix("unresolved[")
459            .and_then(|s| s.strip_suffix(']'))
460        {
461            let constraint = Self::parse_inner(inner, depth + 1)?;
462            return Ok(ExprType::unresolved(constraint));
463        }
464        // Type variables
465        match s {
466            "T" => return Ok(ExprType::T),
467            "T1" => return Ok(ExprType::T1),
468            "T2" => return Ok(ExprType::T2),
469            "T3" => return Ok(ExprType::T3),
470            _ => {}
471        }
472        Err(format!("Unknown type string: {s}"))
473    }
474}
475
476fn split_union(s: &str) -> Vec<&str> {
477    let mut parts = Vec::new();
478    let mut start = 0;
479    let mut depth = 0;
480    let bytes = s.as_bytes();
481    let mut i = 0;
482    while i < bytes.len() {
483        match bytes[i] {
484            b'[' | b'(' => depth += 1,
485            b']' | b')' => depth -= 1,
486            b' ' if depth == 0
487                && i + 2 < bytes.len()
488                && bytes[i + 1] == b'|'
489                && bytes[i + 2] == b' ' =>
490            {
491                parts.push(&s[start..i]);
492                i += 3;
493                start = i;
494                continue;
495            }
496            _ => {}
497        }
498        i += 1;
499    }
500    parts.push(&s[start..]);
501    parts
502}
503
504/// Split on ", " respecting brackets and parens.
505fn split_params(s: &str) -> Vec<&str> {
506    let mut parts = Vec::new();
507    let mut start = 0;
508    let mut depth = 0;
509    let bytes = s.as_bytes();
510    let mut i = 0;
511    while i < bytes.len() {
512        match bytes[i] {
513            b'[' | b'(' => depth += 1,
514            b']' | b')' => depth -= 1,
515            b',' if depth == 0 => {
516                parts.push(s[start..i].trim());
517                start = i + 1;
518            }
519            _ => {}
520        }
521        i += 1;
522    }
523    let last = s[start..].trim();
524    if !last.is_empty() {
525        parts.push(last);
526    }
527    parts
528}
529
530// ── Display ──
531
532impl fmt::Display for ExprType {
533    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534        match self.code {
535            TypeCode::NullType => write!(f, "nulltype"),
536            TypeCode::Bool => write!(f, "bool"),
537            TypeCode::Int => write!(f, "int"),
538            TypeCode::Float => write!(f, "float"),
539            TypeCode::String => write!(f, "string"),
540            TypeCode::Path => write!(f, "path"),
541            TypeCode::RangeExpr => write!(f, "range_expr"),
542            TypeCode::Any => write!(f, "any"),
543            TypeCode::NoReturn => write!(f, "noreturn"),
544            TypeCode::TypeVarT => write!(f, "T"),
545            TypeCode::TypeVarT1 => write!(f, "T1"),
546            TypeCode::TypeVarT2 => write!(f, "T2"),
547            TypeCode::TypeVarT3 => write!(f, "T3"),
548            TypeCode::List => {
549                if let Some(elem) = self.params.first() {
550                    write!(f, "list[{elem}]")
551                } else {
552                    write!(f, "list")
553                }
554            }
555            TypeCode::Unresolved => {
556                if let Some(constraint) = self.params.first() {
557                    if constraint.code == TypeCode::Any {
558                        write!(f, "unresolved")
559                    } else {
560                        write!(f, "unresolved[{constraint}]")
561                    }
562                } else {
563                    write!(f, "unresolved")
564                }
565            }
566            TypeCode::Union => {
567                let non_null: Vec<_> = self
568                    .params
569                    .iter()
570                    .filter(|t| t.code != TypeCode::NullType)
571                    .collect();
572                let has_null = non_null.len() < self.params.len();
573                if has_null && non_null.len() == 1 {
574                    return write!(f, "{}?", non_null[0]);
575                }
576                let mut parts: Vec<std::string::String> =
577                    non_null.iter().map(|t| t.to_string()).collect();
578                if has_null {
579                    parts.push("nulltype".to_string());
580                }
581                write!(f, "{}", parts.join(" | "))
582            }
583            TypeCode::Signature => {
584                let params = &self.params[..self.params.len() - 1];
585                let ret = self.params.last().unwrap();
586                let param_strs: Vec<String> = params.iter().map(|p| p.to_string()).collect();
587                write!(f, "({}) -> {}", param_strs.join(", "), ret)
588            }
589        }
590    }
591}
592
593// ── Eq / Hash ──
594
595impl PartialEq for ExprType {
596    fn eq(&self, other: &Self) -> bool {
597        self.code == other.code && self.params == other.params
598    }
599}
600
601impl std::hash::Hash for ExprType {
602    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
603        self.code.hash(state);
604        self.params.hash(state);
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    // ── Basic types ──
613    #[test]
614    fn basic_types() {
615        assert_eq!(ExprType::BOOL.code, TypeCode::Bool);
616        assert_eq!(ExprType::INT.code, TypeCode::Int);
617        assert_eq!(ExprType::FLOAT.code, TypeCode::Float);
618        assert_eq!(ExprType::STRING.code, TypeCode::String);
619        assert_eq!(ExprType::PATH.code, TypeCode::Path);
620    }
621
622    // ── Display ──
623    #[test]
624    fn display_int() {
625        assert_eq!(ExprType::INT.to_string(), "int");
626    }
627    #[test]
628    fn display_list_int() {
629        assert_eq!(ExprType::list(ExprType::INT).to_string(), "list[int]");
630    }
631    #[test]
632    fn display_any() {
633        assert_eq!(ExprType::ANY.to_string(), "any");
634    }
635    #[test]
636    fn display_noreturn() {
637        assert_eq!(ExprType::NORETURN.to_string(), "noreturn");
638    }
639    #[test]
640    fn display_typevar() {
641        assert_eq!(ExprType::T1.to_string(), "T1");
642    }
643    #[test]
644    fn display_union() {
645        assert_eq!(
646            ExprType::union(vec![ExprType::INT, ExprType::STRING]).to_string(),
647            "int | string"
648        );
649    }
650    #[test]
651    fn display_nullable() {
652        assert_eq!(
653            ExprType::union(vec![ExprType::INT, ExprType::NULLTYPE]).to_string(),
654            "int?"
655        );
656    }
657    #[test]
658    fn display_unresolved_bare() {
659        assert_eq!(
660            ExprType::unresolved(ExprType::ANY).to_string(),
661            "unresolved"
662        );
663    }
664    #[test]
665    fn display_unresolved_int() {
666        assert_eq!(
667            ExprType::unresolved(ExprType::INT).to_string(),
668            "unresolved[int]"
669        );
670    }
671
672    // ── Parsing ──
673    #[test]
674    fn parse_int() {
675        assert_eq!(ExprType::parse("int").unwrap(), ExprType::INT);
676    }
677    #[test]
678    fn parse_float() {
679        assert_eq!(ExprType::parse("float").unwrap(), ExprType::FLOAT);
680    }
681    #[test]
682    fn parse_string() {
683        assert_eq!(ExprType::parse("string").unwrap(), ExprType::STRING);
684    }
685    #[test]
686    fn parse_bool() {
687        assert_eq!(ExprType::parse("bool").unwrap(), ExprType::BOOL);
688    }
689    #[test]
690    fn parse_path() {
691        assert_eq!(ExprType::parse("path").unwrap(), ExprType::PATH);
692    }
693    #[test]
694    fn parse_range_expr() {
695        assert_eq!(ExprType::parse("range_expr").unwrap(), ExprType::RANGE_EXPR);
696    }
697    #[test]
698    fn parse_nulltype() {
699        assert_eq!(ExprType::parse("nulltype").unwrap(), ExprType::NULLTYPE);
700    }
701    #[test]
702    fn parse_noreturn() {
703        assert_eq!(ExprType::parse("noreturn").unwrap(), ExprType::NORETURN);
704    }
705    #[test]
706    fn parse_any() {
707        assert_eq!(ExprType::parse("any").unwrap(), ExprType::ANY);
708    }
709    #[test]
710    fn parse_list_int() {
711        assert_eq!(
712            ExprType::parse("list[int]").unwrap(),
713            ExprType::list(ExprType::INT)
714        );
715    }
716    #[test]
717    fn parse_list_string() {
718        assert_eq!(
719            ExprType::parse("list[string]").unwrap(),
720            ExprType::list(ExprType::STRING)
721        );
722    }
723    #[test]
724    fn parse_nested_list() {
725        assert_eq!(
726            ExprType::parse("list[list[int]]").unwrap(),
727            ExprType::list(ExprType::list(ExprType::INT))
728        );
729    }
730    #[test]
731    fn parse_optional() {
732        let t = ExprType::parse("int?").unwrap();
733        assert_eq!(t.code, TypeCode::Union);
734        assert_eq!(t.to_string(), "int?");
735    }
736    #[test]
737    fn parse_union() {
738        let t = ExprType::parse("int | string").unwrap();
739        assert_eq!(t.code, TypeCode::Union);
740        assert_eq!(t.to_string(), "int | string");
741    }
742    #[test]
743    fn parse_unresolved_bare() {
744        assert_eq!(
745            ExprType::parse("unresolved").unwrap(),
746            ExprType::unresolved(ExprType::ANY)
747        );
748    }
749    #[test]
750    fn parse_unresolved_int() {
751        assert_eq!(
752            ExprType::parse("unresolved[int]").unwrap(),
753            ExprType::unresolved(ExprType::INT)
754        );
755    }
756    #[test]
757    fn parse_unresolved_list() {
758        assert_eq!(
759            ExprType::parse("unresolved[list[string]]").unwrap(),
760            ExprType::unresolved(ExprType::list(ExprType::STRING))
761        );
762    }
763    #[test]
764    fn parse_unknown_rejects() {
765        assert!(ExprType::parse("notavalidtype").is_err());
766    }
767    #[test]
768    fn parse_case_sensitive() {
769        assert!(ExprType::parse("INT").is_err());
770    }
771    #[test]
772    fn parse_whitespace_rejected() {
773        assert!(ExprType::parse(" int").is_err());
774    }
775
776    // ── Roundtrip ──
777    #[test]
778    fn roundtrip_bare() {
779        let t = ExprType::parse("unresolved").unwrap();
780        assert_eq!(ExprType::parse(&t.to_string()).unwrap(), t);
781    }
782    #[test]
783    fn roundtrip_constrained() {
784        for s in &[
785            "unresolved[int]",
786            "unresolved[list[string]]",
787            "unresolved[float | int]",
788        ] {
789            let t = ExprType::parse(s).unwrap();
790            assert_eq!(
791                ExprType::parse(&t.to_string()).unwrap(),
792                t,
793                "roundtrip failed for {s}"
794            );
795        }
796    }
797
798    // ── Equality / Hash ──
799    #[test]
800    fn eq_same() {
801        assert_eq!(ExprType::INT, ExprType::INT);
802    }
803    #[test]
804    fn eq_diff() {
805        assert_ne!(ExprType::INT, ExprType::FLOAT);
806    }
807    #[test]
808    fn eq_list() {
809        assert_eq!(ExprType::list(ExprType::INT), ExprType::list(ExprType::INT));
810    }
811    #[test]
812    fn ne_list() {
813        assert_ne!(
814            ExprType::list(ExprType::INT),
815            ExprType::list(ExprType::STRING)
816        );
817    }
818    #[test]
819    fn hash_consistent() {
820        use std::collections::HashSet;
821        let mut s = HashSet::new();
822        s.insert(ExprType::INT);
823        s.insert(ExprType::FLOAT);
824        s.insert(ExprType::INT);
825        assert_eq!(s.len(), 2);
826    }
827
828    // ── Union normalization ──
829    #[test]
830    fn union_dedup() {
831        assert_eq!(
832            ExprType::union(vec![ExprType::INT, ExprType::INT]).code,
833            TypeCode::Int
834        );
835    }
836    #[test]
837    fn union_single_unwrap() {
838        assert_eq!(
839            ExprType::union(vec![ExprType::STRING]).code,
840            TypeCode::String
841        );
842    }
843    #[test]
844    fn union_flatten() {
845        let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
846        let u2 = ExprType::union(vec![ExprType::FLOAT, ExprType::BOOL]);
847        let combined = ExprType::union(vec![u1, u2]);
848        assert_eq!(combined.code, TypeCode::Union);
849        assert_eq!(combined.params.len(), 4);
850        assert_eq!(combined.to_string(), "bool | float | int | string");
851    }
852    #[test]
853    fn union_any_absorbs() {
854        assert_eq!(
855            ExprType::union(vec![ExprType::INT, ExprType::ANY]).code,
856            TypeCode::Any
857        );
858    }
859    #[test]
860    fn union_noreturn_collapses() {
861        assert_eq!(
862            ExprType::union(vec![ExprType::INT, ExprType::NORETURN]),
863            ExprType::INT
864        );
865    }
866    #[test]
867    fn union_all_noreturn() {
868        assert_eq!(
869            ExprType::union(vec![ExprType::NORETURN, ExprType::NORETURN]),
870            ExprType::NORETURN
871        );
872    }
873    #[test]
874    fn union_order_independent() {
875        let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
876        let u2 = ExprType::union(vec![ExprType::STRING, ExprType::INT]);
877        assert_eq!(u1, u2);
878    }
879    #[test]
880    fn union_hash_consistent() {
881        use std::collections::HashSet;
882        let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
883        let u2 = ExprType::parse("int | string").unwrap();
884        let mut s = HashSet::new();
885        s.insert(u1);
886        s.insert(u2);
887        assert_eq!(s.len(), 1);
888    }
889
890    // ── Unresolved normalization ──
891    #[test]
892    fn list_of_unresolved_hoists() {
893        let t = ExprType::list(ExprType::unresolved(ExprType::INT));
894        assert_eq!(t.code, TypeCode::Unresolved);
895        assert_eq!(t, ExprType::unresolved(ExprType::list(ExprType::INT)));
896    }
897    #[test]
898    fn union_with_unresolved_hoists() {
899        let t = ExprType::union(vec![ExprType::STRING, ExprType::unresolved(ExprType::INT)]);
900        assert_eq!(t.code, TypeCode::Unresolved);
901        assert_eq!(
902            t,
903            ExprType::unresolved(ExprType::union(vec![ExprType::INT, ExprType::STRING]))
904        );
905    }
906    #[test]
907    fn nested_unresolved_flattens() {
908        let t = ExprType::unresolved(ExprType::unresolved(ExprType::INT));
909        assert_eq!(t.code, TypeCode::Unresolved);
910        assert_eq!(t, ExprType::unresolved(ExprType::INT));
911    }
912    #[test]
913    fn unresolved_never_inside_list() {
914        let t = ExprType::list(ExprType::unresolved(ExprType::STRING));
915        assert_eq!(t.code, TypeCode::Unresolved);
916        assert_eq!(t.params[0], ExprType::list(ExprType::STRING));
917    }
918
919    // ── is_symbolic / is_concrete ──
920    #[test]
921    fn concrete_int() {
922        assert!(ExprType::INT.is_concrete());
923    }
924    #[test]
925    fn concrete_list_int() {
926        assert!(ExprType::list(ExprType::INT).is_concrete());
927    }
928    #[test]
929    fn not_concrete_any() {
930        assert!(!ExprType::ANY.is_concrete());
931    }
932    #[test]
933    fn not_concrete_union() {
934        assert!(!ExprType::union(vec![ExprType::INT, ExprType::STRING]).is_concrete());
935    }
936    #[test]
937    fn not_concrete_typevar() {
938        assert!(!ExprType::T1.is_concrete());
939    }
940    #[test]
941    fn symbolic_t1() {
942        assert!(ExprType::T1.is_symbolic());
943    }
944    #[test]
945    fn not_symbolic_int() {
946        assert!(!ExprType::INT.is_symbolic());
947    }
948    #[test]
949    fn symbolic_list_t1() {
950        assert!(ExprType::list(ExprType::T1).is_symbolic());
951    }
952
953    // ── match_type ──
954    #[test]
955    fn match_simple_typevar() {
956        let b = ExprType::T1.match_type(&ExprType::INT).unwrap();
957        assert_eq!(b[&TypeCode::TypeVarT1], ExprType::INT);
958    }
959    #[test]
960    fn match_nested_typevar() {
961        let list_t1 = ExprType::list(ExprType::T1);
962        let list_int = ExprType::list(ExprType::INT);
963        let b = list_t1.match_type(&list_int).unwrap();
964        assert_eq!(b[&TypeCode::TypeVarT1], ExprType::INT);
965    }
966    #[test]
967    fn match_no_match() {
968        let list_t1 = ExprType::list(ExprType::T1);
969        assert!(list_t1.match_type(&ExprType::INT).is_none());
970    }
971    #[test]
972    fn match_any() {
973        assert!(ExprType::ANY.match_type(&ExprType::INT).is_some());
974        assert!(ExprType::INT.match_type(&ExprType::ANY).is_some());
975    }
976    #[test]
977    fn match_union_member() {
978        let u = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
979        assert!(u.match_type(&ExprType::INT).is_some());
980        assert!(u.match_type(&ExprType::STRING).is_some());
981        assert!(u.match_type(&ExprType::FLOAT).is_none());
982    }
983    #[test]
984    fn match_unresolved_delegates() {
985        let t = ExprType::unresolved(ExprType::INT);
986        assert!(t.match_type(&ExprType::INT).is_some());
987        assert!(t.match_type(&ExprType::STRING).is_none());
988    }
989
990    // ── substitute ──
991    #[test]
992    fn substitute_typevar() {
993        let list_t1 = ExprType::list(ExprType::T1);
994        let mut bindings = HashMap::new();
995        bindings.insert(TypeCode::TypeVarT1, ExprType::INT);
996        assert_eq!(list_t1.substitute(&bindings), ExprType::list(ExprType::INT));
997    }
998}