1use crate::SyntaxShape;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10 Any,
11 Binary,
12 Bool,
13 CellPath,
14 Closure,
15 Custom(Box<str>),
16 Date,
17 Duration,
18 Error,
19 Filesize,
20 Float,
21 Int,
22 List(Box<Type>),
23 #[default]
24 Nothing,
25 Number,
26 Range,
27 Record(Box<[(String, Type)]>),
28 String,
29 Glob,
30 Table(Box<[(String, Type)]>),
31}
32
33impl Type {
34 pub fn list(inner: Type) -> Self {
35 Self::List(Box::new(inner))
36 }
37
38 pub fn record() -> Self {
39 Self::Record([].into())
40 }
41
42 pub fn table() -> Self {
43 Self::Table([].into())
44 }
45
46 pub fn custom(name: impl Into<Box<str>>) -> Self {
47 Self::Custom(name.into())
48 }
49
50 pub fn is_subtype_of(&self, other: &Type) -> bool {
56 let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
58 if this.is_empty() || that.is_empty() {
59 true
60 } else if this.len() < that.len() {
61 false
62 } else {
63 that.iter().all(|(col_y, ty_y)| {
64 if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
65 ty_x.is_subtype_of(ty_y)
66 } else {
67 false
68 }
69 })
70 }
71 };
72
73 match (self, other) {
74 (t, u) if t == u => true,
75 (Type::Float, Type::Number) => true,
76 (Type::Int, Type::Number) => true,
77 (_, Type::Any) => true,
78 (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
80 is_subtype_collection(this, that)
81 }
82 (Type::Table(_), Type::List(_)) => true,
83 _ => false,
84 }
85 }
86
87 pub fn is_numeric(&self) -> bool {
88 matches!(self, Type::Int | Type::Float | Type::Number)
89 }
90
91 pub fn is_list(&self) -> bool {
92 matches!(self, Type::List(_))
93 }
94
95 pub fn accepts_cell_paths(&self) -> bool {
97 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
98 }
99
100 pub fn to_shape(&self) -> SyntaxShape {
101 let mk_shape = |tys: &[(String, Type)]| {
102 tys.iter()
103 .map(|(key, val)| (key.clone(), val.to_shape()))
104 .collect()
105 };
106
107 match self {
108 Type::Int => SyntaxShape::Int,
109 Type::Float => SyntaxShape::Float,
110 Type::Range => SyntaxShape::Range,
111 Type::Bool => SyntaxShape::Boolean,
112 Type::String => SyntaxShape::String,
113 Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
115 Type::Duration => SyntaxShape::Duration,
116 Type::Date => SyntaxShape::DateTime,
117 Type::Filesize => SyntaxShape::Filesize,
118 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
119 Type::Number => SyntaxShape::Number,
120 Type::Nothing => SyntaxShape::Nothing,
121 Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
122 Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
123 Type::Any => SyntaxShape::Any,
124 Type::Error => SyntaxShape::Any,
125 Type::Binary => SyntaxShape::Binary,
126 Type::Custom(_) => SyntaxShape::Any,
127 Type::Glob => SyntaxShape::GlobPattern,
128 }
129 }
130
131 pub fn get_non_specified_string(&self) -> String {
134 match self {
135 Type::Closure => String::from("closure"),
136 Type::Bool => String::from("bool"),
137 Type::CellPath => String::from("cell-path"),
138 Type::Date => String::from("datetime"),
139 Type::Duration => String::from("duration"),
140 Type::Filesize => String::from("filesize"),
141 Type::Float => String::from("float"),
142 Type::Int => String::from("int"),
143 Type::Range => String::from("range"),
144 Type::Record(_) => String::from("record"),
145 Type::Table(_) => String::from("table"),
146 Type::List(_) => String::from("list"),
147 Type::Nothing => String::from("nothing"),
148 Type::Number => String::from("number"),
149 Type::String => String::from("string"),
150 Type::Any => String::from("any"),
151 Type::Error => String::from("error"),
152 Type::Binary => String::from("binary"),
153 Type::Custom(_) => String::from("custom"),
154 Type::Glob => String::from("glob"),
155 }
156 }
157}
158
159impl Display for Type {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 Type::Closure => write!(f, "closure"),
163 Type::Bool => write!(f, "bool"),
164 Type::CellPath => write!(f, "cell-path"),
165 Type::Date => write!(f, "datetime"),
166 Type::Duration => write!(f, "duration"),
167 Type::Filesize => write!(f, "filesize"),
168 Type::Float => write!(f, "float"),
169 Type::Int => write!(f, "int"),
170 Type::Range => write!(f, "range"),
171 Type::Record(fields) => {
172 if fields.is_empty() {
173 write!(f, "record")
174 } else {
175 write!(
176 f,
177 "record<{}>",
178 fields
179 .iter()
180 .map(|(x, y)| format!("{x}: {y}"))
181 .collect::<Vec<String>>()
182 .join(", "),
183 )
184 }
185 }
186 Type::Table(columns) => {
187 if columns.is_empty() {
188 write!(f, "table")
189 } else {
190 write!(
191 f,
192 "table<{}>",
193 columns
194 .iter()
195 .map(|(x, y)| format!("{x}: {y}"))
196 .collect::<Vec<String>>()
197 .join(", ")
198 )
199 }
200 }
201 Type::List(l) => write!(f, "list<{l}>"),
202 Type::Nothing => write!(f, "nothing"),
203 Type::Number => write!(f, "number"),
204 Type::String => write!(f, "string"),
205 Type::Any => write!(f, "any"),
206 Type::Error => write!(f, "error"),
207 Type::Binary => write!(f, "binary"),
208 Type::Custom(custom) => write!(f, "{custom}"),
209 Type::Glob => write!(f, "glob"),
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::Type;
217 use strum::IntoEnumIterator;
218
219 mod subtype_relation {
220 use super::*;
221
222 #[test]
223 fn test_reflexivity() {
224 for ty in Type::iter() {
225 assert!(ty.is_subtype_of(&ty));
226 }
227 }
228
229 #[test]
230 fn test_any_is_top_type() {
231 for ty in Type::iter() {
232 assert!(ty.is_subtype_of(&Type::Any));
233 }
234 }
235
236 #[test]
237 fn test_number_supertype() {
238 assert!(Type::Int.is_subtype_of(&Type::Number));
239 assert!(Type::Float.is_subtype_of(&Type::Number));
240 }
241
242 #[test]
243 fn test_list_covariance() {
244 for ty1 in Type::iter() {
245 for ty2 in Type::iter() {
246 let list_ty1 = Type::List(Box::new(ty1.clone()));
247 let list_ty2 = Type::List(Box::new(ty2.clone()));
248 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
249 }
250 }
251 }
252 }
253}