1use std::collections::HashMap;
12use std::fmt;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize)]
20#[non_exhaustive]
21pub enum TypeCode {
22 NullType,
23 Bool,
24 Int,
25 Float,
26 String,
27 Path,
28 List,
29 RangeExpr,
30 Any,
31 Union,
32 NoReturn,
33 Unresolved,
34 TypeVarT,
35 TypeVarT1,
36 TypeVarT2,
37 TypeVarT3,
38 Signature,
40}
41
42#[derive(Debug, Clone, Eq, serde::Serialize)]
47pub struct ExprType {
48 code: TypeCode,
49 params: Vec<ExprType>,
50}
51
52impl ExprType {
55 pub const BOOL: ExprType = ExprType {
56 code: TypeCode::Bool,
57 params: Vec::new(),
58 };
59 pub const INT: ExprType = ExprType {
60 code: TypeCode::Int,
61 params: Vec::new(),
62 };
63 pub const FLOAT: ExprType = ExprType {
64 code: TypeCode::Float,
65 params: Vec::new(),
66 };
67 pub const STRING: ExprType = ExprType {
68 code: TypeCode::String,
69 params: Vec::new(),
70 };
71 pub const PATH: ExprType = ExprType {
72 code: TypeCode::Path,
73 params: Vec::new(),
74 };
75 pub const RANGE_EXPR: ExprType = ExprType {
76 code: TypeCode::RangeExpr,
77 params: Vec::new(),
78 };
79 pub const NULLTYPE: ExprType = ExprType {
80 code: TypeCode::NullType,
81 params: Vec::new(),
82 };
83 pub const ANY: ExprType = ExprType {
84 code: TypeCode::Any,
85 params: Vec::new(),
86 };
87 pub const NORETURN: ExprType = ExprType {
88 code: TypeCode::NoReturn,
89 params: Vec::new(),
90 };
91 pub const T: ExprType = ExprType {
92 code: TypeCode::TypeVarT,
93 params: Vec::new(),
94 };
95 pub const T1: ExprType = ExprType {
96 code: TypeCode::TypeVarT1,
97 params: Vec::new(),
98 };
99 pub const T2: ExprType = ExprType {
100 code: TypeCode::TypeVarT2,
101 params: Vec::new(),
102 };
103 pub const T3: ExprType = ExprType {
104 code: TypeCode::TypeVarT3,
105 params: Vec::new(),
106 };
107
108 pub fn code(&self) -> TypeCode {
110 self.code
111 }
112
113 pub fn params(&self) -> &[ExprType] {
115 &self.params
116 }
117
118 pub fn list(elem: ExprType) -> Self {
119 if elem.code == TypeCode::Unresolved && elem.params.len() == 1 {
121 let inner_list = ExprType {
122 code: TypeCode::List,
123 params: vec![elem.params[0].clone()],
124 };
125 return ExprType {
126 code: TypeCode::Unresolved,
127 params: vec![inner_list],
128 };
129 }
130 ExprType {
131 code: TypeCode::List,
132 params: vec![elem],
133 }
134 }
135
136 pub fn union(types: Vec<ExprType>) -> Self {
137 normalize_union(types)
138 }
139
140 pub fn unresolved(constraint: ExprType) -> Self {
141 if constraint.code == TypeCode::Unresolved {
143 return constraint;
144 }
145 ExprType {
146 code: TypeCode::Unresolved,
147 params: vec![constraint],
148 }
149 }
150
151 pub fn signature(param_types: Vec<ExprType>, return_type: ExprType) -> Self {
154 let mut params = param_types;
155 params.push(return_type);
156 ExprType {
157 code: TypeCode::Signature,
158 params,
159 }
160 }
161
162 pub fn sig_params(&self) -> &[ExprType] {
164 debug_assert_eq!(self.code, TypeCode::Signature);
165 &self.params[..self.params.len() - 1]
166 }
167
168 pub fn sig_return(&self) -> &ExprType {
170 debug_assert_eq!(self.code, TypeCode::Signature);
171 self.params.last().unwrap()
172 }
173
174 pub fn match_call(&self, arg_types: &[ExprType]) -> Option<HashMap<TypeCode, ExprType>> {
177 let sig_params = self.sig_params();
178 if sig_params.len() != arg_types.len() {
179 return None;
180 }
181 let mut bindings = HashMap::new();
182 for (sig_p, arg_t) in sig_params.iter().zip(arg_types.iter()) {
183 let sub = sig_p.match_type(arg_t)?;
184 for (k, v) in sub {
185 if let Some(existing) = bindings.get(&k) {
186 if *existing != v {
187 return None;
188 }
189 }
190 bindings.insert(k, v);
191 }
192 }
193 Some(bindings)
194 }
195
196 pub fn resolve_call(&self, arg_types: &[ExprType]) -> Option<ExprType> {
199 let bindings = self.match_call(arg_types)?;
200 Some(self.sig_return().substitute(&bindings))
201 }
202}
203
204fn normalize_union(types: Vec<ExprType>) -> ExprType {
207 let mut members = Vec::new();
208 for t in types {
209 match t.code {
210 TypeCode::Any => return ExprType::ANY,
211 TypeCode::NoReturn => continue,
212 TypeCode::Union => members.extend(t.params),
213 _ => members.push(t),
214 }
215 }
216 let mut unresolved_constraints = Vec::new();
218 let mut non_unresolved = Vec::new();
219 for m in &members {
220 if m.code == TypeCode::Unresolved {
221 unresolved_constraints.push(m.params[0].clone());
222 } else {
223 non_unresolved.push(m.clone());
224 }
225 }
226 if !unresolved_constraints.is_empty() {
227 let mut all_parts = non_unresolved;
228 all_parts.extend(unresolved_constraints);
229 let inner = ExprType::union(all_parts);
230 return ExprType::unresolved(inner);
231 }
232 members.sort_by_key(|a| a.to_string());
234 members.dedup();
235 match members.len() {
236 0 => ExprType::NORETURN,
237 1 => members.into_iter().next().unwrap(),
238 _ => ExprType {
239 code: TypeCode::Union,
240 params: members,
241 },
242 }
243}
244
245impl ExprType {
248 pub fn is_list(&self) -> bool {
249 self.code == TypeCode::List
250 }
251
252 pub fn list_element_type(&self) -> Option<&ExprType> {
253 if self.code == TypeCode::List {
254 self.params.first()
255 } else {
256 None
257 }
258 }
259
260 pub fn is_symbolic(&self) -> bool {
261 matches!(
262 self.code,
263 TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
264 ) || self.params.iter().any(|p| p.is_symbolic())
265 }
266
267 pub fn is_concrete(&self) -> bool {
268 if matches!(
269 self.code,
270 TypeCode::Any
271 | TypeCode::Union
272 | TypeCode::Unresolved
273 | TypeCode::TypeVarT
274 | TypeCode::TypeVarT1
275 | TypeCode::TypeVarT2
276 | TypeCode::TypeVarT3
277 | TypeCode::Signature
278 ) {
279 return false;
280 }
281 self.params.iter().all(|p| p.is_concrete())
282 }
283
284 pub fn substitute(&self, bindings: &HashMap<TypeCode, ExprType>) -> ExprType {
286 if let Some(bound) = bindings.get(&self.code) {
287 return bound.clone();
288 }
289 if self.params.is_empty() {
290 return self.clone();
291 }
292 let new_params: Vec<ExprType> =
293 self.params.iter().map(|p| p.substitute(bindings)).collect();
294 ExprType::new(self.code, new_params)
295 }
296
297 pub fn match_type(&self, other: &ExprType) -> Option<HashMap<TypeCode, ExprType>> {
300 if matches!(
302 self.code,
303 TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
304 ) {
305 let mut m = HashMap::new();
306 m.insert(self.code, other.clone());
307 return Some(m);
308 }
309 if matches!(
310 other.code,
311 TypeCode::TypeVarT | TypeCode::TypeVarT1 | TypeCode::TypeVarT2 | TypeCode::TypeVarT3
312 ) {
313 let mut m = HashMap::new();
314 m.insert(other.code, self.clone());
315 return Some(m);
316 }
317 if self.code == TypeCode::Any || other.code == TypeCode::Any {
319 return Some(HashMap::new());
320 }
321 if self.code == TypeCode::Unresolved {
323 return self.params[0].match_type(other);
324 }
325 if other.code == TypeCode::Unresolved {
326 return self.match_type(&other.params[0]);
327 }
328 if self.code == TypeCode::Union && other.code == TypeCode::Union {
330 for s in &self.params {
331 for c in &other.params {
332 if let Some(r) = s.match_type(c) {
333 return Some(r);
334 }
335 }
336 }
337 return None;
338 }
339 if self.code == TypeCode::Union {
340 for member in &self.params {
341 if let Some(r) = member.match_type(other) {
342 return Some(r);
343 }
344 }
345 return None;
346 }
347 if other.code == TypeCode::Union {
348 for member in &other.params {
349 if let Some(r) = self.match_type(member) {
350 return Some(r);
351 }
352 }
353 return None;
354 }
355 if self.code != other.code {
357 return None;
358 }
359 if self.params.len() != other.params.len() {
360 return None;
361 }
362 let mut bindings = HashMap::new();
363 for (sp, cp) in self.params.iter().zip(other.params.iter()) {
364 let sub = sp.match_type(cp)?;
365 for (k, v) in sub {
366 if let Some(existing) = bindings.get(&k) {
367 if *existing != v {
368 return None;
369 }
370 }
371 bindings.insert(k, v);
372 }
373 }
374 Some(bindings)
375 }
376
377 pub fn new(code: TypeCode, params: Vec<ExprType>) -> Self {
379 match code {
380 TypeCode::Union => normalize_union(params),
381 TypeCode::List if params.len() == 1 => {
382 ExprType::list(params.into_iter().next().unwrap())
383 }
384 TypeCode::Unresolved if params.len() == 1 => {
385 ExprType::unresolved(params.into_iter().next().unwrap())
386 }
387 _ => ExprType { code, params },
388 }
389 }
390}
391
392impl ExprType {
395 const MAX_PARSE_DEPTH: usize = 10;
396
397 pub fn parse(s: &str) -> Result<ExprType, String> {
399 Self::parse_inner(s, 0)
400 }
401
402 fn parse_inner(s: &str, depth: usize) -> Result<ExprType, String> {
403 if depth > Self::MAX_PARSE_DEPTH {
404 return Err("Type nesting depth exceeded".to_string());
405 }
406 if s.starts_with('(') {
408 if let Some(arrow_pos) = s.find(") -> ") {
409 let params_str = &s[1..arrow_pos];
410 let ret_str = &s[arrow_pos + 5..];
411 let param_types = if params_str.is_empty() {
412 Vec::new()
413 } else {
414 split_params(params_str)
415 .iter()
416 .map(|p| Self::parse_inner(p, depth + 1))
417 .collect::<Result<Vec<_>, _>>()?
418 };
419 let return_type = Self::parse_inner(ret_str, depth + 1)?;
420 return Ok(ExprType::signature(param_types, return_type));
421 }
422 }
423 let parts = split_union(s);
425 if parts.len() > 1 {
426 let types: Result<Vec<_>, _> = parts
427 .iter()
428 .map(|p| Self::parse_inner(p, depth + 1))
429 .collect();
430 return Ok(ExprType::union(types?));
431 }
432 match s {
434 "bool" => return Ok(ExprType::BOOL),
435 "int" => return Ok(ExprType::INT),
436 "float" => return Ok(ExprType::FLOAT),
437 "string" => return Ok(ExprType::STRING),
438 "path" => return Ok(ExprType::PATH),
439 "range_expr" => return Ok(ExprType::RANGE_EXPR),
440 "nulltype" => return Ok(ExprType::NULLTYPE),
441 "noreturn" => return Ok(ExprType::NORETURN),
442 "any" => return Ok(ExprType::ANY),
443 "unresolved" => return Ok(ExprType::unresolved(ExprType::ANY)),
444 _ => {}
445 }
446 if let Some(inner) = s.strip_suffix('?') {
448 let t = Self::parse_inner(inner, depth + 1)?;
449 return Ok(ExprType::union(vec![t, ExprType::NULLTYPE]));
450 }
451 if let Some(inner) = s.strip_prefix("list[").and_then(|s| s.strip_suffix(']')) {
453 let elem = Self::parse_inner(inner, depth + 1)?;
454 return Ok(ExprType::list(elem));
455 }
456 if let Some(inner) = s
458 .strip_prefix("unresolved[")
459 .and_then(|s| s.strip_suffix(']'))
460 {
461 let constraint = Self::parse_inner(inner, depth + 1)?;
462 return Ok(ExprType::unresolved(constraint));
463 }
464 match s {
466 "T" => return Ok(ExprType::T),
467 "T1" => return Ok(ExprType::T1),
468 "T2" => return Ok(ExprType::T2),
469 "T3" => return Ok(ExprType::T3),
470 _ => {}
471 }
472 Err(format!("Unknown type string: {s}"))
473 }
474}
475
476fn split_union(s: &str) -> Vec<&str> {
477 let mut parts = Vec::new();
478 let mut start = 0;
479 let mut depth = 0;
480 let bytes = s.as_bytes();
481 let mut i = 0;
482 while i < bytes.len() {
483 match bytes[i] {
484 b'[' | b'(' => depth += 1,
485 b']' | b')' => depth -= 1,
486 b' ' if depth == 0
487 && i + 2 < bytes.len()
488 && bytes[i + 1] == b'|'
489 && bytes[i + 2] == b' ' =>
490 {
491 parts.push(&s[start..i]);
492 i += 3;
493 start = i;
494 continue;
495 }
496 _ => {}
497 }
498 i += 1;
499 }
500 parts.push(&s[start..]);
501 parts
502}
503
504fn split_params(s: &str) -> Vec<&str> {
506 let mut parts = Vec::new();
507 let mut start = 0;
508 let mut depth = 0;
509 let bytes = s.as_bytes();
510 let mut i = 0;
511 while i < bytes.len() {
512 match bytes[i] {
513 b'[' | b'(' => depth += 1,
514 b']' | b')' => depth -= 1,
515 b',' if depth == 0 => {
516 parts.push(s[start..i].trim());
517 start = i + 1;
518 }
519 _ => {}
520 }
521 i += 1;
522 }
523 let last = s[start..].trim();
524 if !last.is_empty() {
525 parts.push(last);
526 }
527 parts
528}
529
530impl fmt::Display for ExprType {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 match self.code {
535 TypeCode::NullType => write!(f, "nulltype"),
536 TypeCode::Bool => write!(f, "bool"),
537 TypeCode::Int => write!(f, "int"),
538 TypeCode::Float => write!(f, "float"),
539 TypeCode::String => write!(f, "string"),
540 TypeCode::Path => write!(f, "path"),
541 TypeCode::RangeExpr => write!(f, "range_expr"),
542 TypeCode::Any => write!(f, "any"),
543 TypeCode::NoReturn => write!(f, "noreturn"),
544 TypeCode::TypeVarT => write!(f, "T"),
545 TypeCode::TypeVarT1 => write!(f, "T1"),
546 TypeCode::TypeVarT2 => write!(f, "T2"),
547 TypeCode::TypeVarT3 => write!(f, "T3"),
548 TypeCode::List => {
549 if let Some(elem) = self.params.first() {
550 write!(f, "list[{elem}]")
551 } else {
552 write!(f, "list")
553 }
554 }
555 TypeCode::Unresolved => {
556 if let Some(constraint) = self.params.first() {
557 if constraint.code == TypeCode::Any {
558 write!(f, "unresolved")
559 } else {
560 write!(f, "unresolved[{constraint}]")
561 }
562 } else {
563 write!(f, "unresolved")
564 }
565 }
566 TypeCode::Union => {
567 let non_null: Vec<_> = self
568 .params
569 .iter()
570 .filter(|t| t.code != TypeCode::NullType)
571 .collect();
572 let has_null = non_null.len() < self.params.len();
573 if has_null && non_null.len() == 1 {
574 return write!(f, "{}?", non_null[0]);
575 }
576 let mut parts: Vec<std::string::String> =
577 non_null.iter().map(|t| t.to_string()).collect();
578 if has_null {
579 parts.push("nulltype".to_string());
580 }
581 write!(f, "{}", parts.join(" | "))
582 }
583 TypeCode::Signature => {
584 let params = &self.params[..self.params.len() - 1];
585 let ret = self.params.last().unwrap();
586 let param_strs: Vec<String> = params.iter().map(|p| p.to_string()).collect();
587 write!(f, "({}) -> {}", param_strs.join(", "), ret)
588 }
589 }
590 }
591}
592
593impl PartialEq for ExprType {
596 fn eq(&self, other: &Self) -> bool {
597 self.code == other.code && self.params == other.params
598 }
599}
600
601impl std::hash::Hash for ExprType {
602 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
603 self.code.hash(state);
604 self.params.hash(state);
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
614 fn basic_types() {
615 assert_eq!(ExprType::BOOL.code, TypeCode::Bool);
616 assert_eq!(ExprType::INT.code, TypeCode::Int);
617 assert_eq!(ExprType::FLOAT.code, TypeCode::Float);
618 assert_eq!(ExprType::STRING.code, TypeCode::String);
619 assert_eq!(ExprType::PATH.code, TypeCode::Path);
620 }
621
622 #[test]
624 fn display_int() {
625 assert_eq!(ExprType::INT.to_string(), "int");
626 }
627 #[test]
628 fn display_list_int() {
629 assert_eq!(ExprType::list(ExprType::INT).to_string(), "list[int]");
630 }
631 #[test]
632 fn display_any() {
633 assert_eq!(ExprType::ANY.to_string(), "any");
634 }
635 #[test]
636 fn display_noreturn() {
637 assert_eq!(ExprType::NORETURN.to_string(), "noreturn");
638 }
639 #[test]
640 fn display_typevar() {
641 assert_eq!(ExprType::T1.to_string(), "T1");
642 }
643 #[test]
644 fn display_union() {
645 assert_eq!(
646 ExprType::union(vec![ExprType::INT, ExprType::STRING]).to_string(),
647 "int | string"
648 );
649 }
650 #[test]
651 fn display_nullable() {
652 assert_eq!(
653 ExprType::union(vec![ExprType::INT, ExprType::NULLTYPE]).to_string(),
654 "int?"
655 );
656 }
657 #[test]
658 fn display_unresolved_bare() {
659 assert_eq!(
660 ExprType::unresolved(ExprType::ANY).to_string(),
661 "unresolved"
662 );
663 }
664 #[test]
665 fn display_unresolved_int() {
666 assert_eq!(
667 ExprType::unresolved(ExprType::INT).to_string(),
668 "unresolved[int]"
669 );
670 }
671
672 #[test]
674 fn parse_int() {
675 assert_eq!(ExprType::parse("int").unwrap(), ExprType::INT);
676 }
677 #[test]
678 fn parse_float() {
679 assert_eq!(ExprType::parse("float").unwrap(), ExprType::FLOAT);
680 }
681 #[test]
682 fn parse_string() {
683 assert_eq!(ExprType::parse("string").unwrap(), ExprType::STRING);
684 }
685 #[test]
686 fn parse_bool() {
687 assert_eq!(ExprType::parse("bool").unwrap(), ExprType::BOOL);
688 }
689 #[test]
690 fn parse_path() {
691 assert_eq!(ExprType::parse("path").unwrap(), ExprType::PATH);
692 }
693 #[test]
694 fn parse_range_expr() {
695 assert_eq!(ExprType::parse("range_expr").unwrap(), ExprType::RANGE_EXPR);
696 }
697 #[test]
698 fn parse_nulltype() {
699 assert_eq!(ExprType::parse("nulltype").unwrap(), ExprType::NULLTYPE);
700 }
701 #[test]
702 fn parse_noreturn() {
703 assert_eq!(ExprType::parse("noreturn").unwrap(), ExprType::NORETURN);
704 }
705 #[test]
706 fn parse_any() {
707 assert_eq!(ExprType::parse("any").unwrap(), ExprType::ANY);
708 }
709 #[test]
710 fn parse_list_int() {
711 assert_eq!(
712 ExprType::parse("list[int]").unwrap(),
713 ExprType::list(ExprType::INT)
714 );
715 }
716 #[test]
717 fn parse_list_string() {
718 assert_eq!(
719 ExprType::parse("list[string]").unwrap(),
720 ExprType::list(ExprType::STRING)
721 );
722 }
723 #[test]
724 fn parse_nested_list() {
725 assert_eq!(
726 ExprType::parse("list[list[int]]").unwrap(),
727 ExprType::list(ExprType::list(ExprType::INT))
728 );
729 }
730 #[test]
731 fn parse_optional() {
732 let t = ExprType::parse("int?").unwrap();
733 assert_eq!(t.code, TypeCode::Union);
734 assert_eq!(t.to_string(), "int?");
735 }
736 #[test]
737 fn parse_union() {
738 let t = ExprType::parse("int | string").unwrap();
739 assert_eq!(t.code, TypeCode::Union);
740 assert_eq!(t.to_string(), "int | string");
741 }
742 #[test]
743 fn parse_unresolved_bare() {
744 assert_eq!(
745 ExprType::parse("unresolved").unwrap(),
746 ExprType::unresolved(ExprType::ANY)
747 );
748 }
749 #[test]
750 fn parse_unresolved_int() {
751 assert_eq!(
752 ExprType::parse("unresolved[int]").unwrap(),
753 ExprType::unresolved(ExprType::INT)
754 );
755 }
756 #[test]
757 fn parse_unresolved_list() {
758 assert_eq!(
759 ExprType::parse("unresolved[list[string]]").unwrap(),
760 ExprType::unresolved(ExprType::list(ExprType::STRING))
761 );
762 }
763 #[test]
764 fn parse_unknown_rejects() {
765 assert!(ExprType::parse("notavalidtype").is_err());
766 }
767 #[test]
768 fn parse_case_sensitive() {
769 assert!(ExprType::parse("INT").is_err());
770 }
771 #[test]
772 fn parse_whitespace_rejected() {
773 assert!(ExprType::parse(" int").is_err());
774 }
775
776 #[test]
778 fn roundtrip_bare() {
779 let t = ExprType::parse("unresolved").unwrap();
780 assert_eq!(ExprType::parse(&t.to_string()).unwrap(), t);
781 }
782 #[test]
783 fn roundtrip_constrained() {
784 for s in &[
785 "unresolved[int]",
786 "unresolved[list[string]]",
787 "unresolved[float | int]",
788 ] {
789 let t = ExprType::parse(s).unwrap();
790 assert_eq!(
791 ExprType::parse(&t.to_string()).unwrap(),
792 t,
793 "roundtrip failed for {s}"
794 );
795 }
796 }
797
798 #[test]
800 fn eq_same() {
801 assert_eq!(ExprType::INT, ExprType::INT);
802 }
803 #[test]
804 fn eq_diff() {
805 assert_ne!(ExprType::INT, ExprType::FLOAT);
806 }
807 #[test]
808 fn eq_list() {
809 assert_eq!(ExprType::list(ExprType::INT), ExprType::list(ExprType::INT));
810 }
811 #[test]
812 fn ne_list() {
813 assert_ne!(
814 ExprType::list(ExprType::INT),
815 ExprType::list(ExprType::STRING)
816 );
817 }
818 #[test]
819 fn hash_consistent() {
820 use std::collections::HashSet;
821 let mut s = HashSet::new();
822 s.insert(ExprType::INT);
823 s.insert(ExprType::FLOAT);
824 s.insert(ExprType::INT);
825 assert_eq!(s.len(), 2);
826 }
827
828 #[test]
830 fn union_dedup() {
831 assert_eq!(
832 ExprType::union(vec![ExprType::INT, ExprType::INT]).code,
833 TypeCode::Int
834 );
835 }
836 #[test]
837 fn union_single_unwrap() {
838 assert_eq!(
839 ExprType::union(vec![ExprType::STRING]).code,
840 TypeCode::String
841 );
842 }
843 #[test]
844 fn union_flatten() {
845 let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
846 let u2 = ExprType::union(vec![ExprType::FLOAT, ExprType::BOOL]);
847 let combined = ExprType::union(vec![u1, u2]);
848 assert_eq!(combined.code, TypeCode::Union);
849 assert_eq!(combined.params.len(), 4);
850 assert_eq!(combined.to_string(), "bool | float | int | string");
851 }
852 #[test]
853 fn union_any_absorbs() {
854 assert_eq!(
855 ExprType::union(vec![ExprType::INT, ExprType::ANY]).code,
856 TypeCode::Any
857 );
858 }
859 #[test]
860 fn union_noreturn_collapses() {
861 assert_eq!(
862 ExprType::union(vec![ExprType::INT, ExprType::NORETURN]),
863 ExprType::INT
864 );
865 }
866 #[test]
867 fn union_all_noreturn() {
868 assert_eq!(
869 ExprType::union(vec![ExprType::NORETURN, ExprType::NORETURN]),
870 ExprType::NORETURN
871 );
872 }
873 #[test]
874 fn union_order_independent() {
875 let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
876 let u2 = ExprType::union(vec![ExprType::STRING, ExprType::INT]);
877 assert_eq!(u1, u2);
878 }
879 #[test]
880 fn union_hash_consistent() {
881 use std::collections::HashSet;
882 let u1 = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
883 let u2 = ExprType::parse("int | string").unwrap();
884 let mut s = HashSet::new();
885 s.insert(u1);
886 s.insert(u2);
887 assert_eq!(s.len(), 1);
888 }
889
890 #[test]
892 fn list_of_unresolved_hoists() {
893 let t = ExprType::list(ExprType::unresolved(ExprType::INT));
894 assert_eq!(t.code, TypeCode::Unresolved);
895 assert_eq!(t, ExprType::unresolved(ExprType::list(ExprType::INT)));
896 }
897 #[test]
898 fn union_with_unresolved_hoists() {
899 let t = ExprType::union(vec![ExprType::STRING, ExprType::unresolved(ExprType::INT)]);
900 assert_eq!(t.code, TypeCode::Unresolved);
901 assert_eq!(
902 t,
903 ExprType::unresolved(ExprType::union(vec![ExprType::INT, ExprType::STRING]))
904 );
905 }
906 #[test]
907 fn nested_unresolved_flattens() {
908 let t = ExprType::unresolved(ExprType::unresolved(ExprType::INT));
909 assert_eq!(t.code, TypeCode::Unresolved);
910 assert_eq!(t, ExprType::unresolved(ExprType::INT));
911 }
912 #[test]
913 fn unresolved_never_inside_list() {
914 let t = ExprType::list(ExprType::unresolved(ExprType::STRING));
915 assert_eq!(t.code, TypeCode::Unresolved);
916 assert_eq!(t.params[0], ExprType::list(ExprType::STRING));
917 }
918
919 #[test]
921 fn concrete_int() {
922 assert!(ExprType::INT.is_concrete());
923 }
924 #[test]
925 fn concrete_list_int() {
926 assert!(ExprType::list(ExprType::INT).is_concrete());
927 }
928 #[test]
929 fn not_concrete_any() {
930 assert!(!ExprType::ANY.is_concrete());
931 }
932 #[test]
933 fn not_concrete_union() {
934 assert!(!ExprType::union(vec![ExprType::INT, ExprType::STRING]).is_concrete());
935 }
936 #[test]
937 fn not_concrete_typevar() {
938 assert!(!ExprType::T1.is_concrete());
939 }
940 #[test]
941 fn symbolic_t1() {
942 assert!(ExprType::T1.is_symbolic());
943 }
944 #[test]
945 fn not_symbolic_int() {
946 assert!(!ExprType::INT.is_symbolic());
947 }
948 #[test]
949 fn symbolic_list_t1() {
950 assert!(ExprType::list(ExprType::T1).is_symbolic());
951 }
952
953 #[test]
955 fn match_simple_typevar() {
956 let b = ExprType::T1.match_type(&ExprType::INT).unwrap();
957 assert_eq!(b[&TypeCode::TypeVarT1], ExprType::INT);
958 }
959 #[test]
960 fn match_nested_typevar() {
961 let list_t1 = ExprType::list(ExprType::T1);
962 let list_int = ExprType::list(ExprType::INT);
963 let b = list_t1.match_type(&list_int).unwrap();
964 assert_eq!(b[&TypeCode::TypeVarT1], ExprType::INT);
965 }
966 #[test]
967 fn match_no_match() {
968 let list_t1 = ExprType::list(ExprType::T1);
969 assert!(list_t1.match_type(&ExprType::INT).is_none());
970 }
971 #[test]
972 fn match_any() {
973 assert!(ExprType::ANY.match_type(&ExprType::INT).is_some());
974 assert!(ExprType::INT.match_type(&ExprType::ANY).is_some());
975 }
976 #[test]
977 fn match_union_member() {
978 let u = ExprType::union(vec![ExprType::INT, ExprType::STRING]);
979 assert!(u.match_type(&ExprType::INT).is_some());
980 assert!(u.match_type(&ExprType::STRING).is_some());
981 assert!(u.match_type(&ExprType::FLOAT).is_none());
982 }
983 #[test]
984 fn match_unresolved_delegates() {
985 let t = ExprType::unresolved(ExprType::INT);
986 assert!(t.match_type(&ExprType::INT).is_some());
987 assert!(t.match_type(&ExprType::STRING).is_none());
988 }
989
990 #[test]
992 fn substitute_typevar() {
993 let list_t1 = ExprType::list(ExprType::T1);
994 let mut bindings = HashMap::new();
995 bindings.insert(TypeCode::TypeVarT1, ExprType::INT);
996 assert_eq!(list_t1.substitute(&bindings), ExprType::list(ExprType::INT));
997 }
998}