Skip to main content

nu_protocol/
ty.rs

1use crate::{
2    CollectionColumns, CompareTypes, OneOf, SyntaxShape, TypeRelation, TypeSet, ast::PathMember,
3};
4use serde::{Deserialize, Serialize};
5use std::{borrow::Cow, fmt::Display};
6#[cfg(test)]
7use strum_macros::EnumIter;
8
9#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash, Ord, PartialOrd)]
10#[cfg_attr(test, derive(EnumIter))]
11pub enum Type {
12    /// Top type, supertype of all types
13    Any,
14    Binary,
15    Block,
16    Bool,
17    CellPath,
18    Closure,
19    Custom(Box<str>),
20    Date,
21    Duration,
22    Error,
23    Filesize,
24    Float,
25    Int,
26    List(Box<Type>),
27    #[default]
28    Nothing,
29    /// Supertype of Int and Float. Equivalent to `oneof<int, float>`
30    Number,
31    /// Supertype of all types it contains.
32    OneOf(OneOf),
33    Range,
34    Record(CollectionColumns<Type>),
35    String,
36    Glob,
37    Table(CollectionColumns<Type>),
38}
39
40fn follow_cell_path_recursive<'a>(
41    current: Cow<'a, Type>,
42    path_members: &mut dyn Iterator<Item = &'a PathMember>,
43) -> Option<Cow<'a, Type>> {
44    let Some(first) = path_members.next() else {
45        return Some(current);
46    };
47    match (current.as_ref(), first) {
48        (Type::Record(_), PathMember::String { val, .. }) => {
49            let next = match current {
50                Cow::Borrowed(Type::Record(f)) => {
51                    Cow::Borrowed(&f.iter().find(|(name, _)| name == val)?.1)
52                }
53                Cow::Owned(Type::Record(f)) => {
54                    Cow::Owned(f.into_iter().find(|(name, _)| name == val)?.1)
55                }
56                _ => unreachable!(),
57            };
58            follow_cell_path_recursive(next, path_members)
59        }
60
61        // Table to Record (Int)
62        (Type::Table(f), PathMember::Int { .. }) => {
63            follow_cell_path_recursive(Cow::Owned(Type::Record(f.clone())), path_members)
64        }
65
66        // Table to List (String)
67        (Type::Table(columns), PathMember::String { val, .. }) => {
68            let (_, sub_type) = columns.iter().find(|(name, _)| name == val)?;
69            let list_type = Type::List(Box::new(sub_type.clone()));
70            follow_cell_path_recursive(Cow::Owned(list_type), path_members)
71        }
72
73        (Type::List(_), PathMember::Int { .. }) => {
74            let next = match current {
75                Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
76                Cow::Owned(Type::List(i)) => Cow::Owned(*i),
77                _ => unreachable!(),
78            };
79            follow_cell_path_recursive(next, path_members)
80        }
81
82        // List of Records indexed by key names
83        (Type::List(_), PathMember::String { .. }) => {
84            let next = match current {
85                Cow::Borrowed(Type::List(i)) => Cow::Borrowed(i.as_ref()),
86                Cow::Owned(Type::List(i)) => Cow::Owned(*i),
87                _ => unreachable!(),
88            };
89
90            let mut found_int_member = false;
91            let mut new_iter = std::iter::once(first).chain(path_members).filter(|pm| {
92                let first_int = !found_int_member && matches!(pm, PathMember::Int { .. });
93                if first_int {
94                    found_int_member = true;
95                }
96                !first_int
97            });
98            let inner_ty = follow_cell_path_recursive(next, &mut new_iter);
99
100            // If there's no int path member, need to wrap in a List type
101            // e.g. [{foo: bar}].foo -> [bar], list<record<foo: string>> -> list<string>
102            if found_int_member {
103                inner_ty
104            } else {
105                inner_ty.map(|inner_ty| Cow::Owned(Type::List(Box::new(inner_ty.into_owned()))))
106            }
107        }
108
109        _ => None,
110    }
111}
112
113impl Type {
114    pub fn list(inner: Type) -> Self {
115        Self::List(Box::new(inner))
116    }
117
118    /// Creates a OneOf type from an iterator of types.
119    /// Flattens any nested OneOf types and removes duplicates.
120    pub fn one_of(types: impl IntoIterator<Item = Type>) -> Self {
121        Self::OneOf(OneOf::from_iter(types))
122    }
123
124    pub fn record() -> Self {
125        Self::Record(Default::default())
126    }
127
128    pub fn table() -> Self {
129        Self::Table(Default::default())
130    }
131
132    pub fn custom(name: impl Into<Box<str>>) -> Self {
133        Self::Custom(name.into())
134    }
135
136    /// Returns supertype of arguments without creating a `oneof`, or falling back to `any` (unless one or both of the arguments are `any`)
137    pub(crate) fn flat_widen(lhs: Type, rhs: Type) -> Result<Type, (Type, Type)> {
138        match (lhs, rhs) {
139            // short circuit on `any`
140            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
141
142            // primitive number hierarchy is extremely common
143            (Type::Int, Type::Float) | (Type::Float, Type::Int) => Ok(Type::Number),
144
145            // despite their subtyping relation, these pairs should not combine into one or the other
146            tys @ ((Type::Glob, Type::String)
147            | (Type::String, Type::Glob)
148            | (Type::String | Type::Int, Type::CellPath)
149            | (Type::CellPath, Type::String | Type::Int)) => Err(tys),
150
151            // widen structural collections without checking for subtyping
152            (Type::Record(lhs), Type::Record(rhs)) => Ok(Type::Record(lhs.union(rhs))),
153            (Type::Table(lhs), Type::Table(rhs)) => Ok(Type::Table(lhs.union(rhs))),
154
155            // We want to have `oneof<list<T>, table>`, regardless whether one counts as a subtype
156            // of the other.
157            tys @ ((Type::List(_), Type::Table(_)) | (Type::Table(_), Type::List(_))) => Err(tys),
158
159            // If one type is already a subtype of the other, we can skip all of the heavier logic below.
160            (lhs, rhs) => match lhs.compare_types(&rhs) {
161                Some(rel) => Ok(match rel {
162                    TypeRelation::Subtype => rhs,
163                    TypeRelation::Equal => lhs,
164                    TypeRelation::Supertype => lhs,
165                }),
166                // Fallback - the two types are unrelated. Move them out so that callers don't have to clone again.
167                None => Err((lhs, rhs)),
168            },
169        }
170    }
171
172    /// Returns a supertype of all types within `it` *that is not `Any`*.
173    /// If `it` contains `Type::Any`, short circuits and returns `None`.
174    pub fn supertype_of(it: impl IntoIterator<Item = Type>) -> Option<Self> {
175        let mut it = it.into_iter();
176        it.next().and_then(|head| {
177            it.try_fold(head, |acc, e| match acc.union(e) {
178                Type::Any => None,
179                r => Some(r),
180            })
181        })
182    }
183
184    pub fn is_numeric(&self) -> bool {
185        matches!(self, Type::Int | Type::Float | Type::Number)
186    }
187
188    pub fn is_list(&self) -> bool {
189        matches!(self, Type::List(_))
190    }
191
192    /// Does this type represent a data structure containing values that can be addressed using 'cell paths'?
193    pub fn accepts_cell_paths(&self) -> bool {
194        matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
195    }
196
197    pub fn to_shape(&self) -> SyntaxShape {
198        match self {
199            Type::Int => SyntaxShape::Int,
200            Type::Float => SyntaxShape::Float,
201            Type::Range => SyntaxShape::Range,
202            Type::Bool => SyntaxShape::Boolean,
203            Type::String => SyntaxShape::String,
204            Type::Block => SyntaxShape::Block, // FIXME needs more accuracy
205            Type::Closure => SyntaxShape::Closure(None), // FIXME needs more accuracy
206            Type::CellPath => SyntaxShape::CellPath,
207            Type::Duration => SyntaxShape::Duration,
208            Type::Date => SyntaxShape::DateTime,
209            Type::Filesize => SyntaxShape::Filesize,
210            Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
211            Type::Number => SyntaxShape::Number,
212            Type::OneOf(types) => SyntaxShape::OneOf(types.iter().map(Type::to_shape).collect()),
213            Type::Nothing => SyntaxShape::Nothing,
214            Type::Record(entries) => SyntaxShape::Record(entries.map(Type::to_shape)),
215            Type::Table(columns) => SyntaxShape::Table(columns.map(Type::to_shape)),
216            Type::Any => SyntaxShape::Any,
217            Type::Error => SyntaxShape::Any,
218            Type::Binary => SyntaxShape::Binary,
219            Type::Custom(_) => SyntaxShape::Any,
220            Type::Glob => SyntaxShape::GlobPattern,
221        }
222    }
223
224    /// Get a string representation, without inner type specification of lists,
225    /// tables and records (get `list` instead of `list<any>`
226    pub fn get_non_specified_string(&self) -> String {
227        match self {
228            Type::Closure => String::from("closure"),
229            Type::Bool => String::from("bool"),
230            Type::Block => String::from("block"),
231            Type::CellPath => String::from("cell-path"),
232            Type::Date => String::from("datetime"),
233            Type::Duration => String::from("duration"),
234            Type::Filesize => String::from("filesize"),
235            Type::Float => String::from("float"),
236            Type::Int => String::from("int"),
237            Type::Range => String::from("range"),
238            Type::Record(_) => String::from("record"),
239            Type::Table(_) => String::from("table"),
240            Type::List(_) => String::from("list"),
241            Type::Nothing => String::from("nothing"),
242            Type::Number => String::from("number"),
243            Type::OneOf(_) => String::from("oneof"),
244            Type::String => String::from("string"),
245            Type::Any => String::from("any"),
246            Type::Error => String::from("error"),
247            Type::Binary => String::from("binary"),
248            Type::Custom(_) => String::from("custom"),
249            Type::Glob => String::from("glob"),
250        }
251    }
252
253    pub fn follow_cell_path<'a>(&'a self, path_members: &'a [PathMember]) -> Option<Cow<'a, Self>> {
254        follow_cell_path_recursive(Cow::Borrowed(self), &mut path_members.iter())
255    }
256}
257
258impl CompareTypes for Type {
259    fn compare_types(&self, other: &Self) -> Option<TypeRelation> {
260        match (self, other) {
261            (_, Type::Any) => Some(TypeRelation::Subtype),
262            (Type::Any, _) => Some(TypeRelation::Supertype),
263
264            // I don't know how this was decided but this is the behavior that was present in the
265            // parser
266            (Type::Closure, Type::Block) => Some(TypeRelation::Supertype),
267            (Type::Block, Type::Closure) => Some(TypeRelation::Subtype),
268
269            // We want `get`/`select`/etc to accept string and int values, so it's convenient to
270            // use them with variables, without having to explicitly convert them into cell-paths
271            (Type::String | Type::Int, Type::CellPath) => Some(TypeRelation::Subtype),
272            (Type::CellPath, Type::String | Type::Int) => Some(TypeRelation::Supertype),
273
274            (Type::Float | Type::Int, Type::Number) => Some(TypeRelation::Subtype),
275            (Type::Number, Type::Float | Type::Int) => Some(TypeRelation::Supertype),
276
277            (Type::Glob, Type::String) => Some(TypeRelation::Supertype),
278            (Type::String, Type::Glob) => Some(TypeRelation::Subtype),
279
280            // List is covariant
281            (Type::List(t), Type::List(u)) => t.compare_types(u.as_ref()),
282
283            (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
284                this.compare_types(that)
285            }
286
287            (Type::Table(table_cols), Type::List(list_elem)) => match list_elem.as_ref() {
288                Type::Any => Some(TypeRelation::Subtype),
289                Type::Record(record_cols) => table_cols.compare_types(record_cols),
290                _ => None,
291            },
292            (Type::List(list_elem), Type::Table(table_cols)) => match list_elem.as_ref() {
293                Type::Any => Some(TypeRelation::Supertype),
294                Type::Record(record_cols) => record_cols.compare_types(table_cols),
295                _ => None,
296            },
297
298            (Type::OneOf(lhs_oneof), Type::OneOf(rhs_oneof)) => lhs_oneof.compare_types(rhs_oneof),
299            (Type::OneOf(lhs_oneof), rhs) => lhs_oneof.compare_types(rhs),
300            (lhs, Type::OneOf(rhs_oneof)) => lhs.compare_types(rhs_oneof),
301
302            (t, u) if t == u => Some(TypeRelation::Equal),
303
304            _ => None,
305        }
306    }
307
308    /// Determine of the [`Type`] is a [subtype](https://en.wikipedia.org/wiki/Subtyping) of `other`.
309    ///
310    /// This should only be used at parse-time.
311    /// If you have a concrete [`Value`](crate::Value) or [`PipelineData`](crate::PipelineData),
312    /// you should use their respective `is_subtype_of` methods instead.
313    // This is identical to this method's default implementation. Written here to attach doccomment.
314    fn is_subtype_of(&self, other: &Self) -> bool {
315        matches!(
316            self.compare_types(other),
317            Some(TypeRelation::Subtype | TypeRelation::Equal)
318        )
319    }
320
321    fn is_any(&self) -> bool {
322        matches!(self, Type::Any)
323    }
324
325    fn is_assignable_to(&self, dst: &Self) -> bool {
326        let src = self;
327        match (dst, src) {
328            (Type::Table(dst_cols), Type::List(src_ty))
329                if let Type::Record(src_cols) = src_ty.as_ref() =>
330            {
331                src_cols.is_assignable_to(dst_cols)
332            }
333            (Type::List(dst_ty), Type::Table(src_cols))
334                if let Type::Record(dst_cols) = dst_ty.as_ref() =>
335            {
336                src_cols.is_assignable_to(dst_cols)
337            }
338            (Type::Record(dst_cols), Type::Record(src_cols))
339            | (Type::Table(dst_cols), Type::Table(src_cols)) => src_cols.is_assignable_to(dst_cols),
340            // strings can be coerced globs
341            (Type::Glob, Type::String) => true,
342            // but not the other way around
343            (Type::String, Type::Glob) => false,
344            // strings can be coerced to semver
345            (Type::OneOf(dst_tys), Type::OneOf(src_tys)) => src_tys.is_assignable_to(dst_tys),
346            (Type::OneOf(dst_tys), src_ty) => src_ty.is_assignable_to(dst_tys),
347            (dst_ty, Type::OneOf(src_tys)) => src_tys.is_assignable_to(dst_ty),
348            // leave it to the runtime
349            (Type::List(_) | Type::Table(_) | Type::Record(_), Type::Custom(_)) => true,
350            (lhs, rhs @ Type::CellPath) => rhs.is_subtype_of(lhs),
351            (lhs, rhs) => rhs.compare_types(lhs).is_some(),
352        }
353    }
354}
355
356impl TypeSet for Type {
357    fn union(self, other: Self) -> Self {
358        let (lhs, rhs) = match Self::flat_widen(self, other) {
359            Ok(t) => return t,
360            Err(tys) => tys,
361        };
362
363        match (lhs, rhs) {
364            (Type::OneOf(ts), Type::OneOf(us)) => Type::OneOf(ts.union(us)),
365            (Type::OneOf(oneof), t) | (t, Type::OneOf(oneof)) => Type::OneOf(oneof.add_ty(t)),
366            (this, other) => Type::one_of([this, other]),
367        }
368    }
369}
370
371impl Display for Type {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match self {
374            Type::Block => write!(f, "block"),
375            Type::Closure => write!(f, "closure"),
376            Type::Bool => write!(f, "bool"),
377            Type::CellPath => write!(f, "cell-path"),
378            Type::Date => write!(f, "datetime"),
379            Type::Duration => write!(f, "duration"),
380            Type::Filesize => write!(f, "filesize"),
381            Type::Float => write!(f, "float"),
382            Type::Int => write!(f, "int"),
383            Type::Range => write!(f, "range"),
384            Type::Record(columns) => write!(f, "record{columns}"),
385            Type::Table(columns) => write!(f, "table{columns}"),
386            Type::List(l) => write!(f, "list<{l}>"),
387            Type::Nothing => write!(f, "nothing"),
388            Type::Number => write!(f, "number"),
389            Type::OneOf(oneof) => write!(f, "{oneof}"),
390            Type::String => write!(f, "string"),
391            Type::Any => write!(f, "any"),
392            Type::Error => write!(f, "error"),
393            Type::Binary => write!(f, "binary"),
394            Type::Custom(custom) => write!(f, "{custom}"),
395            Type::Glob => write!(f, "glob"),
396        }
397    }
398}
399
400/// Get a string nicely combining multiple types
401///
402/// Helpful for listing types in errors
403pub fn combined_type_string<'a, I>(types: I, join_word: &str) -> Option<String>
404where
405    I: IntoIterator<Item = &'a Type>,
406{
407    use std::fmt::Write as _;
408
409    // Deduplicate types to avoid confusing repeated entries like
410    // "binary, binary, binary, or binary" in error messages.
411    let mut seen = Vec::new();
412    for t in types {
413        if !seen.contains(t) {
414            seen.push(t.clone());
415        }
416    }
417
418    match seen.as_slice() {
419        [] => None,
420        [one] => Some(one.to_string()),
421        [one, two] => Some(format!("{one} {join_word} {two}")),
422        [initial @ .., last] => {
423            let mut out = String::new();
424            for ele in initial {
425                let _ = write!(out, "{ele}, ");
426            }
427            let _ = write!(out, "{join_word} {last}");
428            Some(out)
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use strum::IntoEnumIterator;
437
438    mod subtype_relation {
439        use super::*;
440
441        #[test]
442        fn test_reflexivity() {
443            for ty in Type::iter() {
444                assert!(ty.is_subtype_of(&ty));
445            }
446        }
447
448        #[test]
449        fn test_any_is_top_type() {
450            for ty in Type::iter() {
451                assert!(ty.is_subtype_of(&Type::Any));
452            }
453        }
454
455        #[test]
456        fn test_number_supertype() {
457            assert!(Type::Int.is_subtype_of(&Type::Number));
458            assert!(Type::Float.is_subtype_of(&Type::Number));
459        }
460
461        #[test]
462        fn test_list_covariance() {
463            for ty1 in Type::iter() {
464                for ty2 in Type::iter() {
465                    let list_ty1 = Type::List(Box::new(ty1.clone()));
466                    let list_ty2 = Type::List(Box::new(ty2.clone()));
467                    assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
468                }
469            }
470        }
471    }
472
473    mod oneof {
474        use super::*;
475
476        #[test]
477        fn oneof_lhs() {
478            let rel = Type::one_of([Type::Int, Type::Nothing]).compare_types(&Type::Int);
479            assert_eq!(rel, Some(TypeRelation::Supertype));
480        }
481
482        #[test]
483        fn oneof_rhs() {
484            let rel = Type::Int.compare_types(&Type::one_of([Type::Int, Type::Nothing]));
485            assert_eq!(rel, Some(TypeRelation::Subtype));
486        }
487    }
488
489    mod oneof_flattening {
490        use super::*;
491
492        #[test]
493        fn test_oneof_creation_flattens() {
494            let nested = Type::one_of([
495                Type::String,
496                Type::one_of([Type::Int, Type::Float]),
497                Type::Bool,
498            ]);
499            if let Type::OneOf(oneof) = nested {
500                let types_vec: Vec<Type> = oneof.into_iter().collect();
501                assert_eq!(types_vec.len(), 3);
502                assert!(types_vec.contains(&Type::String));
503                assert!(types_vec.contains(&Type::Number));
504                assert!(types_vec.contains(&Type::Bool));
505            } else {
506                panic!("Expected OneOf");
507            }
508        }
509
510        #[test]
511        fn test_widen_flattens_oneof() {
512            let a = Type::one_of([Type::String, Type::Int]);
513            let b = Type::one_of([Type::Float, Type::Bool]);
514            let widened = a.union(b);
515            if let Type::OneOf(oneof) = widened {
516                let types_vec: Vec<Type> = oneof.into_iter().collect();
517                assert_eq!(types_vec.len(), 3);
518                assert!(types_vec.contains(&Type::String));
519                assert!(types_vec.contains(&Type::Number)); // Int + Float -> Number
520                assert!(types_vec.contains(&Type::Bool));
521            } else {
522                panic!("Expected OneOf");
523            }
524        }
525
526        #[test]
527        fn test_oneof_deduplicates() {
528            let record_type =
529                Type::Record(vec![("content".to_string(), Type::list(Type::String))].into());
530            let oneof = Type::one_of([Type::String, record_type.clone(), record_type.clone()]);
531            if let Type::OneOf(oneof) = oneof {
532                let types_vec: Vec<Type> = oneof.into_iter().collect();
533                assert_eq!(types_vec.len(), 2);
534                assert!(types_vec.contains(&Type::String));
535                assert!(types_vec.contains(&record_type));
536            } else {
537                panic!("Expected OneOf");
538            }
539        }
540    }
541
542    // regressions and performance tests for the subtype shortcut added above
543    mod widen_shortcuts {
544        use super::*;
545
546        #[test]
547        fn test_widen_subtype_shortcut() {
548            // widening a union that already covers the new type should return the original union unchanged.
549            let union = Type::one_of([Type::String, Type::Number]);
550            let result = union.clone().union(Type::Int);
551            assert_eq!(result, union);
552
553            // symmetric case where the left side is the subtype
554            let union2 = Type::one_of([Type::Int, Type::String]);
555            let result2 = Type::Int.union(union2.clone());
556            assert_eq!(result2, union2);
557        }
558
559        #[test]
560        fn test_chain_shortcut() {
561            // repeatedly widen the same type pair
562            let mut t = Type::String;
563            for _ in 0..100 {
564                t = t.union(Type::Int);
565            }
566            let expected = Type::one_of([Type::String, Type::Int]);
567            assert_eq!(t, expected);
568        }
569
570        #[test]
571        fn test_list_table_widen_preserves_list() {
572            let list_record = Type::list(Type::Record(vec![("a".to_string(), Type::Int)].into()));
573            let table = Type::Table(vec![("a".to_string(), Type::Int)].into());
574
575            let widened = list_record.clone().union(table.clone());
576            let expected = Type::one_of([list_record, table]);
577
578            assert_eq!(widened, expected);
579        }
580
581        #[test]
582        fn test_glob_string_union() {
583            let g = Type::Glob;
584            let s = Type::String;
585            let w1 = g.clone().union(s.clone());
586            let w2 = s.clone().union(g.clone());
587            let expected1 = Type::one_of([Type::Glob, Type::String]);
588            let expected2 = Type::one_of([Type::String, Type::Glob]);
589            assert_eq!(w1, expected1);
590            assert_eq!(w2, expected2);
591        }
592    }
593}