Skip to main content

robin_sparkless/
schema.rs

1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6    String,
7    Integer,
8    Long,
9    Double,
10    Boolean,
11    Date,
12    Timestamp,
13    Array(Box<DataType>),
14    Map(Box<DataType>, Box<DataType>),
15    Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20    pub name: String,
21    pub data_type: DataType,
22    pub nullable: bool,
23}
24
25impl StructField {
26    pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27        StructField {
28            name,
29            data_type,
30            nullable,
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37    fields: Vec<StructField>,
38}
39
40impl StructType {
41    pub fn new(fields: Vec<StructField>) -> Self {
42        StructType { fields }
43    }
44
45    pub fn from_polars_schema(schema: &Schema) -> Self {
46        let fields = schema
47            .iter()
48            .map(|(name, dtype)| StructField {
49                name: name.to_string(),
50                data_type: polars_type_to_data_type(dtype),
51                nullable: true, // Polars doesn't expose nullability in the same way
52            })
53            .collect();
54        StructType { fields }
55    }
56
57    pub fn to_polars_schema(&self) -> Schema {
58        use polars::prelude::Field;
59        let fields: Vec<Field> = self
60            .fields
61            .iter()
62            .map(|f| {
63                Field::new(
64                    f.name.as_str().into(),
65                    data_type_to_polars_type(&f.data_type),
66                )
67            })
68            .collect();
69        Schema::from_iter(fields)
70    }
71
72    pub fn fields(&self) -> &[StructField] {
73        &self.fields
74    }
75}
76
77fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
78    match polars_type {
79        PlDataType::String => DataType::String,
80        PlDataType::Int32 => DataType::Integer,
81        PlDataType::Int64 => DataType::Long,
82        PlDataType::Float64 => DataType::Double,
83        PlDataType::Boolean => DataType::Boolean,
84        PlDataType::Date => DataType::Date,
85        PlDataType::Datetime(_, _) => DataType::Timestamp,
86        PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
87        _ => DataType::String, // Default fallback
88    }
89}
90
91fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
92    match data_type {
93        DataType::String => PlDataType::String,
94        DataType::Integer => PlDataType::Int32,
95        DataType::Long => PlDataType::Int64,
96        DataType::Double => PlDataType::Float64,
97        DataType::Boolean => PlDataType::Boolean,
98        DataType::Date => PlDataType::Date,
99        DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
100        DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
101        _ => PlDataType::String, // Default fallback
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use polars::prelude::{Field, Schema};
109
110    #[test]
111    fn test_struct_field_new() {
112        let field = StructField::new("age".to_string(), DataType::Integer, true);
113        assert_eq!(field.name, "age");
114        assert!(field.nullable);
115        assert!(matches!(field.data_type, DataType::Integer));
116    }
117
118    #[test]
119    fn test_struct_type_new() {
120        let fields = vec![
121            StructField::new("id".to_string(), DataType::Long, false),
122            StructField::new("name".to_string(), DataType::String, true),
123        ];
124        let schema = StructType::new(fields);
125        assert_eq!(schema.fields().len(), 2);
126        assert_eq!(schema.fields()[0].name, "id");
127        assert_eq!(schema.fields()[1].name, "name");
128    }
129
130    #[test]
131    fn test_struct_type_from_polars_schema() {
132        // Create a Polars schema
133        let polars_schema = Schema::from_iter(vec![
134            Field::new("id".into(), PlDataType::Int64),
135            Field::new("name".into(), PlDataType::String),
136            Field::new("score".into(), PlDataType::Float64),
137            Field::new("active".into(), PlDataType::Boolean),
138        ]);
139
140        let struct_type = StructType::from_polars_schema(&polars_schema);
141
142        assert_eq!(struct_type.fields().len(), 4);
143        assert_eq!(struct_type.fields()[0].name, "id");
144        assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
145        assert_eq!(struct_type.fields()[1].name, "name");
146        assert!(matches!(
147            struct_type.fields()[1].data_type,
148            DataType::String
149        ));
150        assert_eq!(struct_type.fields()[2].name, "score");
151        assert!(matches!(
152            struct_type.fields()[2].data_type,
153            DataType::Double
154        ));
155        assert_eq!(struct_type.fields()[3].name, "active");
156        assert!(matches!(
157            struct_type.fields()[3].data_type,
158            DataType::Boolean
159        ));
160    }
161
162    #[test]
163    fn test_struct_type_to_polars_schema() {
164        let fields = vec![
165            StructField::new("id".to_string(), DataType::Long, false),
166            StructField::new("name".to_string(), DataType::String, true),
167            StructField::new("score".to_string(), DataType::Double, true),
168        ];
169        let struct_type = StructType::new(fields);
170
171        let polars_schema = struct_type.to_polars_schema();
172
173        assert_eq!(polars_schema.len(), 3);
174        assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
175        assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
176        assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
177    }
178
179    #[test]
180    fn test_roundtrip_schema_conversion() {
181        // Create a struct type, convert to Polars, convert back
182        let original = StructType::new(vec![
183            StructField::new("a".to_string(), DataType::Integer, true),
184            StructField::new("b".to_string(), DataType::Long, true),
185            StructField::new("c".to_string(), DataType::Double, true),
186            StructField::new("d".to_string(), DataType::Boolean, true),
187            StructField::new("e".to_string(), DataType::String, true),
188        ]);
189
190        let polars_schema = original.to_polars_schema();
191        let roundtrip = StructType::from_polars_schema(&polars_schema);
192
193        assert_eq!(roundtrip.fields().len(), original.fields().len());
194        for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
195            assert_eq!(orig.name, rt.name);
196        }
197    }
198
199    #[test]
200    fn test_polars_type_to_data_type_basic() {
201        assert!(matches!(
202            polars_type_to_data_type(&PlDataType::String),
203            DataType::String
204        ));
205        assert!(matches!(
206            polars_type_to_data_type(&PlDataType::Int32),
207            DataType::Integer
208        ));
209        assert!(matches!(
210            polars_type_to_data_type(&PlDataType::Int64),
211            DataType::Long
212        ));
213        assert!(matches!(
214            polars_type_to_data_type(&PlDataType::Float64),
215            DataType::Double
216        ));
217        assert!(matches!(
218            polars_type_to_data_type(&PlDataType::Boolean),
219            DataType::Boolean
220        ));
221        assert!(matches!(
222            polars_type_to_data_type(&PlDataType::Date),
223            DataType::Date
224        ));
225    }
226
227    #[test]
228    fn test_polars_type_to_data_type_datetime() {
229        let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
230        assert!(matches!(
231            polars_type_to_data_type(&datetime_type),
232            DataType::Timestamp
233        ));
234    }
235
236    #[test]
237    fn test_polars_type_to_data_type_list() {
238        let list_type = PlDataType::List(Box::new(PlDataType::Int64));
239        match polars_type_to_data_type(&list_type) {
240            DataType::Array(inner) => {
241                assert!(matches!(*inner, DataType::Long));
242            }
243            _ => panic!("Expected Array type"),
244        }
245    }
246
247    #[test]
248    fn test_polars_type_to_data_type_fallback() {
249        // Unknown type should fall back to String
250        let unknown_type = PlDataType::UInt8;
251        assert!(matches!(
252            polars_type_to_data_type(&unknown_type),
253            DataType::String
254        ));
255    }
256
257    #[test]
258    fn test_data_type_to_polars_type_basic() {
259        assert_eq!(
260            data_type_to_polars_type(&DataType::String),
261            PlDataType::String
262        );
263        assert_eq!(
264            data_type_to_polars_type(&DataType::Integer),
265            PlDataType::Int32
266        );
267        assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
268        assert_eq!(
269            data_type_to_polars_type(&DataType::Double),
270            PlDataType::Float64
271        );
272        assert_eq!(
273            data_type_to_polars_type(&DataType::Boolean),
274            PlDataType::Boolean
275        );
276        assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
277    }
278
279    #[test]
280    fn test_data_type_to_polars_type_timestamp() {
281        let result = data_type_to_polars_type(&DataType::Timestamp);
282        assert!(matches!(
283            result,
284            PlDataType::Datetime(TimeUnit::Microseconds, None)
285        ));
286    }
287
288    #[test]
289    fn test_data_type_to_polars_type_array() {
290        let array_type = DataType::Array(Box::new(DataType::Long));
291        let result = data_type_to_polars_type(&array_type);
292        match result {
293            PlDataType::List(inner) => {
294                assert_eq!(*inner, PlDataType::Int64);
295            }
296            _ => panic!("Expected List type"),
297        }
298    }
299
300    #[test]
301    fn test_data_type_to_polars_type_map_fallback() {
302        // Map type falls back to String
303        let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
304        assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
305    }
306
307    #[test]
308    fn test_data_type_to_polars_type_struct_fallback() {
309        // Nested Struct falls back to String
310        let struct_type = DataType::Struct(vec![StructField::new(
311            "nested".to_string(),
312            DataType::Integer,
313            true,
314        )]);
315        assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
316    }
317
318    #[test]
319    fn test_empty_struct_type() {
320        let empty = StructType::new(vec![]);
321        assert!(empty.fields().is_empty());
322
323        let polars_schema = empty.to_polars_schema();
324        assert!(polars_schema.is_empty());
325    }
326}