1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6 String,
7 Integer,
8 Long,
9 Double,
10 Boolean,
11 Date,
12 Timestamp,
13 Array(Box<DataType>),
14 Map(Box<DataType>, Box<DataType>),
15 Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20 pub name: String,
21 pub data_type: DataType,
22 pub nullable: bool,
23}
24
25impl StructField {
26 pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27 StructField {
28 name,
29 data_type,
30 nullable,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37 fields: Vec<StructField>,
38}
39
40impl StructType {
41 pub fn new(fields: Vec<StructField>) -> Self {
42 StructType { fields }
43 }
44
45 pub fn from_polars_schema(schema: &Schema) -> Self {
46 let fields = schema
47 .iter()
48 .map(|(name, dtype)| StructField {
49 name: name.to_string(),
50 data_type: polars_type_to_data_type(dtype),
51 nullable: true, })
53 .collect();
54 StructType { fields }
55 }
56
57 pub fn to_polars_schema(&self) -> Schema {
58 use polars::prelude::Field;
59 let fields: Vec<Field> = self
60 .fields
61 .iter()
62 .map(|f| {
63 Field::new(
64 f.name.as_str().into(),
65 data_type_to_polars_type(&f.data_type),
66 )
67 })
68 .collect();
69 Schema::from_iter(fields)
70 }
71
72 pub fn fields(&self) -> &[StructField] {
73 &self.fields
74 }
75}
76
77fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
78 match polars_type {
79 PlDataType::String => DataType::String,
80 PlDataType::Int32 | PlDataType::Int64 => DataType::Long,
82 PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
84 PlDataType::Boolean => DataType::Boolean,
85 PlDataType::Date => DataType::Date,
86 PlDataType::Datetime(_, _) => DataType::Timestamp,
87 PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
88 _ => DataType::String, }
90}
91
92fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
93 match data_type {
94 DataType::String => PlDataType::String,
95 DataType::Integer => PlDataType::Int32,
96 DataType::Long => PlDataType::Int64,
97 DataType::Double => PlDataType::Float64,
98 DataType::Boolean => PlDataType::Boolean,
99 DataType::Date => PlDataType::Date,
100 DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
101 DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
102 _ => PlDataType::String, }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use polars::prelude::{Field, Schema};
110
111 #[test]
112 fn test_struct_field_new() {
113 let field = StructField::new("age".to_string(), DataType::Integer, true);
114 assert_eq!(field.name, "age");
115 assert!(field.nullable);
116 assert!(matches!(field.data_type, DataType::Integer));
117 }
118
119 #[test]
120 fn test_struct_type_new() {
121 let fields = vec![
122 StructField::new("id".to_string(), DataType::Long, false),
123 StructField::new("name".to_string(), DataType::String, true),
124 ];
125 let schema = StructType::new(fields);
126 assert_eq!(schema.fields().len(), 2);
127 assert_eq!(schema.fields()[0].name, "id");
128 assert_eq!(schema.fields()[1].name, "name");
129 }
130
131 #[test]
132 fn test_struct_type_from_polars_schema() {
133 let polars_schema = Schema::from_iter(vec![
135 Field::new("id".into(), PlDataType::Int64),
136 Field::new("name".into(), PlDataType::String),
137 Field::new("score".into(), PlDataType::Float64),
138 Field::new("active".into(), PlDataType::Boolean),
139 ]);
140
141 let struct_type = StructType::from_polars_schema(&polars_schema);
142
143 assert_eq!(struct_type.fields().len(), 4);
144 assert_eq!(struct_type.fields()[0].name, "id");
145 assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
146 assert_eq!(struct_type.fields()[1].name, "name");
147 assert!(matches!(
148 struct_type.fields()[1].data_type,
149 DataType::String
150 ));
151 assert_eq!(struct_type.fields()[2].name, "score");
152 assert!(matches!(
153 struct_type.fields()[2].data_type,
154 DataType::Double
155 ));
156 assert_eq!(struct_type.fields()[3].name, "active");
157 assert!(matches!(
158 struct_type.fields()[3].data_type,
159 DataType::Boolean
160 ));
161 }
162
163 #[test]
164 fn test_struct_type_to_polars_schema() {
165 let fields = vec![
166 StructField::new("id".to_string(), DataType::Long, false),
167 StructField::new("name".to_string(), DataType::String, true),
168 StructField::new("score".to_string(), DataType::Double, true),
169 ];
170 let struct_type = StructType::new(fields);
171
172 let polars_schema = struct_type.to_polars_schema();
173
174 assert_eq!(polars_schema.len(), 3);
175 assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
176 assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
177 assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
178 }
179
180 #[test]
181 fn test_roundtrip_schema_conversion() {
182 let original = StructType::new(vec![
184 StructField::new("a".to_string(), DataType::Integer, true),
185 StructField::new("b".to_string(), DataType::Long, true),
186 StructField::new("c".to_string(), DataType::Double, true),
187 StructField::new("d".to_string(), DataType::Boolean, true),
188 StructField::new("e".to_string(), DataType::String, true),
189 ]);
190
191 let polars_schema = original.to_polars_schema();
192 let roundtrip = StructType::from_polars_schema(&polars_schema);
193
194 assert_eq!(roundtrip.fields().len(), original.fields().len());
195 for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
196 assert_eq!(orig.name, rt.name);
197 }
198 }
199
200 #[test]
201 fn test_polars_type_to_data_type_basic() {
202 assert!(matches!(
203 polars_type_to_data_type(&PlDataType::String),
204 DataType::String
205 ));
206 assert!(matches!(
207 polars_type_to_data_type(&PlDataType::Int64),
208 DataType::Long
209 ));
210 assert!(matches!(
211 polars_type_to_data_type(&PlDataType::Float64),
212 DataType::Double
213 ));
214 assert!(matches!(
215 polars_type_to_data_type(&PlDataType::Boolean),
216 DataType::Boolean
217 ));
218 assert!(matches!(
219 polars_type_to_data_type(&PlDataType::Date),
220 DataType::Date
221 ));
222 }
223
224 #[test]
225 fn test_polars_type_to_data_type_datetime() {
226 let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
227 assert!(matches!(
228 polars_type_to_data_type(&datetime_type),
229 DataType::Timestamp
230 ));
231 }
232
233 #[test]
234 fn test_polars_type_to_data_type_list() {
235 let list_type = PlDataType::List(Box::new(PlDataType::Int64));
236 match polars_type_to_data_type(&list_type) {
237 DataType::Array(inner) => {
238 assert!(matches!(*inner, DataType::Long));
239 }
240 _ => panic!("Expected Array type"),
241 }
242 }
243
244 #[test]
245 fn test_polars_type_to_data_type_fallback() {
246 let unknown_type = PlDataType::UInt8;
248 assert!(matches!(
249 polars_type_to_data_type(&unknown_type),
250 DataType::String
251 ));
252 }
253
254 #[test]
255 fn test_data_type_to_polars_type_basic() {
256 assert_eq!(
257 data_type_to_polars_type(&DataType::String),
258 PlDataType::String
259 );
260 assert_eq!(
261 data_type_to_polars_type(&DataType::Integer),
262 PlDataType::Int32
263 );
264 assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
265 assert_eq!(
266 data_type_to_polars_type(&DataType::Double),
267 PlDataType::Float64
268 );
269 assert_eq!(
270 data_type_to_polars_type(&DataType::Boolean),
271 PlDataType::Boolean
272 );
273 assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
274 }
275
276 #[test]
277 fn test_data_type_to_polars_type_timestamp() {
278 let result = data_type_to_polars_type(&DataType::Timestamp);
279 assert!(matches!(
280 result,
281 PlDataType::Datetime(TimeUnit::Microseconds, None)
282 ));
283 }
284
285 #[test]
286 fn test_data_type_to_polars_type_array() {
287 let array_type = DataType::Array(Box::new(DataType::Long));
288 let result = data_type_to_polars_type(&array_type);
289 match result {
290 PlDataType::List(inner) => {
291 assert_eq!(*inner, PlDataType::Int64);
292 }
293 _ => panic!("Expected List type"),
294 }
295 }
296
297 #[test]
298 fn test_data_type_to_polars_type_map_fallback() {
299 let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
301 assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
302 }
303
304 #[test]
305 fn test_data_type_to_polars_type_struct_fallback() {
306 let struct_type = DataType::Struct(vec![StructField::new(
308 "nested".to_string(),
309 DataType::Integer,
310 true,
311 )]);
312 assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
313 }
314
315 #[test]
316 fn test_empty_struct_type() {
317 let empty = StructType::new(vec![]);
318 assert!(empty.fields().is_empty());
319
320 let polars_schema = empty.to_polars_schema();
321 assert!(polars_schema.is_empty());
322 }
323}