1use crate::physical_type::PhysicalType;
2use smol_str::SmolStr;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
8pub enum LogicalType {
9 Any,
10 Bool,
11 Int8,
12 Int16,
13 Int32,
14 Int64,
15 Int128,
16 UInt8,
17 UInt16,
18 UInt32,
19 UInt64,
20 Float,
21 Double,
22 Date,
23 Timestamp,
24 TimestampSec,
25 TimestampMs,
26 TimestampNs,
27 TimestampTz,
28 Interval,
29 Decimal {
30 precision: u8,
31 scale: u8,
32 },
33 InternalId,
34 Serial,
35 String,
36 Blob,
37 Uuid,
38 Node,
39 Rel,
40 RecursiveRel,
41 List(Box<LogicalType>),
42 Array {
43 element: Box<LogicalType>,
44 size: u64,
45 },
46 Struct(Vec<(SmolStr, LogicalType)>),
47 Map {
48 key: Box<LogicalType>,
49 value: Box<LogicalType>,
50 },
51 Union(Vec<(SmolStr, LogicalType)>),
52 Pointer,
53}
54
55impl LogicalType {
56 pub fn physical_type(&self) -> PhysicalType {
58 match self {
59 Self::Any => PhysicalType::String,
60 Self::Bool => PhysicalType::Bool,
61 Self::Int8 => PhysicalType::Int8,
62 Self::Int16 => PhysicalType::Int16,
63 Self::Int32 => PhysicalType::Int32,
64 Self::Int64
65 | Self::Date
66 | Self::Timestamp
67 | Self::TimestampSec
68 | Self::TimestampMs
69 | Self::TimestampNs
70 | Self::TimestampTz
71 | Self::Serial => PhysicalType::Int64,
72 Self::Int128 | Self::Decimal { .. } => PhysicalType::Int128,
73 Self::UInt8 => PhysicalType::UInt8,
74 Self::UInt16 => PhysicalType::UInt16,
75 Self::UInt32 => PhysicalType::UInt32,
76 Self::UInt64 => PhysicalType::UInt64,
77 Self::Float => PhysicalType::Float32,
78 Self::Double => PhysicalType::Float64,
79 Self::Interval => PhysicalType::Interval,
80 Self::InternalId => PhysicalType::InternalId,
81 Self::String | Self::Blob | Self::Uuid => PhysicalType::String,
82 Self::List(_) | Self::Map { .. } => PhysicalType::List,
83 Self::Array { .. } => PhysicalType::Array,
84 Self::Struct(_) | Self::Node | Self::Rel | Self::RecursiveRel | Self::Union(_) => {
85 PhysicalType::Struct
86 }
87 Self::Pointer => PhysicalType::Int64,
88 }
89 }
90
91 pub fn type_name(&self) -> std::borrow::Cow<'static, str> {
93 match self {
94 Self::Any => "ANY".into(),
95 Self::Bool => "BOOL".into(),
96 Self::Int8 => "INT8".into(),
97 Self::Int16 => "INT16".into(),
98 Self::Int32 => "INT32".into(),
99 Self::Int64 => "INT64".into(),
100 Self::Int128 => "INT128".into(),
101 Self::UInt8 => "UINT8".into(),
102 Self::UInt16 => "UINT16".into(),
103 Self::UInt32 => "UINT32".into(),
104 Self::UInt64 => "UINT64".into(),
105 Self::Float => "FLOAT".into(),
106 Self::Double => "DOUBLE".into(),
107 Self::Date => "DATE".into(),
108 Self::Timestamp => "TIMESTAMP".into(),
109 Self::TimestampSec => "TIMESTAMP_SEC".into(),
110 Self::TimestampMs => "TIMESTAMP_MS".into(),
111 Self::TimestampNs => "TIMESTAMP_NS".into(),
112 Self::TimestampTz => "TIMESTAMP_TZ".into(),
113 Self::Interval => "INTERVAL".into(),
114 Self::Decimal { precision, scale } => format!("DECIMAL({precision},{scale})").into(),
115 Self::InternalId => "INTERNAL_ID".into(),
116 Self::Serial => "SERIAL".into(),
117 Self::String => "STRING".into(),
118 Self::Blob => "BLOB".into(),
119 Self::Uuid => "UUID".into(),
120 Self::Node => "NODE".into(),
121 Self::Rel => "REL".into(),
122 Self::RecursiveRel => "RECURSIVE_REL".into(),
123 Self::List(inner) => format!("{}[]", inner.type_name()).into(),
124 Self::Array { element, size } => format!("{}[{size}]", element.type_name()).into(),
125 Self::Struct(fields) => {
126 let fields_str: Vec<_> = fields
127 .iter()
128 .map(|(name, ty)| format!("{name}: {}", ty.type_name()))
129 .collect();
130 format!("STRUCT({})", fields_str.join(", ")).into()
131 }
132 Self::Map { key, value } => {
133 format!("MAP({}, {})", key.type_name(), value.type_name()).into()
134 }
135 Self::Union(fields) => {
136 let fields_str: Vec<_> = fields
137 .iter()
138 .map(|(name, ty)| format!("{name}: {}", ty.type_name()))
139 .collect();
140 format!("UNION({})", fields_str.join(", ")).into()
141 }
142 Self::Pointer => "POINTER".into(),
143 }
144 }
145
146 pub fn is_numeric(&self) -> bool {
148 matches!(
149 self,
150 Self::Int8
151 | Self::Int16
152 | Self::Int32
153 | Self::Int64
154 | Self::Int128
155 | Self::UInt8
156 | Self::UInt16
157 | Self::UInt32
158 | Self::UInt64
159 | Self::Float
160 | Self::Double
161 | Self::Serial
162 | Self::Decimal { .. }
163 )
164 }
165
166 pub fn is_integer(&self) -> bool {
168 matches!(
169 self,
170 Self::Int8
171 | Self::Int16
172 | Self::Int32
173 | Self::Int64
174 | Self::Int128
175 | Self::UInt8
176 | Self::UInt16
177 | Self::UInt32
178 | Self::UInt64
179 | Self::Serial
180 )
181 }
182
183 pub fn is_temporal(&self) -> bool {
185 matches!(
186 self,
187 Self::Date
188 | Self::Timestamp
189 | Self::TimestampSec
190 | Self::TimestampMs
191 | Self::TimestampNs
192 | Self::TimestampTz
193 | Self::Interval
194 )
195 }
196
197 pub fn is_nested(&self) -> bool {
199 matches!(
200 self,
201 Self::List(_)
202 | Self::Array { .. }
203 | Self::Struct(_)
204 | Self::Map { .. }
205 | Self::Union(_)
206 )
207 }
208}
209
210impl std::fmt::Display for LogicalType {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 write!(f, "{}", self.type_name())
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn physical_type_mapping_integers() {
222 assert_eq!(LogicalType::Int8.physical_type(), PhysicalType::Int8);
223 assert_eq!(LogicalType::Int16.physical_type(), PhysicalType::Int16);
224 assert_eq!(LogicalType::Int32.physical_type(), PhysicalType::Int32);
225 assert_eq!(LogicalType::Int64.physical_type(), PhysicalType::Int64);
226 assert_eq!(LogicalType::Int128.physical_type(), PhysicalType::Int128);
227 assert_eq!(LogicalType::UInt8.physical_type(), PhysicalType::UInt8);
228 assert_eq!(LogicalType::UInt16.physical_type(), PhysicalType::UInt16);
229 assert_eq!(LogicalType::UInt32.physical_type(), PhysicalType::UInt32);
230 assert_eq!(LogicalType::UInt64.physical_type(), PhysicalType::UInt64);
231 }
232
233 #[test]
234 fn physical_type_mapping_temporal() {
235 assert_eq!(LogicalType::Date.physical_type(), PhysicalType::Int64);
236 assert_eq!(LogicalType::Timestamp.physical_type(), PhysicalType::Int64);
237 assert_eq!(
238 LogicalType::TimestampNs.physical_type(),
239 PhysicalType::Int64
240 );
241 assert_eq!(LogicalType::Serial.physical_type(), PhysicalType::Int64);
242 assert_eq!(
243 LogicalType::Interval.physical_type(),
244 PhysicalType::Interval
245 );
246 }
247
248 #[test]
249 fn physical_type_mapping_strings() {
250 assert_eq!(LogicalType::String.physical_type(), PhysicalType::String);
251 assert_eq!(LogicalType::Blob.physical_type(), PhysicalType::String);
252 assert_eq!(LogicalType::Uuid.physical_type(), PhysicalType::String);
253 }
254
255 #[test]
256 fn physical_type_mapping_nested() {
257 let list = LogicalType::List(Box::new(LogicalType::Int64));
258 assert_eq!(list.physical_type(), PhysicalType::List);
259
260 let arr = LogicalType::Array {
261 element: Box::new(LogicalType::Float),
262 size: 1536,
263 };
264 assert_eq!(arr.physical_type(), PhysicalType::Array);
265
266 let strukt = LogicalType::Struct(vec![
267 (SmolStr::new("name"), LogicalType::String),
268 (SmolStr::new("age"), LogicalType::Int64),
269 ]);
270 assert_eq!(strukt.physical_type(), PhysicalType::Struct);
271
272 let map = LogicalType::Map {
273 key: Box::new(LogicalType::String),
274 value: Box::new(LogicalType::Int64),
275 };
276 assert_eq!(map.physical_type(), PhysicalType::List);
277 }
278
279 #[test]
280 fn type_name_simple() {
281 assert_eq!(LogicalType::Bool.type_name().as_ref(), "BOOL");
282 assert_eq!(LogicalType::Int64.type_name().as_ref(), "INT64");
283 assert_eq!(LogicalType::String.type_name().as_ref(), "STRING");
284 }
285
286 #[test]
287 fn type_name_complex() {
288 let list = LogicalType::List(Box::new(LogicalType::Int64));
289 assert_eq!(list.type_name().as_ref(), "INT64[]");
290
291 let arr = LogicalType::Array {
292 element: Box::new(LogicalType::Float),
293 size: 1536,
294 };
295 assert_eq!(arr.type_name().as_ref(), "FLOAT[1536]");
296
297 let dec = LogicalType::Decimal {
298 precision: 18,
299 scale: 3,
300 };
301 assert_eq!(dec.type_name().as_ref(), "DECIMAL(18,3)");
302 }
303
304 #[test]
305 fn classification_methods() {
306 assert!(LogicalType::Int64.is_numeric());
307 assert!(LogicalType::Float.is_numeric());
308 assert!(!LogicalType::String.is_numeric());
309
310 assert!(LogicalType::Int32.is_integer());
311 assert!(LogicalType::UInt64.is_integer());
312 assert!(!LogicalType::Float.is_integer());
313
314 assert!(LogicalType::Date.is_temporal());
315 assert!(LogicalType::Interval.is_temporal());
316 assert!(!LogicalType::Int64.is_temporal());
317
318 assert!(LogicalType::List(Box::new(LogicalType::Any)).is_nested());
319 assert!(!LogicalType::String.is_nested());
320 }
321
322 #[test]
323 fn equality_and_hash() {
324 use std::collections::HashSet;
325
326 let a = LogicalType::List(Box::new(LogicalType::Int64));
327 let b = LogicalType::List(Box::new(LogicalType::Int64));
328 let c = LogicalType::List(Box::new(LogicalType::Int32));
329 assert_eq!(a, b);
330 assert_ne!(a, c);
331
332 let mut set = HashSet::new();
333 set.insert(a);
334 set.insert(b);
335 assert_eq!(set.len(), 1);
336 }
337
338 #[test]
339 fn display() {
340 assert_eq!(LogicalType::Int64.to_string(), "INT64");
341 let list = LogicalType::List(Box::new(LogicalType::String));
342 assert_eq!(list.to_string(), "STRING[]");
343 }
344}