1use crate::SyntaxShape;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10    Any,
11    Binary,
12    Bool,
13    CellPath,
14    Closure,
15    Custom(Box<str>),
16    Date,
17    Duration,
18    Error,
19    Filesize,
20    Float,
21    Int,
22    List(Box<Type>),
23    #[default]
24    Nothing,
25    Number,
26    Range,
27    Record(Box<[(String, Type)]>),
28    String,
29    Glob,
30    Table(Box<[(String, Type)]>),
31}
32
33impl Type {
34    pub fn list(inner: Type) -> Self {
35        Self::List(Box::new(inner))
36    }
37
38    pub fn record() -> Self {
39        Self::Record([].into())
40    }
41
42    pub fn table() -> Self {
43        Self::Table([].into())
44    }
45
46    pub fn custom(name: impl Into<Box<str>>) -> Self {
47        Self::Custom(name.into())
48    }
49
50    pub fn is_subtype_of(&self, other: &Type) -> bool {
56        let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
58            if this.is_empty() || that.is_empty() {
59                true
60            } else if this.len() < that.len() {
61                false
62            } else {
63                that.iter().all(|(col_y, ty_y)| {
64                    if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
65                        ty_x.is_subtype_of(ty_y)
66                    } else {
67                        false
68                    }
69                })
70            }
71        };
72
73        match (self, other) {
74            (t, u) if t == u => true,
75            (Type::Float, Type::Number) => true,
76            (Type::Int, Type::Number) => true,
77            (_, Type::Any) => true,
78            (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
80                is_subtype_collection(this, that)
81            }
82            (Type::Table(_), Type::List(_)) => true,
83            _ => false,
84        }
85    }
86
87    pub fn is_numeric(&self) -> bool {
88        matches!(self, Type::Int | Type::Float | Type::Number)
89    }
90
91    pub fn is_list(&self) -> bool {
92        matches!(self, Type::List(_))
93    }
94
95    pub fn accepts_cell_paths(&self) -> bool {
97        matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
98    }
99
100    pub fn to_shape(&self) -> SyntaxShape {
101        let mk_shape = |tys: &[(String, Type)]| {
102            tys.iter()
103                .map(|(key, val)| (key.clone(), val.to_shape()))
104                .collect()
105        };
106
107        match self {
108            Type::Int => SyntaxShape::Int,
109            Type::Float => SyntaxShape::Float,
110            Type::Range => SyntaxShape::Range,
111            Type::Bool => SyntaxShape::Boolean,
112            Type::String => SyntaxShape::String,
113            Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
115            Type::Duration => SyntaxShape::Duration,
116            Type::Date => SyntaxShape::DateTime,
117            Type::Filesize => SyntaxShape::Filesize,
118            Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
119            Type::Number => SyntaxShape::Number,
120            Type::Nothing => SyntaxShape::Nothing,
121            Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
122            Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
123            Type::Any => SyntaxShape::Any,
124            Type::Error => SyntaxShape::Any,
125            Type::Binary => SyntaxShape::Binary,
126            Type::Custom(_) => SyntaxShape::Any,
127            Type::Glob => SyntaxShape::GlobPattern,
128        }
129    }
130
131    pub fn get_non_specified_string(&self) -> String {
134        match self {
135            Type::Closure => String::from("closure"),
136            Type::Bool => String::from("bool"),
137            Type::CellPath => String::from("cell-path"),
138            Type::Date => String::from("datetime"),
139            Type::Duration => String::from("duration"),
140            Type::Filesize => String::from("filesize"),
141            Type::Float => String::from("float"),
142            Type::Int => String::from("int"),
143            Type::Range => String::from("range"),
144            Type::Record(_) => String::from("record"),
145            Type::Table(_) => String::from("table"),
146            Type::List(_) => String::from("list"),
147            Type::Nothing => String::from("nothing"),
148            Type::Number => String::from("number"),
149            Type::String => String::from("string"),
150            Type::Any => String::from("any"),
151            Type::Error => String::from("error"),
152            Type::Binary => String::from("binary"),
153            Type::Custom(_) => String::from("custom"),
154            Type::Glob => String::from("glob"),
155        }
156    }
157}
158
159impl Display for Type {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Type::Closure => write!(f, "closure"),
163            Type::Bool => write!(f, "bool"),
164            Type::CellPath => write!(f, "cell-path"),
165            Type::Date => write!(f, "datetime"),
166            Type::Duration => write!(f, "duration"),
167            Type::Filesize => write!(f, "filesize"),
168            Type::Float => write!(f, "float"),
169            Type::Int => write!(f, "int"),
170            Type::Range => write!(f, "range"),
171            Type::Record(fields) => {
172                if fields.is_empty() {
173                    write!(f, "record")
174                } else {
175                    write!(
176                        f,
177                        "record<{}>",
178                        fields
179                            .iter()
180                            .map(|(x, y)| format!("{x}: {y}"))
181                            .collect::<Vec<String>>()
182                            .join(", "),
183                    )
184                }
185            }
186            Type::Table(columns) => {
187                if columns.is_empty() {
188                    write!(f, "table")
189                } else {
190                    write!(
191                        f,
192                        "table<{}>",
193                        columns
194                            .iter()
195                            .map(|(x, y)| format!("{x}: {y}"))
196                            .collect::<Vec<String>>()
197                            .join(", ")
198                    )
199                }
200            }
201            Type::List(l) => write!(f, "list<{l}>"),
202            Type::Nothing => write!(f, "nothing"),
203            Type::Number => write!(f, "number"),
204            Type::String => write!(f, "string"),
205            Type::Any => write!(f, "any"),
206            Type::Error => write!(f, "error"),
207            Type::Binary => write!(f, "binary"),
208            Type::Custom(custom) => write!(f, "{custom}"),
209            Type::Glob => write!(f, "glob"),
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::Type;
217    use strum::IntoEnumIterator;
218
219    mod subtype_relation {
220        use super::*;
221
222        #[test]
223        fn test_reflexivity() {
224            for ty in Type::iter() {
225                assert!(ty.is_subtype_of(&ty));
226            }
227        }
228
229        #[test]
230        fn test_any_is_top_type() {
231            for ty in Type::iter() {
232                assert!(ty.is_subtype_of(&Type::Any));
233            }
234        }
235
236        #[test]
237        fn test_number_supertype() {
238            assert!(Type::Int.is_subtype_of(&Type::Number));
239            assert!(Type::Float.is_subtype_of(&Type::Number));
240        }
241
242        #[test]
243        fn test_list_covariance() {
244            for ty1 in Type::iter() {
245                for ty2 in Type::iter() {
246                    let list_ty1 = Type::List(Box::new(ty1.clone()));
247                    let list_ty2 = Type::List(Box::new(ty2.clone()));
248                    assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
249                }
250            }
251        }
252    }
253}