Skip to main content

robin_sparkless_polars/
schema_conv.rs

1//! Polars schema conversion for StructType. Kept in main crate (has Polars dependency).
2
3use polars::prelude::{DataType as PlDataType, Field, Schema, TimeUnit};
4use robin_sparkless_core::{DataType, StructField, StructType};
5
6/// Extension trait for Polars schema conversion. Implemented for [`StructType`] from core.
7/// Bring this trait into scope to use `StructType::from_polars_schema` and `to_polars_schema`.
8pub trait StructTypePolarsExt: Sized {
9    fn from_polars_schema(schema: &Schema) -> Self;
10    fn to_polars_schema(&self) -> Schema;
11}
12
13impl StructTypePolarsExt for StructType {
14    fn from_polars_schema(schema: &Schema) -> Self {
15        let fields = schema
16            .iter()
17            .map(|(name, dtype)| StructField {
18                name: name.to_string(),
19                data_type: polars_type_to_data_type(dtype),
20                nullable: true, // Polars doesn't expose nullability in the same way
21            })
22            .collect();
23        StructType::new(fields)
24    }
25
26    fn to_polars_schema(&self) -> Schema {
27        let fields: Vec<Field> = self
28            .fields()
29            .iter()
30            .map(|f| {
31                Field::new(
32                    f.name.as_str().into(),
33                    data_type_to_polars_type(&f.data_type),
34                )
35            })
36            .collect();
37        Schema::from_iter(fields)
38    }
39}
40
41fn polars_type_to_data_type(polars_type: &PlDataType) -> DataType {
42    match polars_type {
43        PlDataType::String => DataType::String,
44        PlDataType::Int32 => DataType::Integer,
45        PlDataType::Int64 => DataType::Long,
46        // Polars rank(), row_number(), dense_rank() return UInt32; map to Integer for PySpark parity (int not str).
47        PlDataType::UInt32 => DataType::Integer,
48        PlDataType::UInt64 => DataType::Long,
49        PlDataType::Float32 | PlDataType::Float64 => DataType::Double,
50        PlDataType::Boolean => DataType::Boolean,
51        PlDataType::Date => DataType::Date,
52        PlDataType::Datetime(_, _) => DataType::Timestamp,
53        PlDataType::Binary => DataType::Binary,
54        PlDataType::List(inner) => DataType::Array(Box::new(polars_type_to_data_type(inner))),
55        PlDataType::Struct(fields) => DataType::Struct(
56            fields
57                .iter()
58                .map(|f| {
59                    StructField::new(
60                        f.name().to_string(),
61                        polars_type_to_data_type(f.dtype()),
62                        true,
63                    )
64                })
65                .collect(),
66        ),
67        _ => DataType::String,
68    }
69}
70
71pub(crate) fn data_type_to_polars_type(data_type: &DataType) -> PlDataType {
72    match data_type {
73        DataType::String => PlDataType::String,
74        DataType::Integer => PlDataType::Int32,
75        DataType::Long => PlDataType::Int64,
76        DataType::Double => PlDataType::Float64,
77        DataType::Boolean => PlDataType::Boolean,
78        DataType::Date => PlDataType::Date,
79        DataType::Timestamp => PlDataType::Datetime(TimeUnit::Microseconds, None),
80        DataType::Binary => PlDataType::Binary,
81        DataType::Array(inner) => PlDataType::List(Box::new(data_type_to_polars_type(inner))),
82        DataType::Struct(fields) => PlDataType::Struct(
83            fields
84                .iter()
85                .map(|f| {
86                    Field::new(
87                        f.name.as_str().into(),
88                        data_type_to_polars_type(&f.data_type),
89                    )
90                })
91                .collect(),
92        ),
93        _ => PlDataType::String,
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use polars::prelude::{Field, Schema};
101
102    #[test]
103    fn test_struct_type_from_polars_schema() {
104        let polars_schema = Schema::from_iter(vec![
105            Field::new("id".into(), PlDataType::Int64),
106            Field::new("name".into(), PlDataType::String),
107            Field::new("score".into(), PlDataType::Float64),
108            Field::new("active".into(), PlDataType::Boolean),
109        ]);
110        let struct_type = StructType::from_polars_schema(&polars_schema);
111        assert_eq!(struct_type.fields().len(), 4);
112        assert_eq!(struct_type.fields()[0].name, "id");
113        assert!(matches!(struct_type.fields()[0].data_type, DataType::Long));
114    }
115
116    #[test]
117    fn test_struct_type_to_polars_schema() {
118        let fields = vec![
119            StructField::new("id".to_string(), DataType::Long, false),
120            StructField::new("name".to_string(), DataType::String, true),
121            StructField::new("score".to_string(), DataType::Double, true),
122        ];
123        let struct_type = StructType::new(fields);
124        let polars_schema = struct_type.to_polars_schema();
125        assert_eq!(polars_schema.len(), 3);
126        assert_eq!(polars_schema.get("id"), Some(&PlDataType::Int64));
127        assert_eq!(polars_schema.get("name"), Some(&PlDataType::String));
128    }
129
130    #[test]
131    fn test_roundtrip_schema_conversion() {
132        let original = StructType::new(vec![
133            StructField::new("a".to_string(), DataType::Integer, true),
134            StructField::new("b".to_string(), DataType::Long, true),
135            StructField::new("c".to_string(), DataType::Double, true),
136        ]);
137        let polars_schema = original.to_polars_schema();
138        let roundtrip = StructType::from_polars_schema(&polars_schema);
139        assert_eq!(roundtrip.fields().len(), original.fields().len());
140    }
141
142    /// Window rank/row_number return UInt32 in Polars; we map to Integer for PySpark parity (int in Row, not str).
143    #[test]
144    fn test_window_uint_maps_to_integer_long() {
145        let polars_schema = Schema::from_iter(vec![
146            Field::new("rn".into(), PlDataType::UInt32),
147            Field::new("rank".into(), PlDataType::UInt64),
148        ]);
149        let struct_type = StructType::from_polars_schema(&polars_schema);
150        assert_eq!(struct_type.fields().len(), 2);
151        assert!(matches!(
152            struct_type.fields()[0].data_type,
153            DataType::Integer
154        ));
155        assert!(matches!(struct_type.fields()[1].data_type, DataType::Long));
156    }
157}