Skip to main content

robin_sparkless/
schema.rs

1use polars::prelude::{DataType as PlDataType, Schema, TimeUnit};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub enum DataType {
6    String,
7    Integer,
8    Long,
9    Double,
10    Boolean,
11    Date,
12    Timestamp,
13    Array(Box<DataType>),
14    Map(Box<DataType>, Box<DataType>),
15    Struct(Vec<StructField>),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct StructField {
20    pub name: String,
21    pub data_type: DataType,
22    pub nullable: bool,
23}
24
25impl StructField {
26    pub fn new(name: String, data_type: DataType, nullable: bool) -> Self {
27        StructField {
28            name,
29            data_type,
30            nullable,
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct StructType {
37    fields: Vec<StructField>,
38}
39
40impl StructType {
41    pub fn new(fields: Vec<StructField>) -> Self {
42        StructType { fields }
43    }
44
45    pub fn from_polars_schema(schema: &Schema) -> Self {
46        let fields = schema
47            .iter()
48            .map(|(name, dtype)| StructField {
49                name: name.to_string(),
50                data_type: polars_type_to_data_type(dtype),
51                nullable: true, // Polars doesn't expose nullability in the same way
52            })
53            .collect();
54        StructType { fields }
55    }
56
57    pub fn to_polars_schema(&self) -> Schema {
58        use polars::prelude::Field;
59        let fields: Vec<Field> = self
60            .fields
61            .iter()
62            .map(|f| {
63                Field::new(
64                    f.name.as_str().into(),
65                    data_type_to_polars_type(&f.data_type),
66                )
67            })
68            .collect();
69        Schema::from_iter(fields)
70    }
71
72    pub fn fields(&self) -> &[StructField] {
73        &self.fields
74    }
75
76    /// Serialize the schema to a JSON string (array of field objects with name, data_type, nullable).
77    /// Useful for bindings that need to expose schema to the host without Polars types.
78    pub fn to_json(&self) -> Result<String, serde_json::Error> {
79        serde_json::to_string(self)
80    }
81
82    /// Serialize the schema to a pretty-printed JSON string.
83    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
84        serde_json::to_string_pretty(self)
85    }
86}
87
88/// Parse a schema from a JSON string (e.g. from a host binding).
89/// The JSON must match the serialization of [`StructType`] (e.g. from [`StructType::to_json`]).
90pub fn schema_from_json(json: &str) -> Result<StructType, crate::error::EngineError> {
91    serde_json::from_str(json).map_err(crate::error::EngineError::from)
92}
93
94fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
95    match polars_type {
96        PlDataType::String => DataType::String,
97        // Spark/Sparkless inferSchema uses 64-bit for integral types; map Int32 as Long.
98        PlDataType::Int32 | PlDataType::Int64 => DataType::Long,
99        // Map both Float32 and Float64 to Double for schema parity.
100        PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
101        PlDataType::Boolean => DataType::Boolean,
102        PlDataType::Date => DataType::Date,
103        PlDataType::Datetime(_, _) => DataType::Timestamp,
104        PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
105        _ => DataType::String, // Default fallback
106    }
107}
108
109fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
110    match data_type {
111        DataType::String => PlDataType::String,
112        DataType::Integer => PlDataType::Int32,
113        DataType::Long => PlDataType::Int64,
114        DataType::Double => PlDataType::Float64,
115        DataType::Boolean => PlDataType::Boolean,
116        DataType::Date => PlDataType::Date,
117        DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
118        DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
119        _ => PlDataType::String, // Default fallback
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use polars::prelude::{Field, Schema};
127
128    #[test]
129    fn test_struct_field_new() {
130        let field = StructField::new("age".to_string(), DataType::Integer, true);
131        assert_eq!(field.name, "age");
132        assert!(field.nullable);
133        assert!(matches!(field.data_type, DataType::Integer));
134    }
135
136    #[test]
137    fn test_struct_type_new() {
138        let fields = vec![
139            StructField::new("id".to_string(), DataType::Long, false),
140            StructField::new("name".to_string(), DataType::String, true),
141        ];
142        let schema = StructType::new(fields);
143        assert_eq!(schema.fields().len(), 2);
144        assert_eq!(schema.fields()[0].name, "id");
145        assert_eq!(schema.fields()[1].name, "name");
146    }
147
148    #[test]
149    fn test_struct_type_from_polars_schema() {
150        // Create a Polars schema
151        let polars_schema = Schema::from_iter(vec![
152            Field::new("id".into(), PlDataType::Int64),
153            Field::new("name".into(), PlDataType::String),
154            Field::new("score".into(), PlDataType::Float64),
155            Field::new("active".into(), PlDataType::Boolean),
156        ]);
157
158        let struct_type = StructType::from_polars_schema(&polars_schema);
159
160        assert_eq!(struct_type.fields().len(), 4);
161        assert_eq!(struct_type.fields()[0].name, "id");
162        assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
163        assert_eq!(struct_type.fields()[1].name, "name");
164        assert!(matches!(
165            struct_type.fields()[1].data_type,
166            DataType::String
167        ));
168        assert_eq!(struct_type.fields()[2].name, "score");
169        assert!(matches!(
170            struct_type.fields()[2].data_type,
171            DataType::Double
172        ));
173        assert_eq!(struct_type.fields()[3].name, "active");
174        assert!(matches!(
175            struct_type.fields()[3].data_type,
176            DataType::Boolean
177        ));
178    }
179
180    #[test]
181    fn test_struct_type_to_polars_schema() {
182        let fields = vec![
183            StructField::new("id".to_string(), DataType::Long, false),
184            StructField::new("name".to_string(), DataType::String, true),
185            StructField::new("score".to_string(), DataType::Double, true),
186        ];
187        let struct_type = StructType::new(fields);
188
189        let polars_schema = struct_type.to_polars_schema();
190
191        assert_eq!(polars_schema.len(), 3);
192        assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
193        assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
194        assert_eq!(polars_schema.get("score"), Some(&PlDataType::Float64));
195    }
196
197    #[test]
198    fn test_roundtrip_schema_conversion() {
199        // Create a struct type, convert to Polars, convert back
200        let original = StructType::new(vec![
201            StructField::new("a".to_string(), DataType::Integer, true),
202            StructField::new("b".to_string(), DataType::Long, true),
203            StructField::new("c".to_string(), DataType::Double, true),
204            StructField::new("d".to_string(), DataType::Boolean, true),
205            StructField::new("e".to_string(), DataType::String, true),
206        ]);
207
208        let polars_schema = original.to_polars_schema();
209        let roundtrip = StructType::from_polars_schema(&polars_schema);
210
211        assert_eq!(roundtrip.fields().len(), original.fields().len());
212        for (orig, rt) in original.fields().iter().zip(roundtrip.fields().iter()) {
213            assert_eq!(orig.name, rt.name);
214        }
215    }
216
217    #[test]
218    fn test_struct_type_to_json() {
219        let fields = vec![
220            StructField::new("id".to_string(), DataType::Long, false),
221            StructField::new("name".to_string(), DataType::String, true),
222        ];
223        let schema = StructType::new(fields);
224        let json = schema.to_json().unwrap();
225        assert!(json.contains("\"name\":\"id\""));
226        assert!(json.contains("\"name\":\"name\""));
227        assert!(json.contains("\"data_type\""));
228        assert!(json.contains("\"nullable\""));
229        let _parsed: StructType = serde_json::from_str(&json).unwrap();
230        let pretty = schema.to_json_pretty().unwrap();
231        assert!(pretty.contains('\n'));
232    }
233
234    #[test]
235    fn test_polars_type_to_data_type_basic() {
236        assert!(matches!(
237            polars_type_to_data_type(&PlDataType::String),
238            DataType::String
239        ));
240        assert!(matches!(
241            polars_type_to_data_type(&PlDataType::Int64),
242            DataType::Long
243        ));
244        assert!(matches!(
245            polars_type_to_data_type(&PlDataType::Float64),
246            DataType::Double
247        ));
248        assert!(matches!(
249            polars_type_to_data_type(&PlDataType::Boolean),
250            DataType::Boolean
251        ));
252        assert!(matches!(
253            polars_type_to_data_type(&PlDataType::Date),
254            DataType::Date
255        ));
256    }
257
258    #[test]
259    fn test_polars_type_to_data_type_datetime() {
260        let datetime_type = PlDataType::Datetime(TimeUnit::Microseconds, None);
261        assert!(matches!(
262            polars_type_to_data_type(&datetime_type),
263            DataType::Timestamp
264        ));
265    }
266
267    #[test]
268    fn test_polars_type_to_data_type_list() {
269        let list_type = PlDataType::List(Box::new(PlDataType::Int64));
270        match polars_type_to_data_type(&list_type) {
271            DataType::Array(inner) => {
272                assert!(matches!(*inner, DataType::Long));
273            }
274            other => panic!("Expected Array type, got {other:?}"),
275        }
276    }
277
278    #[test]
279    fn test_polars_type_to_data_type_fallback() {
280        // Unknown type should fall back to String
281        let unknown_type = PlDataType::UInt8;
282        assert!(matches!(
283            polars_type_to_data_type(&unknown_type),
284            DataType::String
285        ));
286    }
287
288    #[test]
289    fn test_data_type_to_polars_type_basic() {
290        assert_eq!(
291            data_type_to_polars_type(&DataType::String),
292            PlDataType::String
293        );
294        assert_eq!(
295            data_type_to_polars_type(&DataType::Integer),
296            PlDataType::Int32
297        );
298        assert_eq!(data_type_to_polars_type(&DataType::Long), PlDataType::Int64);
299        assert_eq!(
300            data_type_to_polars_type(&DataType::Double),
301            PlDataType::Float64
302        );
303        assert_eq!(
304            data_type_to_polars_type(&DataType::Boolean),
305            PlDataType::Boolean
306        );
307        assert_eq!(data_type_to_polars_type(&DataType::Date), PlDataType::Date);
308    }
309
310    #[test]
311    fn test_data_type_to_polars_type_timestamp() {
312        let result = data_type_to_polars_type(&DataType::Timestamp);
313        assert!(matches!(
314            result,
315            PlDataType::Datetime(TimeUnit::Microseconds, None)
316        ));
317    }
318
319    #[test]
320    fn test_data_type_to_polars_type_array() {
321        let array_type = DataType::Array(Box::new(DataType::Long));
322        let result = data_type_to_polars_type(&array_type);
323        match result {
324            PlDataType::List(inner) => {
325                assert_eq!(*inner, PlDataType::Int64);
326            }
327            other => panic!("Expected List type, got {other:?}"),
328        }
329    }
330
331    #[test]
332    fn test_data_type_to_polars_type_map_fallback() {
333        // Map type falls back to String
334        let map_type = DataType::Map(Box::new(DataType::String), Box::new(DataType::Long));
335        assert_eq!(data_type_to_polars_type(&map_type), PlDataType::String);
336    }
337
338    #[test]
339    fn test_data_type_to_polars_type_struct_fallback() {
340        // Nested Struct falls back to String
341        let struct_type = DataType::Struct(vec![StructField::new(
342            "nested".to_string(),
343            DataType::Integer,
344            true,
345        )]);
346        assert_eq!(data_type_to_polars_type(&struct_type), PlDataType::String);
347    }
348
349    #[test]
350    fn test_empty_struct_type() {
351        let empty = StructType::new(vec![]);
352        assert!(empty.fields().is_empty());
353
354        let polars_schema = empty.to_polars_schema();
355        assert!(polars_schema.is_empty());
356    }
357}