1use crate::SyntaxShape;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10 Any,
11 Binary,
12 Block,
13 Bool,
14 CellPath,
15 Closure,
16 Custom(Box<str>),
17 Date,
18 Duration,
19 Error,
20 Filesize,
21 Float,
22 Int,
23 List(Box<Type>),
24 #[default]
25 Nothing,
26 Number,
27 Range,
28 Record(Box<[(String, Type)]>),
29 String,
30 Glob,
31 Table(Box<[(String, Type)]>),
32}
33
34impl Type {
35 pub fn list(inner: Type) -> Self {
36 Self::List(Box::new(inner))
37 }
38
39 pub fn record() -> Self {
40 Self::Record([].into())
41 }
42
43 pub fn table() -> Self {
44 Self::Table([].into())
45 }
46
47 pub fn custom(name: impl Into<Box<str>>) -> Self {
48 Self::Custom(name.into())
49 }
50
51 pub fn is_subtype_of(&self, other: &Type) -> bool {
57 let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
59 if this.is_empty() || that.is_empty() {
60 true
61 } else if this.len() < that.len() {
62 false
63 } else {
64 that.iter().all(|(col_y, ty_y)| {
65 if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
66 ty_x.is_subtype_of(ty_y)
67 } else {
68 false
69 }
70 })
71 }
72 };
73
74 match (self, other) {
75 (t, u) if t == u => true,
76 (Type::Float, Type::Number) => true,
77 (Type::Int, Type::Number) => true,
78 (Type::Glob, Type::String) => true,
79 (Type::String, Type::Glob) => true,
80 (_, Type::Any) => true,
81 (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
83 is_subtype_collection(this, that)
84 }
85 (Type::Table(_), Type::List(that)) if matches!(**that, Type::Any) => true,
86 (Type::Table(this), Type::List(that)) => {
87 matches!(that.as_ref(), Type::Record(that) if is_subtype_collection(this, that))
88 }
89 (Type::List(this), Type::Table(that)) => {
90 matches!(this.as_ref(), Type::Record(this) if is_subtype_collection(this, that))
91 }
92 _ => false,
93 }
94 }
95
96 pub fn widen(self, other: Type) -> Type {
98 if self.is_subtype_of(&other) {
99 other
100 } else if other.is_subtype_of(&self) {
101 self
102 } else {
103 Type::Any
104 }
105 }
106
107 pub fn is_numeric(&self) -> bool {
108 matches!(self, Type::Int | Type::Float | Type::Number)
109 }
110
111 pub fn is_list(&self) -> bool {
112 matches!(self, Type::List(_))
113 }
114
115 pub fn accepts_cell_paths(&self) -> bool {
117 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
118 }
119
120 pub fn to_shape(&self) -> SyntaxShape {
121 let mk_shape = |tys: &[(String, Type)]| {
122 tys.iter()
123 .map(|(key, val)| (key.clone(), val.to_shape()))
124 .collect()
125 };
126
127 match self {
128 Type::Int => SyntaxShape::Int,
129 Type::Float => SyntaxShape::Float,
130 Type::Range => SyntaxShape::Range,
131 Type::Bool => SyntaxShape::Boolean,
132 Type::String => SyntaxShape::String,
133 Type::Block => SyntaxShape::Block, Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
136 Type::Duration => SyntaxShape::Duration,
137 Type::Date => SyntaxShape::DateTime,
138 Type::Filesize => SyntaxShape::Filesize,
139 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
140 Type::Number => SyntaxShape::Number,
141 Type::Nothing => SyntaxShape::Nothing,
142 Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
143 Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
144 Type::Any => SyntaxShape::Any,
145 Type::Error => SyntaxShape::Any,
146 Type::Binary => SyntaxShape::Binary,
147 Type::Custom(_) => SyntaxShape::Any,
148 Type::Glob => SyntaxShape::GlobPattern,
149 }
150 }
151
152 pub fn get_non_specified_string(&self) -> String {
155 match self {
156 Type::Closure => String::from("closure"),
157 Type::Bool => String::from("bool"),
158 Type::Block => String::from("block"),
159 Type::CellPath => String::from("cell-path"),
160 Type::Date => String::from("datetime"),
161 Type::Duration => String::from("duration"),
162 Type::Filesize => String::from("filesize"),
163 Type::Float => String::from("float"),
164 Type::Int => String::from("int"),
165 Type::Range => String::from("range"),
166 Type::Record(_) => String::from("record"),
167 Type::Table(_) => String::from("table"),
168 Type::List(_) => String::from("list"),
169 Type::Nothing => String::from("nothing"),
170 Type::Number => String::from("number"),
171 Type::String => String::from("string"),
172 Type::Any => String::from("any"),
173 Type::Error => String::from("error"),
174 Type::Binary => String::from("binary"),
175 Type::Custom(_) => String::from("custom"),
176 Type::Glob => String::from("glob"),
177 }
178 }
179}
180
181impl Display for Type {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 match self {
184 Type::Block => write!(f, "block"),
185 Type::Closure => write!(f, "closure"),
186 Type::Bool => write!(f, "bool"),
187 Type::CellPath => write!(f, "cell-path"),
188 Type::Date => write!(f, "datetime"),
189 Type::Duration => write!(f, "duration"),
190 Type::Filesize => write!(f, "filesize"),
191 Type::Float => write!(f, "float"),
192 Type::Int => write!(f, "int"),
193 Type::Range => write!(f, "range"),
194 Type::Record(fields) => {
195 if fields.is_empty() {
196 write!(f, "record")
197 } else {
198 write!(
199 f,
200 "record<{}>",
201 fields
202 .iter()
203 .map(|(x, y)| format!("{x}: {y}"))
204 .collect::<Vec<String>>()
205 .join(", "),
206 )
207 }
208 }
209 Type::Table(columns) => {
210 if columns.is_empty() {
211 write!(f, "table")
212 } else {
213 write!(
214 f,
215 "table<{}>",
216 columns
217 .iter()
218 .map(|(x, y)| format!("{x}: {y}"))
219 .collect::<Vec<String>>()
220 .join(", ")
221 )
222 }
223 }
224 Type::List(l) => write!(f, "list<{l}>"),
225 Type::Nothing => write!(f, "nothing"),
226 Type::Number => write!(f, "number"),
227 Type::String => write!(f, "string"),
228 Type::Any => write!(f, "any"),
229 Type::Error => write!(f, "error"),
230 Type::Binary => write!(f, "binary"),
231 Type::Custom(custom) => write!(f, "{custom}"),
232 Type::Glob => write!(f, "glob"),
233 }
234 }
235}
236
237pub fn combined_type_string(types: &[Type], join_word: &str) -> Option<String> {
241 use std::fmt::Write as _;
242 match types {
243 [] => None,
244 [one] => Some(one.to_string()),
245 [one, two] => Some(format!("{one} {join_word} {two}")),
246 [initial @ .., last] => {
247 let mut out = String::new();
248 for ele in initial {
249 let _ = write!(out, "{ele}, ");
250 }
251 let _ = write!(out, "{join_word} {last}");
252 Some(out)
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::Type;
260 use strum::IntoEnumIterator;
261
262 mod subtype_relation {
263 use super::*;
264
265 #[test]
266 fn test_reflexivity() {
267 for ty in Type::iter() {
268 assert!(ty.is_subtype_of(&ty));
269 }
270 }
271
272 #[test]
273 fn test_any_is_top_type() {
274 for ty in Type::iter() {
275 assert!(ty.is_subtype_of(&Type::Any));
276 }
277 }
278
279 #[test]
280 fn test_number_supertype() {
281 assert!(Type::Int.is_subtype_of(&Type::Number));
282 assert!(Type::Float.is_subtype_of(&Type::Number));
283 }
284
285 #[test]
286 fn test_list_covariance() {
287 for ty1 in Type::iter() {
288 for ty2 in Type::iter() {
289 let list_ty1 = Type::List(Box::new(ty1.clone()));
290 let list_ty2 = Type::List(Box::new(ty2.clone()));
291 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
292 }
293 }
294 }
295 }
296}