Skip to main content

robin_sparkless/
schema.rs

1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6    String,
7    Integer,
8    Long,
9    Double,
10    Boolean,
11    Date,
12    Timestamp,
13    Array(Box<DataType>),
14    Map(Box<DataType>, Box<DataType>),
15    Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20    pub name: String,
21    pub data_type: DataType,
22    pub nullable: bool,
23}
24
25impl StructField {
26    pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27        StructField {
28            name,
29            data_type,
30            nullable,
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37    fields: Vec<StructField>,
38}
39
40impl StructType {
41    pub fn new(fields: Vec<StructField>) -> Self {
42        StructType { fields }
43    }
44
45    pub fn from_polars_schema(schema: &Schema) -> Self {
46        let fields = schema
47            .iter()
48            .map(|(name, dtype)| StructField {
49                name: name.to_string(),
50                data_type: polars_type_to_data_type(dtype),
51                nullable: true, // Polars doesn't expose nullability in the same way
52            })
53            .collect();
54        StructType { fields }
55    }
56
57    pub fn to_polars_schema(&self) -> Schema {
58        use polars::prelude::Field;
59        let fields: Vec<Field> = self
60            .fields
61            .iter()
62            .map(|f| {
63                Field::new(
64                    f.name.as_str().into(),
65                    data_type_to_polars_type(&f.data_type),
66                )
67            })
68            .collect();
69        Schema::from_iter(fields)
70    }
71
72    pub fn fields(&self) -> &[StructField] {
73        &self.fields
74    }
75}
76
77fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
78    match polars_type {
79        PlDataType::String => DataType::String,
80        // Spark/Sparkless inferSchema uses 64-bit for integral types; map Int32 as Long.
81        PlDataType::Int32 | PlDataType::Int64 => DataType::Long,
82        // Map both Float32 and Float64 to Double for schema parity.
83        PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
84        PlDataType::Boolean => DataType::Boolean,
85        PlDataType::Date => DataType::Date,
86        PlDataType::Datetime(_, _) => DataType::Timestamp,
87        PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
88        _ => DataType::String, // Default fallback
89    }
90}
91
92fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
93    match data_type {
94        DataType::String => PlDataType::String,
95        DataType::Integer => PlDataType::Int32,
96        DataType::Long => PlDataType::Int64,
97        DataType::Double => PlDataType::Float64,
98        DataType::Boolean => PlDataType::Boolean,
99        DataType::Date => PlDataType::Date,
100        DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
101        DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
102        _ => PlDataType::String, // Default fallback
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use polars::prelude::{Field, Schema};
110
111    #[test]
112    fn test_struct_field_new() {
113        let field = StructField::new("age".to_string(), DataType::Integer, true);
114        assert_eq!(field.name, "age");
115        assert!(field.nullable);
116        assert!(matches!(field.data_type, DataType::Integer));
117    }
118
119    #[test]
120    fn test_struct_type_new() {
121        let fields = vec![
122            StructField::new("id".to_string(), DataType::Long, false),
123            StructField::new("name".to_string(), DataType::String, true),
124        ];
125        let schema = StructType::new(fields);
126        assert_eq!(schema.fields().len(), 2);
127        assert_eq!(schema.fields()[0].name, "id");
128        assert_eq!(schema.fields()[1].name, "name");
129    }
130
131    #[test]
132    fn test_struct_type_from_polars_schema() {
133        // Create a Polars schema
134        let polars_schema = Schema::from_iter(vec![
135            Field::new("id".into(), PlDataType::Int64),
136            Field::new("name".into(), PlDataType::String),
137            Field::new("score".into(), PlDataType::Float64),
138            Field::new("active".into(), PlDataType::Boolean),
139        ]);
140
141        let struct_type = StructType::from_polars_schema(&polars_schema);
142
143        assert_eq!(struct_type.fields().len(), 4);
144        assert_eq!(struct_type.fields()[0].name, "id");
145        assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
146        assert_eq!(struct_type.fields()[1].name, "name");
147        assert!(matches!(
148            struct_type.fields()[1].data_type,
149            DataType::String
150        ));
151        assert_eq!(struct_type.fields()[2].name, "score");
152        assert!(matches!(
153            struct_type.fields()[2].data_type,
154            DataType::Double
155        ));
156        assert_eq!(struct_type.fields()[3].name, "active");
157        assert!(matches!(
158            struct_type.fields()[3].data_type,
159            DataType::Boolean
160        ));
161    }
162
163    #[test]
164    fn test_struct_type_to_polars_schema() {
165        let fields = vec![
166            StructField::new("id".to_string(), DataType::Long, false),
167            StructField::new("name".to_string(), DataType::String, true),
168            StructField::new("score".to_string(), DataType::Double, true),
169        ];
170        let struct_type = StructType::new(fields);
171
172        let polars_schema = struct_type.to_polars_schema();
173
174        assert_eq!(polars_schema.len(), 3);
175        assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
176        assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
177        assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
178    }
179
180    #[test]
181    fn test_roundtrip_schema_conversion() {
182        // Create a struct type, convert to Polars, convert back
183        let original = StructType::new(vec![
184            StructField::new("a".to_string(), DataType::Integer, true),
185            StructField::new("b".to_string(), DataType::Long, true),
186            StructField::new("c".to_string(), DataType::Double, true),
187            StructField::new("d".to_string(), DataType::Boolean, true),
188            StructField::new("e".to_string(), DataType::String, true),
189        ]);
190
191        let polars_schema = original.to_polars_schema();
192        let roundtrip = StructType::from_polars_schema(&polars_schema);
193
194        assert_eq!(roundtrip.fields().len(), original.fields().len());
195        for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
196            assert_eq!(orig.name, rt.name);
197        }
198    }
199
200    #[test]
201    fn test_polars_type_to_data_type_basic() {
202        assert!(matches!(
203            polars_type_to_data_type(&PlDataType::String),
204            DataType::String
205        ));
206        assert!(matches!(
207            polars_type_to_data_type(&PlDataType::Int64),
208            DataType::Long
209        ));
210        assert!(matches!(
211            polars_type_to_data_type(&PlDataType::Float64),
212            DataType::Double
213        ));
214        assert!(matches!(
215            polars_type_to_data_type(&PlDataType::Boolean),
216            DataType::Boolean
217        ));
218        assert!(matches!(
219            polars_type_to_data_type(&PlDataType::Date),
220            DataType::Date
221        ));
222    }
223
224    #[test]
225    fn test_polars_type_to_data_type_datetime() {
226        let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
227        assert!(matches!(
228            polars_type_to_data_type(&datetime_type),
229            DataType::Timestamp
230        ));
231    }
232
233    #[test]
234    fn test_polars_type_to_data_type_list() {
235        let list_type = PlDataType::List(Box::new(PlDataType::Int64));
236        match polars_type_to_data_type(&list_type) {
237            DataType::Array(inner) => {
238                assert!(matches!(*inner, DataType::Long));
239            }
240            _ => panic!("Expected Array type"),
241        }
242    }
243
244    #[test]
245    fn test_polars_type_to_data_type_fallback() {
246        // Unknown type should fall back to String
247        let unknown_type = PlDataType::UInt8;
248        assert!(matches!(
249            polars_type_to_data_type(&unknown_type),
250            DataType::String
251        ));
252    }
253
254    #[test]
255    fn test_data_type_to_polars_type_basic() {
256        assert_eq!(
257            data_type_to_polars_type(&DataType::String),
258            PlDataType::String
259        );
260        assert_eq!(
261            data_type_to_polars_type(&DataType::Integer),
262            PlDataType::Int32
263        );
264        assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
265        assert_eq!(
266            data_type_to_polars_type(&DataType::Double),
267            PlDataType::Float64
268        );
269        assert_eq!(
270            data_type_to_polars_type(&DataType::Boolean),
271            PlDataType::Boolean
272        );
273        assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
274    }
275
276    #[test]
277    fn test_data_type_to_polars_type_timestamp() {
278        let result = data_type_to_polars_type(&DataType::Timestamp);
279        assert!(matches!(
280            result,
281            PlDataType::Datetime(TimeUnit::Microseconds, None)
282        ));
283    }
284
285    #[test]
286    fn test_data_type_to_polars_type_array() {
287        let array_type = DataType::Array(Box::new(DataType::Long));
288        let result = data_type_to_polars_type(&array_type);
289        match result {
290            PlDataType::List(inner) => {
291                assert_eq!(*inner, PlDataType::Int64);
292            }
293            _ => panic!("Expected List type"),
294        }
295    }
296
297    #[test]
298    fn test_data_type_to_polars_type_map_fallback() {
299        // Map type falls back to String
300        let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
301        assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
302    }
303
304    #[test]
305    fn test_data_type_to_polars_type_struct_fallback() {
306        // Nested Struct falls back to String
307        let struct_type = DataType::Struct(vec![StructField::new(
308            "nested".to_string(),
309            DataType::Integer,
310            true,
311        )]);
312        assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
313    }
314
315    #[test]
316    fn test_empty_struct_type() {
317        let empty = StructType::new(vec![]);
318        assert!(empty.fields().is_empty());
319
320        let polars_schema = empty.to_polars_schema();
321        assert!(polars_schema.is_empty());
322    }
323}