1use crate::{SyntaxShape, ast::PathMember};
2use serde::{Deserialize, Serialize};
3use std::{borrow::Cow, fmt::Display};
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash, Ord, PartialOrd)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10 Any,
12 Binary,
13 Block,
14 Bool,
15 CellPath,
16 Closure,
17 Custom(Box<str>),
18 Date,
19 Duration,
20 Error,
21 Filesize,
22 Float,
23 Int,
24 List(Box<Type>),
25 #[default]
26 Nothing,
27 Number,
29 OneOf(Box<[Type]>),
31 Range,
32 Record(Box<[(String, Type)]>),
33 String,
34 Glob,
35 Table(Box<[(String, Type)]>),
36}
37
38fn follow_cell_path_recursive<'a>(
39 current: Cow<'a, Type>,
40 path_members: &mut dyn Iterator<Item = &'a PathMember>,
41) -> Option<Cow<'a, Type>> {
42 let Some(first) = path_members.next() else {
43 return Some(current);
44 };
45 match (current.as_ref(), first) {
46 (Type::Record(fields), PathMember::String { val, .. }) => {
47 let idx = fields.iter().position(|(name, _)| name == val)?;
48 let next = match current {
49 Cow::Borrowed(Type::Record(f)) => Cow::Borrowed(&f[idx].1),
50 Cow::Owned(Type::Record(f)) => Cow::Owned(f[idx].1.to_owned()),
51 _ => unreachable!(),
52 };
53 follow_cell_path_recursive(next, path_members)
54 }
55
56 (Type::Table(f), PathMember::Int { .. }) => {
58 follow_cell_path_recursive(Cow::Owned(Type::Record(f.clone())), path_members)
59 }
60
61 (Type::Table(fields), PathMember::String { val, .. }) => {
63 let (_, sub_type) = fields.iter().find(|(name, _)| name == val)?;
64 let list_type = Type::List(Box::new(sub_type.clone()));
65 follow_cell_path_recursive(Cow::Owned(list_type), path_members)
66 }
67
68 (Type::List(_), PathMember::Int { .. }) => {
69 let next = match current {
70 Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
71 Cow::Owned(Type::List(i)) => Cow::Owned(*i),
72 _ => unreachable!(),
73 };
74 follow_cell_path_recursive(next, path_members)
75 }
76
77 (Type::List(_), PathMember::String { .. }) => {
79 let next = match current {
80 Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
81 Cow::Owned(Type::List(i)) => Cow::Owned(*i),
82 _ => unreachable!(),
83 };
84
85 let mut found_int_member = false;
86 let mut new_iter = std::iter::once(first).chain(path_members).filter(|pm| {
87 let first_int = !found_int_member && matches!(pm, PathMember::Int { .. });
88 if first_int {
89 found_int_member = true;
90 }
91 !first_int
92 });
93 let inner_ty = follow_cell_path_recursive(next, &mut new_iter);
94
95 if found_int_member {
98 inner_ty
99 } else {
100 inner_ty.map(|inner_ty| Cow::Owned(Type::List(Box::new(inner_ty.into_owned()))))
101 }
102 }
103
104 _ => None,
105 }
106}
107
108impl Type {
109 pub fn list(inner: Type) -> Self {
110 Self::List(Box::new(inner))
111 }
112
113 pub fn one_of(types: impl IntoIterator<Item = Type>) -> Self {
116 let mut flattened = Vec::new();
117 for t in types {
118 Self::oneof_add(&mut flattened, t);
119 }
120 Self::OneOf(flattened.into())
121 }
122
123 pub fn record() -> Self {
124 Self::Record([].into())
125 }
126
127 pub fn table() -> Self {
128 Self::Table([].into())
129 }
130
131 pub fn custom(name: impl Into<Box<str>>) -> Self {
132 Self::Custom(name.into())
133 }
134
135 pub fn is_subtype_of(&self, other: &Type) -> bool {
141 let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
143 if this.is_empty() || that.is_empty() {
144 true
145 } else if this.len() < that.len() {
146 false
147 } else {
148 that.iter().all(|(col_y, ty_y)| {
149 if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
150 ty_x.is_subtype_of(ty_y)
151 } else {
152 false
153 }
154 })
155 }
156 };
157
158 match (self, other) {
159 (t, u) if t == u => true,
160 (_, Type::Any) => true,
161 (Type::String | Type::Int, Type::CellPath) => true,
164 (Type::OneOf(oneof), Type::CellPath) => {
165 oneof.iter().all(|t| t.is_subtype_of(&Type::CellPath))
166 }
167 (Type::Float | Type::Int, Type::Number) => true,
168 (Type::Glob, Type::String) | (Type::String, Type::Glob) => true,
169 (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
171 is_subtype_collection(this, that)
172 }
173 (Type::Table(_), Type::List(that)) if matches!(**that, Type::Any) => true,
174 (Type::Table(this), Type::List(that)) => {
175 matches!(that.as_ref(), Type::Record(that) if is_subtype_collection(this, that))
176 }
177 (Type::List(this), Type::Table(that)) => {
178 matches!(this.as_ref(), Type::Record(this) if is_subtype_collection(this, that))
179 }
180 (Type::OneOf(this), that @ Type::OneOf(_)) => {
181 this.iter().all(|t| t.is_subtype_of(that))
182 }
183 (this, Type::OneOf(that)) => that.iter().any(|t| this.is_subtype_of(t)),
184 _ => false,
185 }
186 }
187
188 fn flat_widen(lhs: Type, rhs: Type) -> Result<Type, (Type, Type)> {
190 if lhs == rhs {
192 return Ok(lhs);
193 }
194
195 if matches!(lhs, Type::Any) || matches!(rhs, Type::Any) {
197 return Ok(Type::Any);
198 }
199
200 if matches!(lhs, Type::Int | Type::Float | Type::Number)
203 && matches!(rhs, Type::Int | Type::Float | Type::Number)
204 {
205 return Ok(Type::Number);
206 }
207
208 if (matches!(lhs, Type::Glob) && matches!(rhs, Type::String))
210 || (matches!(lhs, Type::String) && matches!(rhs, Type::Glob))
211 {
212 return Err((lhs, rhs));
213 }
214
215 match (&lhs, &rhs) {
218 (Type::Record(this), Type::Record(that)) => {
219 let widened = Self::widen_collection(this.clone(), that.clone());
220 return Ok(Type::Record(widened));
221 }
222 (Type::Table(this), Type::Table(that)) => {
223 let widened = Self::widen_collection(this.clone(), that.clone());
224 return Ok(Type::Table(widened));
225 }
226
227 (Type::List(_list_item), Type::Table(_table))
228 | (Type::Table(_table), Type::List(_list_item)) => {
229 let item = match (lhs, rhs) {
231 (Type::List(list_item), Type::Table(table)) => match *list_item {
232 Type::Record(record) => Type::Record(Self::widen_collection(record, table)),
233 list_item => Type::one_of([list_item, Type::Record(table)]),
234 },
235 (Type::Table(table), Type::List(list_item)) => match *list_item {
236 Type::Record(record) => Type::Record(Self::widen_collection(record, table)),
237 list_item => Type::one_of([list_item, Type::Record(table)]),
238 },
239 _ => unreachable!(),
240 };
241 return Ok(Type::List(Box::new(item)));
242 }
243
244 (Type::List(lhs), Type::List(rhs)) => {
245 let lhs_inner = lhs.clone();
247 let rhs_inner = rhs.clone();
248 return Ok(Type::list(lhs_inner.widen(*rhs_inner)));
249 }
250
251 _ => {}
252 }
253
254 if lhs.is_subtype_of(&rhs) {
256 return Ok(rhs);
257 }
258 if rhs.is_subtype_of(&lhs) {
259 return Ok(lhs);
260 }
261
262 Err((lhs, rhs))
264 }
265
266 fn widen_collection(
267 lhs: Box<[(String, Type)]>,
268 rhs: Box<[(String, Type)]>,
269 ) -> Box<[(String, Type)]> {
270 if lhs.is_empty() || rhs.is_empty() {
271 return [].into();
272 }
273
274 let (small, big) = if lhs.len() <= rhs.len() {
276 (lhs, rhs)
277 } else {
278 (rhs, lhs)
279 };
280
281 const MAP_THRESH: usize = 16;
282 if big.len() > MAP_THRESH {
283 use std::collections::HashMap;
284 let mut big_map: HashMap<String, Type> = big.into_iter().collect();
285 small
286 .into_iter()
287 .filter_map(|(col, typ)| big_map.remove(&col).map(|b_typ| (col, typ.widen(b_typ))))
288 .collect()
289 } else {
290 small
291 .into_iter()
292 .filter_map(|(col, typ)| {
293 big.iter()
294 .find_map(|(b_col, b_typ)| (&col == b_col).then(|| b_typ.clone()))
295 .map(|b_typ| (col, typ.widen(b_typ)))
296 })
297 .collect()
298 }
299 }
300
301 pub fn widen(self, other: Type) -> Type {
303 fn shortcut_allowed(lhs: &Type, rhs: &Type) -> bool {
307 !matches!(
308 (lhs, rhs),
309 (Type::List(_), Type::Table(_)) | (Type::Table(_), Type::List(_))
310 )
311 }
312
313 if self.is_subtype_of(&other)
316 && !other.is_subtype_of(&self)
317 && shortcut_allowed(&self, &other)
318 {
319 return other;
320 }
321
322 let tu = match Self::flat_widen(self, other) {
323 Ok(t) => return t,
324 Err(tu) => tu,
325 };
326
327 match tu {
328 (Type::OneOf(ts), Type::OneOf(us)) => {
329 let (big, small) = match ts.len() >= us.len() {
330 true => (ts, us),
331 false => (us, ts),
332 };
333 let mut out = big.into_vec();
334 for t in small.into_iter() {
335 Self::oneof_add_widen(&mut out, t);
336 }
337 Type::one_of(out)
338 }
339 (Type::OneOf(oneof), t) | (t, Type::OneOf(oneof)) => {
340 let mut out = oneof.into_vec();
341 Self::oneof_add_widen(&mut out, t);
342 Type::one_of(out)
343 }
344 (this, other) => Type::one_of([this, other]),
345 }
346 }
347
348 fn oneof_add_widen(oneof: &mut Vec<Type>, mut t: Type) {
350 if let Type::OneOf(inner) = t {
352 for sub_t in inner.into_vec() {
353 Self::oneof_add_widen(oneof, sub_t);
354 }
355 return;
356 }
357
358 let mut i = 0;
359 while i < oneof.len() {
360 let one = std::mem::replace(&mut oneof[i], Type::Any);
361 match Self::flat_widen(one, t) {
362 Ok(one_t) => {
363 oneof[i] = one_t;
364 return;
365 }
366 Err((one_old, t_old)) => {
367 oneof[i] = one_old;
368 t = t_old; i += 1;
370 }
371 }
372 }
373
374 oneof.push(t);
375 }
376
377 fn oneof_add(oneof: &mut Vec<Type>, t: Type) {
379 match t {
380 Type::OneOf(inner) => {
381 for sub_t in inner.into_vec() {
382 Self::oneof_add(oneof, sub_t);
383 }
384 }
385 t => {
386 if !oneof.contains(&t) {
387 oneof.push(t);
388 }
389 }
390 }
391 }
392
393 pub fn supertype_of(it: impl IntoIterator<Item = Type>) -> Option<Self> {
395 let mut it = it.into_iter();
396 it.next().and_then(|head| {
397 it.try_fold(head, |acc, e| match acc.widen(e) {
398 Type::Any => None,
399 r => Some(r),
400 })
401 })
402 }
403
404 pub fn is_numeric(&self) -> bool {
405 matches!(self, Type::Int | Type::Float | Type::Number)
406 }
407
408 pub fn is_list(&self) -> bool {
409 matches!(self, Type::List(_))
410 }
411
412 pub fn accepts_cell_paths(&self) -> bool {
414 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
415 }
416
417 pub fn to_shape(&self) -> SyntaxShape {
418 let mk_shape = |tys: &[(String, Type)]| {
419 tys.iter()
420 .map(|(key, val)| (key.clone(), val.to_shape()))
421 .collect()
422 };
423
424 match self {
425 Type::Int => SyntaxShape::Int,
426 Type::Float => SyntaxShape::Float,
427 Type::Range => SyntaxShape::Range,
428 Type::Bool => SyntaxShape::Boolean,
429 Type::String => SyntaxShape::String,
430 Type::Block => SyntaxShape::Block, Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
433 Type::Duration => SyntaxShape::Duration,
434 Type::Date => SyntaxShape::DateTime,
435 Type::Filesize => SyntaxShape::Filesize,
436 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
437 Type::Number => SyntaxShape::Number,
438 Type::OneOf(types) => SyntaxShape::OneOf(types.iter().map(Type::to_shape).collect()),
439 Type::Nothing => SyntaxShape::Nothing,
440 Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
441 Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
442 Type::Any => SyntaxShape::Any,
443 Type::Error => SyntaxShape::Any,
444 Type::Binary => SyntaxShape::Binary,
445 Type::Custom(_) => SyntaxShape::Any,
446 Type::Glob => SyntaxShape::GlobPattern,
447 }
448 }
449
450 pub fn get_non_specified_string(&self) -> String {
453 match self {
454 Type::Closure => String::from("closure"),
455 Type::Bool => String::from("bool"),
456 Type::Block => String::from("block"),
457 Type::CellPath => String::from("cell-path"),
458 Type::Date => String::from("datetime"),
459 Type::Duration => String::from("duration"),
460 Type::Filesize => String::from("filesize"),
461 Type::Float => String::from("float"),
462 Type::Int => String::from("int"),
463 Type::Range => String::from("range"),
464 Type::Record(_) => String::from("record"),
465 Type::Table(_) => String::from("table"),
466 Type::List(_) => String::from("list"),
467 Type::Nothing => String::from("nothing"),
468 Type::Number => String::from("number"),
469 Type::OneOf(_) => String::from("oneof"),
470 Type::String => String::from("string"),
471 Type::Any => String::from("any"),
472 Type::Error => String::from("error"),
473 Type::Binary => String::from("binary"),
474 Type::Custom(_) => String::from("custom"),
475 Type::Glob => String::from("glob"),
476 }
477 }
478
479 pub fn follow_cell_path<'a>(&'a self, path_members: &'a [PathMember]) -> Option<Cow<'a, Self>> {
480 follow_cell_path_recursive(Cow::Borrowed(self), &mut path_members.iter())
481 }
482}
483
484impl Display for Type {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 match self {
487 Type::Block => write!(f, "block"),
488 Type::Closure => write!(f, "closure"),
489 Type::Bool => write!(f, "bool"),
490 Type::CellPath => write!(f, "cell-path"),
491 Type::Date => write!(f, "datetime"),
492 Type::Duration => write!(f, "duration"),
493 Type::Filesize => write!(f, "filesize"),
494 Type::Float => write!(f, "float"),
495 Type::Int => write!(f, "int"),
496 Type::Range => write!(f, "range"),
497 Type::Record(fields) => {
498 if fields.is_empty() {
499 write!(f, "record")
500 } else {
501 write!(
502 f,
503 "record<{}>",
504 fields
505 .iter()
506 .map(|(x, y)| format!("{x}: {y}"))
507 .collect::<Vec<String>>()
508 .join(", "),
509 )
510 }
511 }
512 Type::Table(columns) => {
513 if columns.is_empty() {
514 write!(f, "table")
515 } else {
516 write!(
517 f,
518 "table<{}>",
519 columns
520 .iter()
521 .map(|(x, y)| format!("{x}: {y}"))
522 .collect::<Vec<String>>()
523 .join(", ")
524 )
525 }
526 }
527 Type::List(l) => write!(f, "list<{l}>"),
528 Type::Nothing => write!(f, "nothing"),
529 Type::Number => write!(f, "number"),
530 Type::OneOf(types) => {
531 write!(f, "oneof")?;
532 let [first, rest @ ..] = &**types else {
533 return Ok(());
534 };
535 write!(f, "<{first}")?;
536 for t in rest {
537 write!(f, ", {t}")?;
538 }
539 f.write_str(">")
540 }
541 Type::String => write!(f, "string"),
542 Type::Any => write!(f, "any"),
543 Type::Error => write!(f, "error"),
544 Type::Binary => write!(f, "binary"),
545 Type::Custom(custom) => write!(f, "{custom}"),
546 Type::Glob => write!(f, "glob"),
547 }
548 }
549}
550
551pub fn combined_type_string(types: &[Type], join_word: &str) -> Option<String> {
555 use std::fmt::Write as _;
556
557 let mut seen = Vec::new();
560 for t in types {
561 if !seen.contains(t) {
562 seen.push(t.clone());
563 }
564 }
565
566 match seen.as_slice() {
567 [] => None,
568 [one] => Some(one.to_string()),
569 [one, two] => Some(format!("{one} {join_word} {two}")),
570 [initial @ .., last] => {
571 let mut out = String::new();
572 for ele in initial {
573 let _ = write!(out, "{ele}, ");
574 }
575 let _ = write!(out, "{join_word} {last}");
576 Some(out)
577 }
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::Type;
584 use strum::IntoEnumIterator;
585
586 mod subtype_relation {
587 use super::*;
588
589 #[test]
590 fn test_reflexivity() {
591 for ty in Type::iter() {
592 assert!(ty.is_subtype_of(&ty));
593 }
594 }
595
596 #[test]
597 fn test_any_is_top_type() {
598 for ty in Type::iter() {
599 assert!(ty.is_subtype_of(&Type::Any));
600 }
601 }
602
603 #[test]
604 fn test_number_supertype() {
605 assert!(Type::Int.is_subtype_of(&Type::Number));
606 assert!(Type::Float.is_subtype_of(&Type::Number));
607 }
608
609 #[test]
610 fn test_list_covariance() {
611 for ty1 in Type::iter() {
612 for ty2 in Type::iter() {
613 let list_ty1 = Type::List(Box::new(ty1.clone()));
614 let list_ty2 = Type::List(Box::new(ty2.clone()));
615 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
616 }
617 }
618 }
619 }
620
621 mod oneof_flattening {
622 use super::*;
623
624 #[test]
625 fn test_oneof_creation_flattens() {
626 let nested = Type::one_of([
627 Type::String,
628 Type::one_of([Type::Int, Type::Float]),
629 Type::Bool,
630 ]);
631 if let Type::OneOf(types) = nested {
632 let types_vec = types.to_vec();
633 assert_eq!(types_vec.len(), 4);
634 assert!(types_vec.contains(&Type::String));
635 assert!(types_vec.contains(&Type::Int));
636 assert!(types_vec.contains(&Type::Float));
637 assert!(types_vec.contains(&Type::Bool));
638 } else {
639 panic!("Expected OneOf");
640 }
641 }
642
643 #[test]
644 fn test_widen_flattens_oneof() {
645 let a = Type::one_of([Type::String, Type::Int]);
646 let b = Type::one_of([Type::Float, Type::Bool]);
647 let widened = a.widen(b);
648 if let Type::OneOf(types) = widened {
649 let types_vec = types.to_vec();
650 assert_eq!(types_vec.len(), 3);
651 assert!(types_vec.contains(&Type::String));
652 assert!(types_vec.contains(&Type::Number)); assert!(types_vec.contains(&Type::Bool));
654 } else {
655 panic!("Expected OneOf");
656 }
657 }
658
659 #[test]
660 fn test_oneof_deduplicates() {
661 let record_type =
662 Type::Record(vec![("content".to_string(), Type::list(Type::String))].into());
663 let oneof = Type::one_of([Type::String, record_type.clone(), record_type.clone()]);
664 if let Type::OneOf(types) = oneof {
665 let types_vec = types.to_vec();
666 assert_eq!(types_vec.len(), 2);
667 assert!(types_vec.contains(&Type::String));
668 assert!(types_vec.contains(&record_type));
669 } else {
670 panic!("Expected OneOf");
671 }
672 }
673 }
674
675 mod widen_shortcuts {
677 use super::*;
678
679 #[test]
680 fn test_widen_subtype_shortcut() {
681 let union = Type::one_of([Type::String, Type::Number]);
683 let result = union.clone().widen(Type::Int);
684 assert_eq!(result, union);
685
686 let union2 = Type::one_of([Type::Int, Type::String]);
688 let result2 = Type::Int.widen(union2.clone());
689 assert_eq!(result2, union2);
690 }
691
692 #[test]
693 fn test_chain_shortcut() {
694 let mut t = Type::String;
696 for _ in 0..100 {
697 t = t.widen(Type::Int);
698 }
699 let expected = Type::one_of([Type::String, Type::Int]);
700 assert_eq!(t, expected);
701 }
702
703 #[test]
704 fn test_list_table_widen_preserves_list() {
705 let list_record = Type::List(Box::new(Type::Record(
707 vec![("a".to_string(), Type::Int)].into(),
708 )));
709 let table = Type::Table(vec![("a".to_string(), Type::Int)].into());
710
711 let widened = list_record.clone().widen(table.clone());
712 let expected = Type::List(Box::new(Type::Record(
713 vec![("a".to_string(), Type::Int)].into(),
714 )));
715 assert_eq!(widened, expected);
716
717 let widened2 = table.widen(list_record.clone());
719 assert_eq!(widened2, expected);
720 }
721
722 #[test]
723 fn test_glob_string_union() {
724 let g = Type::Glob;
725 let s = Type::String;
726 let w1 = g.clone().widen(s.clone());
727 let w2 = s.clone().widen(g.clone());
728 let expected1 = Type::one_of([Type::Glob, Type::String]);
729 let expected2 = Type::one_of([Type::String, Type::Glob]);
730 assert_eq!(w1, expected1);
731 assert_eq!(w2, expected2);
732 }
733 }
734}