1use crate::SyntaxShape;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10 Any,
12 Binary,
13 Block,
14 Bool,
15 CellPath,
16 Closure,
17 Custom(Box<str>),
18 Date,
19 Duration,
20 Error,
21 Filesize,
22 Float,
23 Int,
24 List(Box<Type>),
25 #[default]
26 Nothing,
27 Number,
29 OneOf(Box<[Type]>),
31 Range,
32 Record(Box<[(String, Type)]>),
33 String,
34 Glob,
35 Table(Box<[(String, Type)]>),
36}
37
38impl Type {
39 pub fn list(inner: Type) -> Self {
40 Self::List(Box::new(inner))
41 }
42
43 pub fn one_of(types: impl IntoIterator<Item = Type>) -> Self {
44 Self::OneOf(types.into_iter().collect())
45 }
46
47 pub fn record() -> Self {
48 Self::Record([].into())
49 }
50
51 pub fn table() -> Self {
52 Self::Table([].into())
53 }
54
55 pub fn custom(name: impl Into<Box<str>>) -> Self {
56 Self::Custom(name.into())
57 }
58
59 pub fn is_subtype_of(&self, other: &Type) -> bool {
65 let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
67 if this.is_empty() || that.is_empty() {
68 true
69 } else if this.len() < that.len() {
70 false
71 } else {
72 that.iter().all(|(col_y, ty_y)| {
73 if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
74 ty_x.is_subtype_of(ty_y)
75 } else {
76 false
77 }
78 })
79 }
80 };
81
82 match (self, other) {
83 (t, u) if t == u => true,
84 (_, Type::Any) => true,
85 (Type::String | Type::Int, Type::CellPath) => true,
88 (Type::OneOf(oneof), Type::CellPath) => {
89 oneof.iter().all(|t| t.is_subtype_of(&Type::CellPath))
90 }
91 (Type::Float | Type::Int, Type::Number) => true,
92 (Type::Glob, Type::String) | (Type::String, Type::Glob) => true,
93 (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
95 is_subtype_collection(this, that)
96 }
97 (Type::Table(_), Type::List(that)) if matches!(**that, Type::Any) => true,
98 (Type::Table(this), Type::List(that)) => {
99 matches!(that.as_ref(), Type::Record(that) if is_subtype_collection(this, that))
100 }
101 (Type::List(this), Type::Table(that)) => {
102 matches!(this.as_ref(), Type::Record(this) if is_subtype_collection(this, that))
103 }
104 (Type::OneOf(this), that @ Type::OneOf(_)) => {
105 this.iter().all(|t| t.is_subtype_of(that))
106 }
107 (this, Type::OneOf(that)) => that.iter().any(|t| this.is_subtype_of(t)),
108 _ => false,
109 }
110 }
111
112 pub fn widen(self, other: Type) -> Type {
114 fn flat_widen(lhs: Type, rhs: Type) -> Result<Type, (Type, Type)> {
117 Ok(match (lhs, rhs) {
118 (lhs, rhs) if lhs == rhs => lhs,
119 (Type::Any, _) | (_, Type::Any) => Type::Any,
120 (
122 Type::Int | Type::Float | Type::Number,
123 Type::Int | Type::Float | Type::Number,
124 ) => Type::Number,
125
126 (Type::Glob, Type::String) | (Type::String, Type::Glob) => Type::String,
127 (Type::Record(this), Type::Record(that)) => {
128 Type::Record(widen_collection(this, that))
129 }
130 (Type::Table(this), Type::Table(that)) => Type::Table(widen_collection(this, that)),
131 (Type::List(list_item), Type::Table(table))
132 | (Type::Table(table), Type::List(list_item)) => {
133 let item = match *list_item {
134 Type::Record(record) => Type::Record(widen_collection(record, table)),
135 list_item => Type::one_of([list_item, Type::Record(table)]),
136 };
137 Type::List(Box::new(item))
138 }
139 (Type::List(lhs), Type::List(rhs)) => Type::list(lhs.widen(*rhs)),
140 (t, u) => return Err((t, u)),
141 })
142 }
143 fn widen_collection(
144 lhs: Box<[(String, Type)]>,
145 rhs: Box<[(String, Type)]>,
146 ) -> Box<[(String, Type)]> {
147 if lhs.is_empty() || rhs.is_empty() {
148 return [].into();
149 }
150 let (small, big) = match lhs.len() <= rhs.len() {
151 true => (lhs, rhs),
152 false => (rhs, lhs),
153 };
154 small
155 .into_iter()
156 .filter_map(|(col, typ)| {
157 big.iter()
158 .find_map(|(b_col, b_typ)| (&col == b_col).then(|| b_typ.clone()))
159 .map(|b_typ| (col, typ, b_typ))
160 })
161 .map(|(col, t, u)| (col, t.widen(u)))
162 .collect()
163 }
164
165 fn oneof_add(oneof: &mut Vec<Type>, mut t: Type) {
166 if oneof.contains(&t) {
167 return;
168 }
169
170 for one in oneof.iter_mut() {
171 match flat_widen(std::mem::replace(one, Type::Any), t) {
172 Ok(one_t) => {
173 *one = one_t;
174 return;
175 }
176 Err((one_, t_)) => {
177 *one = one_;
178 t = t_;
179 }
180 }
181 }
182
183 oneof.push(t);
184 }
185
186 let tu = match flat_widen(self, other) {
187 Ok(t) => return t,
188 Err(tu) => tu,
189 };
190
191 match tu {
192 (Type::OneOf(ts), Type::OneOf(us)) => {
193 let (big, small) = match ts.len() >= us.len() {
194 true => (ts, us),
195 false => (us, ts),
196 };
197 let mut out = big.into_vec();
198 for t in small.into_iter() {
199 oneof_add(&mut out, t);
200 }
201 Type::one_of(out)
202 }
203 (Type::OneOf(oneof), t) | (t, Type::OneOf(oneof)) => {
204 let mut out = oneof.into_vec();
205 oneof_add(&mut out, t);
206 Type::one_of(out)
207 }
208 (this, other) => Type::one_of([this, other]),
209 }
210 }
211
212 pub fn supertype_of(it: impl IntoIterator<Item = Type>) -> Option<Self> {
214 let mut it = it.into_iter();
215 it.next().and_then(|head| {
216 it.try_fold(head, |acc, e| match acc.widen(e) {
217 Type::Any => None,
218 r => Some(r),
219 })
220 })
221 }
222
223 pub fn is_numeric(&self) -> bool {
224 matches!(self, Type::Int | Type::Float | Type::Number)
225 }
226
227 pub fn is_list(&self) -> bool {
228 matches!(self, Type::List(_))
229 }
230
231 pub fn accepts_cell_paths(&self) -> bool {
233 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
234 }
235
236 pub fn to_shape(&self) -> SyntaxShape {
237 let mk_shape = |tys: &[(String, Type)]| {
238 tys.iter()
239 .map(|(key, val)| (key.clone(), val.to_shape()))
240 .collect()
241 };
242
243 match self {
244 Type::Int => SyntaxShape::Int,
245 Type::Float => SyntaxShape::Float,
246 Type::Range => SyntaxShape::Range,
247 Type::Bool => SyntaxShape::Boolean,
248 Type::String => SyntaxShape::String,
249 Type::Block => SyntaxShape::Block, Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
252 Type::Duration => SyntaxShape::Duration,
253 Type::Date => SyntaxShape::DateTime,
254 Type::Filesize => SyntaxShape::Filesize,
255 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
256 Type::Number => SyntaxShape::Number,
257 Type::OneOf(types) => SyntaxShape::OneOf(types.iter().map(Type::to_shape).collect()),
258 Type::Nothing => SyntaxShape::Nothing,
259 Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
260 Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
261 Type::Any => SyntaxShape::Any,
262 Type::Error => SyntaxShape::Any,
263 Type::Binary => SyntaxShape::Binary,
264 Type::Custom(_) => SyntaxShape::Any,
265 Type::Glob => SyntaxShape::GlobPattern,
266 }
267 }
268
269 pub fn get_non_specified_string(&self) -> String {
272 match self {
273 Type::Closure => String::from("closure"),
274 Type::Bool => String::from("bool"),
275 Type::Block => String::from("block"),
276 Type::CellPath => String::from("cell-path"),
277 Type::Date => String::from("datetime"),
278 Type::Duration => String::from("duration"),
279 Type::Filesize => String::from("filesize"),
280 Type::Float => String::from("float"),
281 Type::Int => String::from("int"),
282 Type::Range => String::from("range"),
283 Type::Record(_) => String::from("record"),
284 Type::Table(_) => String::from("table"),
285 Type::List(_) => String::from("list"),
286 Type::Nothing => String::from("nothing"),
287 Type::Number => String::from("number"),
288 Type::OneOf(_) => String::from("oneof"),
289 Type::String => String::from("string"),
290 Type::Any => String::from("any"),
291 Type::Error => String::from("error"),
292 Type::Binary => String::from("binary"),
293 Type::Custom(_) => String::from("custom"),
294 Type::Glob => String::from("glob"),
295 }
296 }
297}
298
299impl Display for Type {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 match self {
302 Type::Block => write!(f, "block"),
303 Type::Closure => write!(f, "closure"),
304 Type::Bool => write!(f, "bool"),
305 Type::CellPath => write!(f, "cell-path"),
306 Type::Date => write!(f, "datetime"),
307 Type::Duration => write!(f, "duration"),
308 Type::Filesize => write!(f, "filesize"),
309 Type::Float => write!(f, "float"),
310 Type::Int => write!(f, "int"),
311 Type::Range => write!(f, "range"),
312 Type::Record(fields) => {
313 if fields.is_empty() {
314 write!(f, "record")
315 } else {
316 write!(
317 f,
318 "record<{}>",
319 fields
320 .iter()
321 .map(|(x, y)| format!("{x}: {y}"))
322 .collect::<Vec<String>>()
323 .join(", "),
324 )
325 }
326 }
327 Type::Table(columns) => {
328 if columns.is_empty() {
329 write!(f, "table")
330 } else {
331 write!(
332 f,
333 "table<{}>",
334 columns
335 .iter()
336 .map(|(x, y)| format!("{x}: {y}"))
337 .collect::<Vec<String>>()
338 .join(", ")
339 )
340 }
341 }
342 Type::List(l) => write!(f, "list<{l}>"),
343 Type::Nothing => write!(f, "nothing"),
344 Type::Number => write!(f, "number"),
345 Type::OneOf(types) => {
346 write!(f, "oneof")?;
347 let [first, rest @ ..] = &**types else {
348 return Ok(());
349 };
350 write!(f, "<{first}")?;
351 for t in rest {
352 write!(f, ", {t}")?;
353 }
354 f.write_str(">")
355 }
356 Type::String => write!(f, "string"),
357 Type::Any => write!(f, "any"),
358 Type::Error => write!(f, "error"),
359 Type::Binary => write!(f, "binary"),
360 Type::Custom(custom) => write!(f, "{custom}"),
361 Type::Glob => write!(f, "glob"),
362 }
363 }
364}
365
366pub fn combined_type_string(types: &[Type], join_word: &str) -> Option<String> {
370 use std::fmt::Write as _;
371 match types {
372 [] => None,
373 [one] => Some(one.to_string()),
374 [one, two] => Some(format!("{one} {join_word} {two}")),
375 [initial @ .., last] => {
376 let mut out = String::new();
377 for ele in initial {
378 let _ = write!(out, "{ele}, ");
379 }
380 let _ = write!(out, "{join_word} {last}");
381 Some(out)
382 }
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::Type;
389 use strum::IntoEnumIterator;
390
391 mod subtype_relation {
392 use super::*;
393
394 #[test]
395 fn test_reflexivity() {
396 for ty in Type::iter() {
397 assert!(ty.is_subtype_of(&ty));
398 }
399 }
400
401 #[test]
402 fn test_any_is_top_type() {
403 for ty in Type::iter() {
404 assert!(ty.is_subtype_of(&Type::Any));
405 }
406 }
407
408 #[test]
409 fn test_number_supertype() {
410 assert!(Type::Int.is_subtype_of(&Type::Number));
411 assert!(Type::Float.is_subtype_of(&Type::Number));
412 }
413
414 #[test]
415 fn test_list_covariance() {
416 for ty1 in Type::iter() {
417 for ty2 in Type::iter() {
418 let list_ty1 = Type::List(Box::new(ty1.clone()));
419 let list_ty2 = Type::List(Box::new(ty2.clone()));
420 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
421 }
422 }
423 }
424 }
425}