Skip to main content

robin_sparkless_polars/
schema_conv.rs

1//! Polars schema conversion for StructType. Kept in main crate (has Polars dependency).
2
3use polars::prelude::{DataType as PlDataType, Field, Schema, TimeUnit};
4use robin_sparkless_core::{DataType, StructField, StructType};
5
6/// Extension trait for Polars schema conversion. Implemented for [`StructType`] from core.
7/// Bring this trait into scope to use `StructType::from_polars_schema` and `to_polars_schema`.
8pub trait StructTypePolarsExt: Sized {
9    fn from_polars_schema(schema: &Schema) -> Self;
10    fn to_polars_schema(&self) -> Schema;
11}
12
13impl StructTypePolarsExt for StructType {
14    fn from_polars_schema(schema: &Schema) -> Self {
15        let fields = schema
16            .iter()
17            .map(|(name, dtype)| StructField {
18                name: name.to_string(),
19                data_type: polars_type_to_data_type(dtype),
20                nullable: true, // Polars doesn't expose nullability in the same way
21            })
22            .collect();
23        StructType::new(fields)
24    }
25
26    fn to_polars_schema(&self) -> Schema {
27        let fields: Vec<Field> = self
28            .fields()
29            .iter()
30            .map(|f| {
31                Field::new(
32                    f.name.as_str().into(),
33                    data_type_to_polars_type(&f.data_type),
34                )
35            })
36            .collect();
37        Schema::from_iter(fields)
38    }
39}
40
41fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
42    match polars_type {
43        PlDataType::String => DataType::String,
44        PlDataType::Int32 | PlDataType::Int64 => DataType::Long,
45        PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
46        PlDataType::Boolean => DataType::Boolean,
47        PlDataType::Date => DataType::Date,
48        PlDataType::Datetime(_, _) => DataType::Timestamp,
49        PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
50        _ => DataType::String,
51    }
52}
53
54pub(super) fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
55    match data_type {
56        DataType::String => PlDataType::String,
57        DataType::Integer => PlDataType::Int32,
58        DataType::Long => PlDataType::Int64,
59        DataType::Double => PlDataType::Float64,
60        DataType::Boolean => PlDataType::Boolean,
61        DataType::Date => PlDataType::Date,
62        DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
63        DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
64        _ => PlDataType::String,
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use polars::prelude::{Field, Schema};
72
73    #[test]
74    fn test_struct_type_from_polars_schema() {
75        let polars_schema = Schema::from_iter(vec![
76            Field::new("id".into(), PlDataType::Int64),
77            Field::new("name".into(), PlDataType::String),
78            Field::new("score".into(), PlDataType::Float64),
79            Field::new("active".into(), PlDataType::Boolean),
80        ]);
81        let struct_type = StructType::from_polars_schema(&polars_schema);
82        assert_eq!(struct_type.fields().len(), 4);
83        assert_eq!(struct_type.fields()[0].name, "id");
84        assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
85    }
86
87    #[test]
88    fn test_struct_type_to_polars_schema() {
89        let fields = vec![
90            StructField::new("id".to_string(), DataType::Long, false),
91            StructField::new("name".to_string(), DataType::String, true),
92            StructField::new("score".to_string(), DataType::Double, true),
93        ];
94        let struct_type = StructType::new(fields);
95        let polars_schema = struct_type.to_polars_schema();
96        assert_eq!(polars_schema.len(), 3);
97        assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
98        assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
99    }
100
101    #[test]
102    fn test_roundtrip_schema_conversion() {
103        let original = StructType::new(vec![
104            StructField::new("a".to_string(), DataType::Integer, true),
105            StructField::new("b".to_string(), DataType::Long, true),
106            StructField::new("c".to_string(), DataType::Double, true),
107        ]);
108        let polars_schema = original.to_polars_schema();
109        let roundtrip = StructType::from_polars_schema(&polars_schema);
110        assert_eq!(roundtrip.fields().len(), original.fields().len());
111    }
112}