1use crate::{
2 CollectionColumns, CompareTypes, OneOf, SyntaxShape, TypeRelation, TypeSet, ast::PathMember,
3};
4use serde::{Deserialize, Serialize};
5use std::{borrow::Cow, fmt::Display};
6#[cfg(test)]
7use strum_macros::EnumIter;
8
9#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash, Ord, PartialOrd)]
10#[cfg_attr(test, derive(EnumIter))]
11pub enum Type {
12 Any,
14 Binary,
15 Block,
16 Bool,
17 CellPath,
18 Closure,
19 Custom(Box<str>),
20 Date,
21 Duration,
22 Error,
23 Filesize,
24 Float,
25 Int,
26 List(Box<Type>),
27 #[default]
28 Nothing,
29 Number,
31 OneOf(OneOf),
33 Range,
34 Record(CollectionColumns<Type>),
35 String,
36 Glob,
37 Table(CollectionColumns<Type>),
38}
39
40fn follow_cell_path_recursive<'a>(
41 current: Cow<'a, Type>,
42 path_members: &mut dyn Iterator<Item = &'a PathMember>,
43) -> Option<Cow<'a, Type>> {
44 let Some(first) = path_members.next() else {
45 return Some(current);
46 };
47 match (current.as_ref(), first) {
48 (Type::Record(_), PathMember::String { val, .. }) => {
49 let next = match current {
50 Cow::Borrowed(Type::Record(f)) => {
51 Cow::Borrowed(&f.iter().find(|(name, _)| name == val)?.1)
52 }
53 Cow::Owned(Type::Record(f)) => {
54 Cow::Owned(f.into_iter().find(|(name, _)| name == val)?.1)
55 }
56 _ => unreachable!(),
57 };
58 follow_cell_path_recursive(next, path_members)
59 }
60
61 (Type::Table(f), PathMember::Int { .. }) => {
63 follow_cell_path_recursive(Cow::Owned(Type::Record(f.clone())), path_members)
64 }
65
66 (Type::Table(columns), PathMember::String { val, .. }) => {
68 let (_, sub_type) = columns.iter().find(|(name, _)| name == val)?;
69 let list_type = Type::List(Box::new(sub_type.clone()));
70 follow_cell_path_recursive(Cow::Owned(list_type), path_members)
71 }
72
73 (Type::List(_), PathMember::Int { .. }) => {
74 let next = match current {
75 Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
76 Cow::Owned(Type::List(i)) => Cow::Owned(*i),
77 _ => unreachable!(),
78 };
79 follow_cell_path_recursive(next, path_members)
80 }
81
82 (Type::List(_), PathMember::String { .. }) => {
84 let next = match current {
85 Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
86 Cow::Owned(Type::List(i)) => Cow::Owned(*i),
87 _ => unreachable!(),
88 };
89
90 let mut found_int_member = false;
91 let mut new_iter = std::iter::once(first).chain(path_members).filter(|pm| {
92 let first_int = !found_int_member && matches!(pm, PathMember::Int { .. });
93 if first_int {
94 found_int_member = true;
95 }
96 !first_int
97 });
98 let inner_ty = follow_cell_path_recursive(next, &mut new_iter);
99
100 if found_int_member {
103 inner_ty
104 } else {
105 inner_ty.map(|inner_ty| Cow::Owned(Type::List(Box::new(inner_ty.into_owned()))))
106 }
107 }
108
109 _ => None,
110 }
111}
112
113impl Type {
114 pub fn list(inner: Type) -> Self {
115 Self::List(Box::new(inner))
116 }
117
118 pub fn one_of(types: impl IntoIterator<Item = Type>) -> Self {
121 Self::OneOf(OneOf::from_iter(types))
122 }
123
124 pub fn record() -> Self {
125 Self::Record(Default::default())
126 }
127
128 pub fn table() -> Self {
129 Self::Table(Default::default())
130 }
131
132 pub fn custom(name: impl Into<Box<str>>) -> Self {
133 Self::Custom(name.into())
134 }
135
136 pub(crate) fn flat_widen(lhs: Type, rhs: Type) -> Result<Type, (Type, Type)> {
138 match (lhs, rhs) {
139 (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
141
142 (Type::Int, Type::Float) | (Type::Float, Type::Int) => Ok(Type::Number),
144
145 tys @ ((Type::Glob, Type::String)
147 | (Type::String, Type::Glob)
148 | (Type::String | Type::Int, Type::CellPath)
149 | (Type::CellPath, Type::String | Type::Int)) => Err(tys),
150
151 (Type::Record(lhs), Type::Record(rhs)) => Ok(Type::Record(lhs.union(rhs))),
153 (Type::Table(lhs), Type::Table(rhs)) => Ok(Type::Table(lhs.union(rhs))),
154
155 tys @ ((Type::List(_), Type::Table(_)) | (Type::Table(_), Type::List(_))) => Err(tys),
158
159 (lhs, rhs) => match lhs.compare_types(&rhs) {
161 Some(rel) => Ok(match rel {
162 TypeRelation::Subtype => rhs,
163 TypeRelation::Equal => lhs,
164 TypeRelation::Supertype => lhs,
165 }),
166 None => Err((lhs, rhs)),
168 },
169 }
170 }
171
172 pub fn supertype_of(it: impl IntoIterator<Item = Type>) -> Option<Self> {
175 let mut it = it.into_iter();
176 it.next().and_then(|head| {
177 it.try_fold(head, |acc, e| match acc.union(e) {
178 Type::Any => None,
179 r => Some(r),
180 })
181 })
182 }
183
184 pub fn is_numeric(&self) -> bool {
185 matches!(self, Type::Int | Type::Float | Type::Number)
186 }
187
188 pub fn is_list(&self) -> bool {
189 matches!(self, Type::List(_))
190 }
191
192 pub fn accepts_cell_paths(&self) -> bool {
194 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
195 }
196
197 pub fn to_shape(&self) -> SyntaxShape {
198 match self {
199 Type::Int => SyntaxShape::Int,
200 Type::Float => SyntaxShape::Float,
201 Type::Range => SyntaxShape::Range,
202 Type::Bool => SyntaxShape::Boolean,
203 Type::String => SyntaxShape::String,
204 Type::Block => SyntaxShape::Block, Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
207 Type::Duration => SyntaxShape::Duration,
208 Type::Date => SyntaxShape::DateTime,
209 Type::Filesize => SyntaxShape::Filesize,
210 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
211 Type::Number => SyntaxShape::Number,
212 Type::OneOf(types) => SyntaxShape::OneOf(types.iter().map(Type::to_shape).collect()),
213 Type::Nothing => SyntaxShape::Nothing,
214 Type::Record(entries) => SyntaxShape::Record(entries.map(Type::to_shape)),
215 Type::Table(columns) => SyntaxShape::Table(columns.map(Type::to_shape)),
216 Type::Any => SyntaxShape::Any,
217 Type::Error => SyntaxShape::Any,
218 Type::Binary => SyntaxShape::Binary,
219 Type::Custom(_) => SyntaxShape::Any,
220 Type::Glob => SyntaxShape::GlobPattern,
221 }
222 }
223
224 pub fn get_non_specified_string(&self) -> String {
227 match self {
228 Type::Closure => String::from("closure"),
229 Type::Bool => String::from("bool"),
230 Type::Block => String::from("block"),
231 Type::CellPath => String::from("cell-path"),
232 Type::Date => String::from("datetime"),
233 Type::Duration => String::from("duration"),
234 Type::Filesize => String::from("filesize"),
235 Type::Float => String::from("float"),
236 Type::Int => String::from("int"),
237 Type::Range => String::from("range"),
238 Type::Record(_) => String::from("record"),
239 Type::Table(_) => String::from("table"),
240 Type::List(_) => String::from("list"),
241 Type::Nothing => String::from("nothing"),
242 Type::Number => String::from("number"),
243 Type::OneOf(_) => String::from("oneof"),
244 Type::String => String::from("string"),
245 Type::Any => String::from("any"),
246 Type::Error => String::from("error"),
247 Type::Binary => String::from("binary"),
248 Type::Custom(_) => String::from("custom"),
249 Type::Glob => String::from("glob"),
250 }
251 }
252
253 pub fn follow_cell_path<'a>(&'a self, path_members: &'a [PathMember]) -> Option<Cow<'a, Self>> {
254 follow_cell_path_recursive(Cow::Borrowed(self), &mut path_members.iter())
255 }
256}
257
258impl CompareTypes for Type {
259 fn compare_types(&self, other: &Self) -> Option<TypeRelation> {
260 match (self, other) {
261 (_, Type::Any) => Some(TypeRelation::Subtype),
262 (Type::Any, _) => Some(TypeRelation::Supertype),
263
264 (Type::Closure, Type::Block) => Some(TypeRelation::Supertype),
267 (Type::Block, Type::Closure) => Some(TypeRelation::Subtype),
268
269 (Type::String | Type::Int, Type::CellPath) => Some(TypeRelation::Subtype),
272 (Type::CellPath, Type::String | Type::Int) => Some(TypeRelation::Supertype),
273
274 (Type::Float | Type::Int, Type::Number) => Some(TypeRelation::Subtype),
275 (Type::Number, Type::Float | Type::Int) => Some(TypeRelation::Supertype),
276
277 (Type::Glob, Type::String) => Some(TypeRelation::Supertype),
278 (Type::String, Type::Glob) => Some(TypeRelation::Subtype),
279
280 (Type::List(t), Type::List(u)) => t.compare_types(u.as_ref()),
282
283 (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
284 this.compare_types(that)
285 }
286
287 (Type::Table(table_cols), Type::List(list_elem)) => match list_elem.as_ref() {
288 Type::Any => Some(TypeRelation::Subtype),
289 Type::Record(record_cols) => table_cols.compare_types(record_cols),
290 _ => None,
291 },
292 (Type::List(list_elem), Type::Table(table_cols)) => match list_elem.as_ref() {
293 Type::Any => Some(TypeRelation::Supertype),
294 Type::Record(record_cols) => record_cols.compare_types(table_cols),
295 _ => None,
296 },
297
298 (Type::OneOf(lhs_oneof), Type::OneOf(rhs_oneof)) => lhs_oneof.compare_types(rhs_oneof),
299 (Type::OneOf(lhs_oneof), rhs) => lhs_oneof.compare_types(rhs),
300 (lhs, Type::OneOf(rhs_oneof)) => lhs.compare_types(rhs_oneof),
301
302 (t, u) if t == u => Some(TypeRelation::Equal),
303
304 _ => None,
305 }
306 }
307
308 fn is_subtype_of(&self, other: &Self) -> bool {
315 matches!(
316 self.compare_types(other),
317 Some(TypeRelation::Subtype | TypeRelation::Equal)
318 )
319 }
320
321 fn is_any(&self) -> bool {
322 matches!(self, Type::Any)
323 }
324
325 fn is_assignable_to(&self, dst: &Self) -> bool {
326 let src = self;
327 match (dst, src) {
328 (Type::Table(dst_cols), Type::List(src_ty))
329 if let Type::Record(src_cols) = src_ty.as_ref() =>
330 {
331 src_cols.is_assignable_to(dst_cols)
332 }
333 (Type::List(dst_ty), Type::Table(src_cols))
334 if let Type::Record(dst_cols) = dst_ty.as_ref() =>
335 {
336 src_cols.is_assignable_to(dst_cols)
337 }
338 (Type::Record(dst_cols), Type::Record(src_cols))
339 | (Type::Table(dst_cols), Type::Table(src_cols)) => src_cols.is_assignable_to(dst_cols),
340 (Type::Glob, Type::String) => true,
342 (Type::String, Type::Glob) => false,
344 (Type::OneOf(dst_tys), Type::OneOf(src_tys)) => src_tys.is_assignable_to(dst_tys),
346 (Type::OneOf(dst_tys), src_ty) => src_ty.is_assignable_to(dst_tys),
347 (dst_ty, Type::OneOf(src_tys)) => src_tys.is_assignable_to(dst_ty),
348 (Type::List(_) | Type::Table(_) | Type::Record(_), Type::Custom(_)) => true,
350 (lhs, rhs @ Type::CellPath) => rhs.is_subtype_of(lhs),
351 (lhs, rhs) => rhs.compare_types(lhs).is_some(),
352 }
353 }
354}
355
356impl TypeSet for Type {
357 fn union(self, other: Self) -> Self {
358 let (lhs, rhs) = match Self::flat_widen(self, other) {
359 Ok(t) => return t,
360 Err(tys) => tys,
361 };
362
363 match (lhs, rhs) {
364 (Type::OneOf(ts), Type::OneOf(us)) => Type::OneOf(ts.union(us)),
365 (Type::OneOf(oneof), t) | (t, Type::OneOf(oneof)) => Type::OneOf(oneof.add_ty(t)),
366 (this, other) => Type::one_of([this, other]),
367 }
368 }
369}
370
371impl Display for Type {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 match self {
374 Type::Block => write!(f, "block"),
375 Type::Closure => write!(f, "closure"),
376 Type::Bool => write!(f, "bool"),
377 Type::CellPath => write!(f, "cell-path"),
378 Type::Date => write!(f, "datetime"),
379 Type::Duration => write!(f, "duration"),
380 Type::Filesize => write!(f, "filesize"),
381 Type::Float => write!(f, "float"),
382 Type::Int => write!(f, "int"),
383 Type::Range => write!(f, "range"),
384 Type::Record(columns) => write!(f, "record{columns}"),
385 Type::Table(columns) => write!(f, "table{columns}"),
386 Type::List(l) => write!(f, "list<{l}>"),
387 Type::Nothing => write!(f, "nothing"),
388 Type::Number => write!(f, "number"),
389 Type::OneOf(oneof) => write!(f, "{oneof}"),
390 Type::String => write!(f, "string"),
391 Type::Any => write!(f, "any"),
392 Type::Error => write!(f, "error"),
393 Type::Binary => write!(f, "binary"),
394 Type::Custom(custom) => write!(f, "{custom}"),
395 Type::Glob => write!(f, "glob"),
396 }
397 }
398}
399
400pub fn combined_type_string<'a, I>(types: I, join_word: &str) -> Option<String>
404where
405 I: IntoIterator<Item = &'a Type>,
406{
407 use std::fmt::Write as _;
408
409 let mut seen = Vec::new();
412 for t in types {
413 if !seen.contains(t) {
414 seen.push(t.clone());
415 }
416 }
417
418 match seen.as_slice() {
419 [] => None,
420 [one] => Some(one.to_string()),
421 [one, two] => Some(format!("{one} {join_word} {two}")),
422 [initial @ .., last] => {
423 let mut out = String::new();
424 for ele in initial {
425 let _ = write!(out, "{ele}, ");
426 }
427 let _ = write!(out, "{join_word} {last}");
428 Some(out)
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use strum::IntoEnumIterator;
437
438 mod subtype_relation {
439 use super::*;
440
441 #[test]
442 fn test_reflexivity() {
443 for ty in Type::iter() {
444 assert!(ty.is_subtype_of(&ty));
445 }
446 }
447
448 #[test]
449 fn test_any_is_top_type() {
450 for ty in Type::iter() {
451 assert!(ty.is_subtype_of(&Type::Any));
452 }
453 }
454
455 #[test]
456 fn test_number_supertype() {
457 assert!(Type::Int.is_subtype_of(&Type::Number));
458 assert!(Type::Float.is_subtype_of(&Type::Number));
459 }
460
461 #[test]
462 fn test_list_covariance() {
463 for ty1 in Type::iter() {
464 for ty2 in Type::iter() {
465 let list_ty1 = Type::List(Box::new(ty1.clone()));
466 let list_ty2 = Type::List(Box::new(ty2.clone()));
467 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
468 }
469 }
470 }
471 }
472
473 mod oneof {
474 use super::*;
475
476 #[test]
477 fn oneof_lhs() {
478 let rel = Type::one_of([Type::Int, Type::Nothing]).compare_types(&Type::Int);
479 assert_eq!(rel, Some(TypeRelation::Supertype));
480 }
481
482 #[test]
483 fn oneof_rhs() {
484 let rel = Type::Int.compare_types(&Type::one_of([Type::Int, Type::Nothing]));
485 assert_eq!(rel, Some(TypeRelation::Subtype));
486 }
487 }
488
489 mod oneof_flattening {
490 use super::*;
491
492 #[test]
493 fn test_oneof_creation_flattens() {
494 let nested = Type::one_of([
495 Type::String,
496 Type::one_of([Type::Int, Type::Float]),
497 Type::Bool,
498 ]);
499 if let Type::OneOf(oneof) = nested {
500 let types_vec: Vec<Type> = oneof.into_iter().collect();
501 assert_eq!(types_vec.len(), 3);
502 assert!(types_vec.contains(&Type::String));
503 assert!(types_vec.contains(&Type::Number));
504 assert!(types_vec.contains(&Type::Bool));
505 } else {
506 panic!("Expected OneOf");
507 }
508 }
509
510 #[test]
511 fn test_widen_flattens_oneof() {
512 let a = Type::one_of([Type::String, Type::Int]);
513 let b = Type::one_of([Type::Float, Type::Bool]);
514 let widened = a.union(b);
515 if let Type::OneOf(oneof) = widened {
516 let types_vec: Vec<Type> = oneof.into_iter().collect();
517 assert_eq!(types_vec.len(), 3);
518 assert!(types_vec.contains(&Type::String));
519 assert!(types_vec.contains(&Type::Number)); assert!(types_vec.contains(&Type::Bool));
521 } else {
522 panic!("Expected OneOf");
523 }
524 }
525
526 #[test]
527 fn test_oneof_deduplicates() {
528 let record_type =
529 Type::Record(vec![("content".to_string(), Type::list(Type::String))].into());
530 let oneof = Type::one_of([Type::String, record_type.clone(), record_type.clone()]);
531 if let Type::OneOf(oneof) = oneof {
532 let types_vec: Vec<Type> = oneof.into_iter().collect();
533 assert_eq!(types_vec.len(), 2);
534 assert!(types_vec.contains(&Type::String));
535 assert!(types_vec.contains(&record_type));
536 } else {
537 panic!("Expected OneOf");
538 }
539 }
540 }
541
542 mod widen_shortcuts {
544 use super::*;
545
546 #[test]
547 fn test_widen_subtype_shortcut() {
548 let union = Type::one_of([Type::String, Type::Number]);
550 let result = union.clone().union(Type::Int);
551 assert_eq!(result, union);
552
553 let union2 = Type::one_of([Type::Int, Type::String]);
555 let result2 = Type::Int.union(union2.clone());
556 assert_eq!(result2, union2);
557 }
558
559 #[test]
560 fn test_chain_shortcut() {
561 let mut t = Type::String;
563 for _ in 0..100 {
564 t = t.union(Type::Int);
565 }
566 let expected = Type::one_of([Type::String, Type::Int]);
567 assert_eq!(t, expected);
568 }
569
570 #[test]
571 fn test_list_table_widen_preserves_list() {
572 let list_record = Type::list(Type::Record(vec![("a".to_string(), Type::Int)].into()));
573 let table = Type::Table(vec![("a".to_string(), Type::Int)].into());
574
575 let widened = list_record.clone().union(table.clone());
576 let expected = Type::one_of([list_record, table]);
577
578 assert_eq!(widened, expected);
579 }
580
581 #[test]
582 fn test_glob_string_union() {
583 let g = Type::Glob;
584 let s = Type::String;
585 let w1 = g.clone().union(s.clone());
586 let w2 = s.clone().union(g.clone());
587 let expected1 = Type::one_of([Type::Glob, Type::String]);
588 let expected2 = Type::one_of([Type::String, Type::Glob]);
589 assert_eq!(w1, expected1);
590 assert_eq!(w2, expected2);
591 }
592 }
593}