1use crate::SyntaxShape;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4#[cfg(test)]
5use strum_macros::EnumIter;
6
7#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)]
8#[cfg_attr(test, derive(EnumIter))]
9pub enum Type {
10 Any,
11 Binary,
12 Block,
13 Bool,
14 CellPath,
15 Closure,
16 Custom(Box<str>),
17 Date,
18 Duration,
19 Error,
20 Filesize,
21 Float,
22 Int,
23 List(Box<Type>),
24 #[default]
25 Nothing,
26 Number,
27 Range,
28 Record(Box<[(String, Type)]>),
29 String,
30 Glob,
31 Table(Box<[(String, Type)]>),
32}
33
34impl Type {
35 pub fn list(inner: Type) -> Self {
36 Self::List(Box::new(inner))
37 }
38
39 pub fn record() -> Self {
40 Self::Record([].into())
41 }
42
43 pub fn table() -> Self {
44 Self::Table([].into())
45 }
46
47 pub fn custom(name: impl Into<Box<str>>) -> Self {
48 Self::Custom(name.into())
49 }
50
51 pub fn is_subtype_of(&self, other: &Type) -> bool {
57 let is_subtype_collection = |this: &[(String, Type)], that: &[(String, Type)]| {
59 if this.is_empty() || that.is_empty() {
60 true
61 } else if this.len() < that.len() {
62 false
63 } else {
64 that.iter().all(|(col_y, ty_y)| {
65 if let Some((_, ty_x)) = this.iter().find(|(col_x, _)| col_x == col_y) {
66 ty_x.is_subtype_of(ty_y)
67 } else {
68 false
69 }
70 })
71 }
72 };
73
74 match (self, other) {
75 (t, u) if t == u => true,
76 (Type::Float, Type::Number) => true,
77 (Type::Int, Type::Number) => true,
78 (_, Type::Any) => true,
79 (Type::List(t), Type::List(u)) if t.is_subtype_of(u) => true, (Type::Record(this), Type::Record(that)) | (Type::Table(this), Type::Table(that)) => {
81 is_subtype_collection(this, that)
82 }
83 (Type::Table(_), Type::List(_)) => true,
84 _ => false,
85 }
86 }
87
88 pub fn is_numeric(&self) -> bool {
89 matches!(self, Type::Int | Type::Float | Type::Number)
90 }
91
92 pub fn is_list(&self) -> bool {
93 matches!(self, Type::List(_))
94 }
95
96 pub fn accepts_cell_paths(&self) -> bool {
98 matches!(self, Type::List(_) | Type::Record(_) | Type::Table(_))
99 }
100
101 pub fn to_shape(&self) -> SyntaxShape {
102 let mk_shape = |tys: &[(String, Type)]| {
103 tys.iter()
104 .map(|(key, val)| (key.clone(), val.to_shape()))
105 .collect()
106 };
107
108 match self {
109 Type::Int => SyntaxShape::Int,
110 Type::Float => SyntaxShape::Float,
111 Type::Range => SyntaxShape::Range,
112 Type::Bool => SyntaxShape::Boolean,
113 Type::String => SyntaxShape::String,
114 Type::Block => SyntaxShape::Block, Type::Closure => SyntaxShape::Closure(None), Type::CellPath => SyntaxShape::CellPath,
117 Type::Duration => SyntaxShape::Duration,
118 Type::Date => SyntaxShape::DateTime,
119 Type::Filesize => SyntaxShape::Filesize,
120 Type::List(x) => SyntaxShape::List(Box::new(x.to_shape())),
121 Type::Number => SyntaxShape::Number,
122 Type::Nothing => SyntaxShape::Nothing,
123 Type::Record(entries) => SyntaxShape::Record(mk_shape(entries)),
124 Type::Table(columns) => SyntaxShape::Table(mk_shape(columns)),
125 Type::Any => SyntaxShape::Any,
126 Type::Error => SyntaxShape::Any,
127 Type::Binary => SyntaxShape::Binary,
128 Type::Custom(_) => SyntaxShape::Any,
129 Type::Glob => SyntaxShape::GlobPattern,
130 }
131 }
132
133 pub fn get_non_specified_string(&self) -> String {
136 match self {
137 Type::Closure => String::from("closure"),
138 Type::Bool => String::from("bool"),
139 Type::Block => String::from("block"),
140 Type::CellPath => String::from("cell-path"),
141 Type::Date => String::from("datetime"),
142 Type::Duration => String::from("duration"),
143 Type::Filesize => String::from("filesize"),
144 Type::Float => String::from("float"),
145 Type::Int => String::from("int"),
146 Type::Range => String::from("range"),
147 Type::Record(_) => String::from("record"),
148 Type::Table(_) => String::from("table"),
149 Type::List(_) => String::from("list"),
150 Type::Nothing => String::from("nothing"),
151 Type::Number => String::from("number"),
152 Type::String => String::from("string"),
153 Type::Any => String::from("any"),
154 Type::Error => String::from("error"),
155 Type::Binary => String::from("binary"),
156 Type::Custom(_) => String::from("custom"),
157 Type::Glob => String::from("glob"),
158 }
159 }
160}
161
162impl Display for Type {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 match self {
165 Type::Block => write!(f, "block"),
166 Type::Closure => write!(f, "closure"),
167 Type::Bool => write!(f, "bool"),
168 Type::CellPath => write!(f, "cell-path"),
169 Type::Date => write!(f, "datetime"),
170 Type::Duration => write!(f, "duration"),
171 Type::Filesize => write!(f, "filesize"),
172 Type::Float => write!(f, "float"),
173 Type::Int => write!(f, "int"),
174 Type::Range => write!(f, "range"),
175 Type::Record(fields) => {
176 if fields.is_empty() {
177 write!(f, "record")
178 } else {
179 write!(
180 f,
181 "record<{}>",
182 fields
183 .iter()
184 .map(|(x, y)| format!("{x}: {y}"))
185 .collect::<Vec<String>>()
186 .join(", "),
187 )
188 }
189 }
190 Type::Table(columns) => {
191 if columns.is_empty() {
192 write!(f, "table")
193 } else {
194 write!(
195 f,
196 "table<{}>",
197 columns
198 .iter()
199 .map(|(x, y)| format!("{x}: {y}"))
200 .collect::<Vec<String>>()
201 .join(", ")
202 )
203 }
204 }
205 Type::List(l) => write!(f, "list<{l}>"),
206 Type::Nothing => write!(f, "nothing"),
207 Type::Number => write!(f, "number"),
208 Type::String => write!(f, "string"),
209 Type::Any => write!(f, "any"),
210 Type::Error => write!(f, "error"),
211 Type::Binary => write!(f, "binary"),
212 Type::Custom(custom) => write!(f, "{custom}"),
213 Type::Glob => write!(f, "glob"),
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::Type;
221 use strum::IntoEnumIterator;
222
223 mod subtype_relation {
224 use super::*;
225
226 #[test]
227 fn test_reflexivity() {
228 for ty in Type::iter() {
229 assert!(ty.is_subtype_of(&ty));
230 }
231 }
232
233 #[test]
234 fn test_any_is_top_type() {
235 for ty in Type::iter() {
236 assert!(ty.is_subtype_of(&Type::Any));
237 }
238 }
239
240 #[test]
241 fn test_number_supertype() {
242 assert!(Type::Int.is_subtype_of(&Type::Number));
243 assert!(Type::Float.is_subtype_of(&Type::Number));
244 }
245
246 #[test]
247 fn test_list_covariance() {
248 for ty1 in Type::iter() {
249 for ty2 in Type::iter() {
250 let list_ty1 = Type::List(Box::new(ty1.clone()));
251 let list_ty2 = Type::List(Box::new(ty2.clone()));
252 assert_eq!(list_ty1.is_subtype_of(&list_ty2), ty1.is_subtype_of(&ty2));
253 }
254 }
255 }
256 }
257}