1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6 String,
7 Integer,
8 Long,
9 Double,
10 Boolean,
11 Date,
12 Timestamp,
13 Array(Box<DataType>),
14 Map(Box<DataType>, Box<DataType>),
15 Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20 pub name: String,
21 pub data_type: DataType,
22 pub nullable: bool,
23}
24
25impl StructField {
26 pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27 StructField {
28 name,
29 data_type,
30 nullable,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37 fields: Vec<StructField>,
38}
39
40impl StructType {
41 pub fn new(fields: Vec<StructField>) -> Self {
42 StructType { fields }
43 }
44
45 pub fn from_polars_schema(schema: &Schema) -> Self {
46 let fields = schema
47 .iter()
48 .map(|(name, dtype)| StructField {
49 name: name.to_string(),
50 data_type: polars_type_to_data_type(dtype),
51 nullable: true, })
53 .collect();
54 StructType { fields }
55 }
56
57 pub fn to_polars_schema(&self) -> Schema {
58 use polars::prelude::Field;
59 let fields: Vec<Field> = self
60 .fields
61 .iter()
62 .map(|f| {
63 Field::new(
64 f.name.as_str().into(),
65 data_type_to_polars_type(&f.data_type),
66 )
67 })
68 .collect();
69 Schema::from_iter(fields)
70 }
71
72 pub fn fields(&self) -> &[StructField] {
73 &self.fields
74 }
75
76 pub fn to_json(&self) -> Result<String, serde_json::Error> {
79 serde_json::to_string(self)
80 }
81
82 pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
84 serde_json::to_string_pretty(self)
85 }
86}
87
88pub fn schema_from_json(json: &str) -> Result<StructType, crate::error::EngineError> {
91 serde_json::from_str(json).map_err(crate::error::EngineError::from)
92}
93
94fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
95 match polars_type {
96 PlDataType::String => DataType::String,
97 PlDataType::Int32 | PlDataType::Int64 => DataType::Long,
99 PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
101 PlDataType::Boolean => DataType::Boolean,
102 PlDataType::Date => DataType::Date,
103 PlDataType::Datetime(_, _) => DataType::Timestamp,
104 PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
105 _ => DataType::String, }
107}
108
109fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
110 match data_type {
111 DataType::String => PlDataType::String,
112 DataType::Integer => PlDataType::Int32,
113 DataType::Long => PlDataType::Int64,
114 DataType::Double => PlDataType::Float64,
115 DataType::Boolean => PlDataType::Boolean,
116 DataType::Date => PlDataType::Date,
117 DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
118 DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
119 _ => PlDataType::String, }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use polars::prelude::{Field, Schema};
127
128 #[test]
129 fn test_struct_field_new() {
130 let field = StructField::new("age".to_string(), DataType::Integer, true);
131 assert_eq!(field.name, "age");
132 assert!(field.nullable);
133 assert!(matches!(field.data_type, DataType::Integer));
134 }
135
136 #[test]
137 fn test_struct_type_new() {
138 let fields = vec![
139 StructField::new("id".to_string(), DataType::Long, false),
140 StructField::new("name".to_string(), DataType::String, true),
141 ];
142 let schema = StructType::new(fields);
143 assert_eq!(schema.fields().len(), 2);
144 assert_eq!(schema.fields()[0].name, "id");
145 assert_eq!(schema.fields()[1].name, "name");
146 }
147
148 #[test]
149 fn test_struct_type_from_polars_schema() {
150 let polars_schema = Schema::from_iter(vec![
152 Field::new("id".into(), PlDataType::Int64),
153 Field::new("name".into(), PlDataType::String),
154 Field::new("score".into(), PlDataType::Float64),
155 Field::new("active".into(), PlDataType::Boolean),
156 ]);
157
158 let struct_type = StructType::from_polars_schema(&polars_schema);
159
160 assert_eq!(struct_type.fields().len(), 4);
161 assert_eq!(struct_type.fields()[0].name, "id");
162 assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
163 assert_eq!(struct_type.fields()[1].name, "name");
164 assert!(matches!(
165 struct_type.fields()[1].data_type,
166 DataType::String
167 ));
168 assert_eq!(struct_type.fields()[2].name, "score");
169 assert!(matches!(
170 struct_type.fields()[2].data_type,
171 DataType::Double
172 ));
173 assert_eq!(struct_type.fields()[3].name, "active");
174 assert!(matches!(
175 struct_type.fields()[3].data_type,
176 DataType::Boolean
177 ));
178 }
179
180 #[test]
181 fn test_struct_type_to_polars_schema() {
182 let fields = vec![
183 StructField::new("id".to_string(), DataType::Long, false),
184 StructField::new("name".to_string(), DataType::String, true),
185 StructField::new("score".to_string(), DataType::Double, true),
186 ];
187 let struct_type = StructType::new(fields);
188
189 let polars_schema = struct_type.to_polars_schema();
190
191 assert_eq!(polars_schema.len(), 3);
192 assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
193 assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
194 assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
195 }
196
197 #[test]
198 fn test_roundtrip_schema_conversion() {
199 let original = StructType::new(vec![
201 StructField::new("a".to_string(), DataType::Integer, true),
202 StructField::new("b".to_string(), DataType::Long, true),
203 StructField::new("c".to_string(), DataType::Double, true),
204 StructField::new("d".to_string(), DataType::Boolean, true),
205 StructField::new("e".to_string(), DataType::String, true),
206 ]);
207
208 let polars_schema = original.to_polars_schema();
209 let roundtrip = StructType::from_polars_schema(&polars_schema);
210
211 assert_eq!(roundtrip.fields().len(), original.fields().len());
212 for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
213 assert_eq!(orig.name, rt.name);
214 }
215 }
216
217 #[test]
218 fn test_struct_type_to_json() {
219 let fields = vec![
220 StructField::new("id".to_string(), DataType::Long, false),
221 StructField::new("name".to_string(), DataType::String, true),
222 ];
223 let schema = StructType::new(fields);
224 let json = schema.to_json().unwrap();
225 assert!(json.contains("\"name\":\"id\""));
226 assert!(json.contains("\"name\":\"name\""));
227 assert!(json.contains("\"data_type\""));
228 assert!(json.contains("\"nullable\""));
229 let _parsed: StructType = serde_json::from_str(&json).unwrap();
230 let pretty = schema.to_json_pretty().unwrap();
231 assert!(pretty.contains('\n'));
232 }
233
234 #[test]
235 fn test_polars_type_to_data_type_basic() {
236 assert!(matches!(
237 polars_type_to_data_type(&PlDataType::String),
238 DataType::String
239 ));
240 assert!(matches!(
241 polars_type_to_data_type(&PlDataType::Int64),
242 DataType::Long
243 ));
244 assert!(matches!(
245 polars_type_to_data_type(&PlDataType::Float64),
246 DataType::Double
247 ));
248 assert!(matches!(
249 polars_type_to_data_type(&PlDataType::Boolean),
250 DataType::Boolean
251 ));
252 assert!(matches!(
253 polars_type_to_data_type(&PlDataType::Date),
254 DataType::Date
255 ));
256 }
257
258 #[test]
259 fn test_polars_type_to_data_type_datetime() {
260 let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
261 assert!(matches!(
262 polars_type_to_data_type(&datetime_type),
263 DataType::Timestamp
264 ));
265 }
266
267 #[test]
268 fn test_polars_type_to_data_type_list() {
269 let list_type = PlDataType::List(Box::new(PlDataType::Int64));
270 match polars_type_to_data_type(&list_type) {
271 DataType::Array(inner) => {
272 assert!(matches!(*inner, DataType::Long));
273 }
274 other => panic!("Expected Array type, got {other:?}"),
275 }
276 }
277
278 #[test]
279 fn test_polars_type_to_data_type_fallback() {
280 let unknown_type = PlDataType::UInt8;
282 assert!(matches!(
283 polars_type_to_data_type(&unknown_type),
284 DataType::String
285 ));
286 }
287
288 #[test]
289 fn test_data_type_to_polars_type_basic() {
290 assert_eq!(
291 data_type_to_polars_type(&DataType::String),
292 PlDataType::String
293 );
294 assert_eq!(
295 data_type_to_polars_type(&DataType::Integer),
296 PlDataType::Int32
297 );
298 assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
299 assert_eq!(
300 data_type_to_polars_type(&DataType::Double),
301 PlDataType::Float64
302 );
303 assert_eq!(
304 data_type_to_polars_type(&DataType::Boolean),
305 PlDataType::Boolean
306 );
307 assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
308 }
309
310 #[test]
311 fn test_data_type_to_polars_type_timestamp() {
312 let result = data_type_to_polars_type(&DataType::Timestamp);
313 assert!(matches!(
314 result,
315 PlDataType::Datetime(TimeUnit::Microseconds, None)
316 ));
317 }
318
319 #[test]
320 fn test_data_type_to_polars_type_array() {
321 let array_type = DataType::Array(Box::new(DataType::Long));
322 let result = data_type_to_polars_type(&array_type);
323 match result {
324 PlDataType::List(inner) => {
325 assert_eq!(*inner, PlDataType::Int64);
326 }
327 other => panic!("Expected List type, got {other:?}"),
328 }
329 }
330
331 #[test]
332 fn test_data_type_to_polars_type_map_fallback() {
333 let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
335 assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
336 }
337
338 #[test]
339 fn test_data_type_to_polars_type_struct_fallback() {
340 let struct_type = DataType::Struct(vec![StructField::new(
342 "nested".to_string(),
343 DataType::Integer,
344 true,
345 )]);
346 assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
347 }
348
349 #[test]
350 fn test_empty_struct_type() {
351 let empty = StructType::new(vec![]);
352 assert!(empty.fields().is_empty());
353
354 let polars_schema = empty.to_polars_schema();
355 assert!(polars_schema.is_empty());
356 }
357}