1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6 String,
7 Integer,
8 Long,
9 Double,
10 Boolean,
11 Date,
12 Timestamp,
13 Array(Box<DataType>),
14 Map(Box<DataType>, Box<DataType>),
15 Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20 pub name: String,
21 pub data_type: DataType,
22 pub nullable: bool,
23}
24
25impl StructField {
26 pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27 StructField {
28 name,
29 data_type,
30 nullable,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37 fields: Vec<StructField>,
38}
39
40impl StructType {
41 pub fn new(fields: Vec<StructField>) -> Self {
42 StructType { fields }
43 }
44
45 pub fn from_polars_schema(schema: &Schema) -> Self {
46 let fields = schema
47 .iter()
48 .map(|(name, dtype)| StructField {
49 name: name.to_string(),
50 data_type: polars_type_to_data_type(dtype),
51 nullable: true, })
53 .collect();
54 StructType { fields }
55 }
56
57 pub fn to_polars_schema(&self) -> Schema {
58 use polars::prelude::Field;
59 let fields: Vec<Field> = self
60 .fields
61 .iter()
62 .map(|f| {
63 Field::new(
64 f.name.as_str().into(),
65 data_type_to_polars_type(&f.data_type),
66 )
67 })
68 .collect();
69 Schema::from_iter(fields)
70 }
71
72 pub fn fields(&self) -> &[StructField] {
73 &self.fields
74 }
75}
76
77fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
78 match polars_type {
79 PlDataType::String => DataType::String,
80 PlDataType::Int32 => DataType::Integer,
81 PlDataType::Int64 => DataType::Long,
82 PlDataType::Float64 => DataType::Double,
83 PlDataType::Boolean => DataType::Boolean,
84 PlDataType::Date => DataType::Date,
85 PlDataType::Datetime(_, _) => DataType::Timestamp,
86 PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
87 _ => DataType::String, }
89}
90
91fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
92 match data_type {
93 DataType::String => PlDataType::String,
94 DataType::Integer => PlDataType::Int32,
95 DataType::Long => PlDataType::Int64,
96 DataType::Double => PlDataType::Float64,
97 DataType::Boolean => PlDataType::Boolean,
98 DataType::Date => PlDataType::Date,
99 DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
100 DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
101 _ => PlDataType::String, }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use polars::prelude::{Field, Schema};
109
110 #[test]
111 fn test_struct_field_new() {
112 let field = StructField::new("age".to_string(), DataType::Integer, true);
113 assert_eq!(field.name, "age");
114 assert!(field.nullable);
115 assert!(matches!(field.data_type, DataType::Integer));
116 }
117
118 #[test]
119 fn test_struct_type_new() {
120 let fields = vec![
121 StructField::new("id".to_string(), DataType::Long, false),
122 StructField::new("name".to_string(), DataType::String, true),
123 ];
124 let schema = StructType::new(fields);
125 assert_eq!(schema.fields().len(), 2);
126 assert_eq!(schema.fields()[0].name, "id");
127 assert_eq!(schema.fields()[1].name, "name");
128 }
129
130 #[test]
131 fn test_struct_type_from_polars_schema() {
132 let polars_schema = Schema::from_iter(vec![
134 Field::new("id".into(), PlDataType::Int64),
135 Field::new("name".into(), PlDataType::String),
136 Field::new("score".into(), PlDataType::Float64),
137 Field::new("active".into(), PlDataType::Boolean),
138 ]);
139
140 let struct_type = StructType::from_polars_schema(&polars_schema);
141
142 assert_eq!(struct_type.fields().len(), 4);
143 assert_eq!(struct_type.fields()[0].name, "id");
144 assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
145 assert_eq!(struct_type.fields()[1].name, "name");
146 assert!(matches!(
147 struct_type.fields()[1].data_type,
148 DataType::String
149 ));
150 assert_eq!(struct_type.fields()[2].name, "score");
151 assert!(matches!(
152 struct_type.fields()[2].data_type,
153 DataType::Double
154 ));
155 assert_eq!(struct_type.fields()[3].name, "active");
156 assert!(matches!(
157 struct_type.fields()[3].data_type,
158 DataType::Boolean
159 ));
160 }
161
162 #[test]
163 fn test_struct_type_to_polars_schema() {
164 let fields = vec![
165 StructField::new("id".to_string(), DataType::Long, false),
166 StructField::new("name".to_string(), DataType::String, true),
167 StructField::new("score".to_string(), DataType::Double, true),
168 ];
169 let struct_type = StructType::new(fields);
170
171 let polars_schema = struct_type.to_polars_schema();
172
173 assert_eq!(polars_schema.len(), 3);
174 assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
175 assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
176 assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
177 }
178
179 #[test]
180 fn test_roundtrip_schema_conversion() {
181 let original = StructType::new(vec![
183 StructField::new("a".to_string(), DataType::Integer, true),
184 StructField::new("b".to_string(), DataType::Long, true),
185 StructField::new("c".to_string(), DataType::Double, true),
186 StructField::new("d".to_string(), DataType::Boolean, true),
187 StructField::new("e".to_string(), DataType::String, true),
188 ]);
189
190 let polars_schema = original.to_polars_schema();
191 let roundtrip = StructType::from_polars_schema(&polars_schema);
192
193 assert_eq!(roundtrip.fields().len(), original.fields().len());
194 for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
195 assert_eq!(orig.name, rt.name);
196 }
197 }
198
199 #[test]
200 fn test_polars_type_to_data_type_basic() {
201 assert!(matches!(
202 polars_type_to_data_type(&PlDataType::String),
203 DataType::String
204 ));
205 assert!(matches!(
206 polars_type_to_data_type(&PlDataType::Int32),
207 DataType::Integer
208 ));
209 assert!(matches!(
210 polars_type_to_data_type(&PlDataType::Int64),
211 DataType::Long
212 ));
213 assert!(matches!(
214 polars_type_to_data_type(&PlDataType::Float64),
215 DataType::Double
216 ));
217 assert!(matches!(
218 polars_type_to_data_type(&PlDataType::Boolean),
219 DataType::Boolean
220 ));
221 assert!(matches!(
222 polars_type_to_data_type(&PlDataType::Date),
223 DataType::Date
224 ));
225 }
226
227 #[test]
228 fn test_polars_type_to_data_type_datetime() {
229 let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
230 assert!(matches!(
231 polars_type_to_data_type(&datetime_type),
232 DataType::Timestamp
233 ));
234 }
235
236 #[test]
237 fn test_polars_type_to_data_type_list() {
238 let list_type = PlDataType::List(Box::new(PlDataType::Int64));
239 match polars_type_to_data_type(&list_type) {
240 DataType::Array(inner) => {
241 assert!(matches!(*inner, DataType::Long));
242 }
243 _ => panic!("Expected Array type"),
244 }
245 }
246
247 #[test]
248 fn test_polars_type_to_data_type_fallback() {
249 let unknown_type = PlDataType::UInt8;
251 assert!(matches!(
252 polars_type_to_data_type(&unknown_type),
253 DataType::String
254 ));
255 }
256
257 #[test]
258 fn test_data_type_to_polars_type_basic() {
259 assert_eq!(
260 data_type_to_polars_type(&DataType::String),
261 PlDataType::String
262 );
263 assert_eq!(
264 data_type_to_polars_type(&DataType::Integer),
265 PlDataType::Int32
266 );
267 assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
268 assert_eq!(
269 data_type_to_polars_type(&DataType::Double),
270 PlDataType::Float64
271 );
272 assert_eq!(
273 data_type_to_polars_type(&DataType::Boolean),
274 PlDataType::Boolean
275 );
276 assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
277 }
278
279 #[test]
280 fn test_data_type_to_polars_type_timestamp() {
281 let result = data_type_to_polars_type(&DataType::Timestamp);
282 assert!(matches!(
283 result,
284 PlDataType::Datetime(TimeUnit::Microseconds, None)
285 ));
286 }
287
288 #[test]
289 fn test_data_type_to_polars_type_array() {
290 let array_type = DataType::Array(Box::new(DataType::Long));
291 let result = data_type_to_polars_type(&array_type);
292 match result {
293 PlDataType::List(inner) => {
294 assert_eq!(*inner, PlDataType::Int64);
295 }
296 _ => panic!("Expected List type"),
297 }
298 }
299
300 #[test]
301 fn test_data_type_to_polars_type_map_fallback() {
302 let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
304 assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
305 }
306
307 #[test]
308 fn test_data_type_to_polars_type_struct_fallback() {
309 let struct_type = DataType::Struct(vec![StructField::new(
311 "nested".to_string(),
312 DataType::Integer,
313 true,
314 )]);
315 assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
316 }
317
318 #[test]
319 fn test_empty_struct_type() {
320 let empty = StructType::new(vec![]);
321 assert!(empty.fields().is_empty());
322
323 let polars_schema = empty.to_polars_schema();
324 assert!(polars_schema.is_empty());
325 }
326}