delta_kernel/engine/
arrow_conversion.rs

1//! Conversions from kernel schema types to arrow schema types.
2
3use std::sync::Arc;
4
5use crate::arrow::datatypes::{
6    DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
7    SchemaRef as ArrowSchemaRef, TimeUnit,
8};
9use crate::arrow::error::ArrowError;
10use itertools::Itertools;
11
12use crate::error::Error;
13use crate::schema::{
14    ArrayType, DataType, MapType, MetadataValue, PrimitiveType, StructField, StructType,
15};
16
17pub(crate) const LIST_ARRAY_ROOT: &str = "element";
18pub(crate) const MAP_ROOT_DEFAULT: &str = "key_value";
19pub(crate) const MAP_KEY_DEFAULT: &str = "key";
20pub(crate) const MAP_VALUE_DEFAULT: &str = "value";
21
22/// Convert a kernel type into an arrow type (automatically implemented for all types that
23/// implement [`TryFromKernel`])
24pub trait TryIntoArrow<ArrowType> {
25    fn try_into_arrow(self) -> Result<ArrowType, ArrowError>;
26}
27
28/// Convert an arrow type into a kernel type (a similar [`TryIntoKernel`] trait is automatically
29/// implemented for all types that implement [`TryFromArrow`])
30pub trait TryFromArrow<ArrowType>: Sized {
31    fn try_from_arrow(t: ArrowType) -> Result<Self, ArrowError>;
32}
33
34/// Convert an arrow type into a kernel type (automatically implemented for all types that
35/// implement [`TryFromArrow`])
36pub trait TryIntoKernel<KernelType> {
37    fn try_into_kernel(self) -> Result<KernelType, ArrowError>;
38}
39
40/// Convert a kernel type into an arrow type (a similar [`TryIntoArrow`] trait is automatically
41/// implemented for all types that implement [`TryFromKernel`])
42pub trait TryFromKernel<KernelType>: Sized {
43    fn try_from_kernel(t: KernelType) -> Result<Self, ArrowError>;
44}
45
46impl<KernelType, ArrowType> TryIntoArrow<ArrowType> for KernelType
47where
48    ArrowType: TryFromKernel<KernelType>,
49{
50    fn try_into_arrow(self) -> Result<ArrowType, ArrowError> {
51        ArrowType::try_from_kernel(self)
52    }
53}
54
55impl<KernelType, ArrowType> TryIntoKernel<KernelType> for ArrowType
56where
57    KernelType: TryFromArrow<ArrowType>,
58{
59    fn try_into_kernel(self) -> Result<KernelType, ArrowError> {
60        KernelType::try_from_arrow(self)
61    }
62}
63
64impl TryFromKernel<&StructType> for ArrowSchema {
65    fn try_from_kernel(s: &StructType) -> Result<Self, ArrowError> {
66        let fields: Vec<ArrowField> = s.fields().map(|f| f.try_into_arrow()).try_collect()?;
67        Ok(ArrowSchema::new(fields))
68    }
69}
70
71impl TryFromKernel<&StructField> for ArrowField {
72    fn try_from_kernel(f: &StructField) -> Result<Self, ArrowError> {
73        let metadata = f
74            .metadata()
75            .iter()
76            .map(|(key, val)| match &val {
77                &MetadataValue::String(val) => Ok((key.clone(), val.clone())),
78                _ => Ok((key.clone(), serde_json::to_string(val)?)),
79            })
80            .collect::<Result<_, serde_json::Error>>()
81            .map_err(|err| ArrowError::JsonError(err.to_string()))?;
82
83        let field = ArrowField::new(f.name(), f.data_type().try_into_arrow()?, f.is_nullable())
84            .with_metadata(metadata);
85
86        Ok(field)
87    }
88}
89
90impl TryFromKernel<&ArrayType> for ArrowField {
91    fn try_from_kernel(a: &ArrayType) -> Result<Self, ArrowError> {
92        Ok(ArrowField::new(
93            LIST_ARRAY_ROOT,
94            a.element_type().try_into_arrow()?,
95            a.contains_null(),
96        ))
97    }
98}
99
100impl TryFromKernel<&MapType> for ArrowField {
101    fn try_from_kernel(a: &MapType) -> Result<Self, ArrowError> {
102        Ok(ArrowField::new(
103            MAP_ROOT_DEFAULT,
104            ArrowDataType::Struct(
105                vec![
106                    ArrowField::new(MAP_KEY_DEFAULT, a.key_type().try_into_arrow()?, false),
107                    ArrowField::new(
108                        MAP_VALUE_DEFAULT,
109                        a.value_type().try_into_arrow()?,
110                        a.value_contains_null(),
111                    ),
112                ]
113                .into(),
114            ),
115            false, // always non-null
116        ))
117    }
118}
119
120impl TryFromKernel<&DataType> for ArrowDataType {
121    fn try_from_kernel(t: &DataType) -> Result<Self, ArrowError> {
122        match t {
123            DataType::Primitive(p) => {
124                match p {
125                    PrimitiveType::String => Ok(ArrowDataType::Utf8),
126                    PrimitiveType::Long => Ok(ArrowDataType::Int64), // undocumented type
127                    PrimitiveType::Integer => Ok(ArrowDataType::Int32),
128                    PrimitiveType::Short => Ok(ArrowDataType::Int16),
129                    PrimitiveType::Byte => Ok(ArrowDataType::Int8),
130                    PrimitiveType::Float => Ok(ArrowDataType::Float32),
131                    PrimitiveType::Double => Ok(ArrowDataType::Float64),
132                    PrimitiveType::Boolean => Ok(ArrowDataType::Boolean),
133                    PrimitiveType::Binary => Ok(ArrowDataType::Binary),
134                    PrimitiveType::Decimal(dtype) => Ok(ArrowDataType::Decimal128(
135                        dtype.precision(),
136                        dtype.scale() as i8, // 0..=38
137                    )),
138                    PrimitiveType::Date => {
139                        // A calendar date, represented as a year-month-day triple without a
140                        // timezone. Stored as 4 bytes integer representing days since 1970-01-01
141                        Ok(ArrowDataType::Date32)
142                    }
143                    // TODO: https://github.com/delta-io/delta/issues/643
144                    PrimitiveType::Timestamp => Ok(ArrowDataType::Timestamp(
145                        TimeUnit::Microsecond,
146                        Some("UTC".into()),
147                    )),
148                    PrimitiveType::TimestampNtz => {
149                        Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None))
150                    }
151                }
152            }
153            DataType::Struct(s) => Ok(ArrowDataType::Struct(
154                s.fields()
155                    .map(TryIntoArrow::try_into_arrow)
156                    .collect::<Result<Vec<ArrowField>, ArrowError>>()?
157                    .into(),
158            )),
159            DataType::Array(a) => Ok(ArrowDataType::List(Arc::new(a.as_ref().try_into_arrow()?))),
160            DataType::Map(m) => Ok(ArrowDataType::Map(
161                Arc::new(m.as_ref().try_into_arrow()?),
162                false,
163            )),
164            DataType::Variant(s) => {
165                if *t == DataType::unshredded_variant() {
166                    Ok(ArrowDataType::Struct(
167                        s.fields()
168                            .map(TryIntoArrow::try_into_arrow)
169                            .collect::<Result<Vec<ArrowField>, ArrowError>>()?
170                            .into(),
171                    ))
172                } else {
173                    Err(ArrowError::SchemaError(format!(
174                        "Incorrect Variant Schema: {t}. Only the unshredded variant schema is supported right now."
175                    )))
176                }
177            }
178        }
179    }
180}
181
182impl TryFromArrow<&ArrowSchema> for StructType {
183    fn try_from_arrow(arrow_schema: &ArrowSchema) -> Result<Self, ArrowError> {
184        StructType::try_from_results(
185            arrow_schema
186                .fields()
187                .iter()
188                .map(|field| field.as_ref().try_into_kernel()),
189        )
190        .map_err(|e| ArrowError::from_external_error(e.into()))
191    }
192}
193
194impl TryFromArrow<ArrowSchemaRef> for StructType {
195    fn try_from_arrow(arrow_schema: ArrowSchemaRef) -> Result<Self, ArrowError> {
196        arrow_schema.as_ref().try_into_kernel()
197    }
198}
199
200impl TryFromArrow<&ArrowField> for StructField {
201    fn try_from_arrow(arrow_field: &ArrowField) -> Result<Self, ArrowError> {
202        Ok(StructField::new(
203            arrow_field.name().clone(),
204            DataType::try_from_arrow(arrow_field.data_type())?,
205            arrow_field.is_nullable(),
206        )
207        .with_metadata(arrow_field.metadata().iter().map(|(k, v)| (k.clone(), v))))
208    }
209}
210
211impl TryFromArrow<&ArrowDataType> for DataType {
212    fn try_from_arrow(arrow_datatype: &ArrowDataType) -> Result<Self, ArrowError> {
213        match arrow_datatype {
214            ArrowDataType::Utf8 => Ok(DataType::STRING),
215            ArrowDataType::LargeUtf8 => Ok(DataType::STRING),
216            ArrowDataType::Utf8View => Ok(DataType::STRING),
217            ArrowDataType::Int64 => Ok(DataType::LONG), // undocumented type
218            ArrowDataType::Int32 => Ok(DataType::INTEGER),
219            ArrowDataType::Int16 => Ok(DataType::SHORT),
220            ArrowDataType::Int8 => Ok(DataType::BYTE),
221            ArrowDataType::UInt64 => Ok(DataType::LONG), // undocumented type
222            ArrowDataType::UInt32 => Ok(DataType::INTEGER),
223            ArrowDataType::UInt16 => Ok(DataType::SHORT),
224            ArrowDataType::UInt8 => Ok(DataType::BYTE),
225            ArrowDataType::Float32 => Ok(DataType::FLOAT),
226            ArrowDataType::Float64 => Ok(DataType::DOUBLE),
227            ArrowDataType::Boolean => Ok(DataType::BOOLEAN),
228            ArrowDataType::Binary => Ok(DataType::BINARY),
229            ArrowDataType::FixedSizeBinary(_) => Ok(DataType::BINARY),
230            ArrowDataType::LargeBinary => Ok(DataType::BINARY),
231            ArrowDataType::BinaryView => Ok(DataType::BINARY),
232            ArrowDataType::Decimal128(p, s) => {
233                if *s < 0 {
234                    return Err(ArrowError::from_external_error(
235                        Error::invalid_decimal("Negative scales are not supported in Delta").into(),
236                    ));
237                };
238                DataType::decimal(*p, *s as u8)
239                    .map_err(|e| ArrowError::from_external_error(e.into()))
240            }
241            ArrowDataType::Date32 => Ok(DataType::DATE),
242            ArrowDataType::Date64 => Ok(DataType::DATE),
243            ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ),
244            ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz))
245                if tz.eq_ignore_ascii_case("utc") =>
246            {
247                Ok(DataType::TIMESTAMP)
248            }
249            ArrowDataType::Struct(fields) => DataType::try_struct_type_from_results(
250                fields.iter().map(|field| field.as_ref().try_into_kernel()),
251            )
252            .map_err(|e| ArrowError::from_external_error(e.into())),
253            ArrowDataType::List(field) => Ok(ArrayType::new(
254                (*field).data_type().try_into_kernel()?,
255                (*field).is_nullable(),
256            )
257            .into()),
258            ArrowDataType::ListView(field) => Ok(ArrayType::new(
259                (*field).data_type().try_into_kernel()?,
260                (*field).is_nullable(),
261            )
262            .into()),
263            ArrowDataType::LargeList(field) => Ok(ArrayType::new(
264                (*field).data_type().try_into_kernel()?,
265                (*field).is_nullable(),
266            )
267            .into()),
268            ArrowDataType::LargeListView(field) => Ok(ArrayType::new(
269                (*field).data_type().try_into_kernel()?,
270                (*field).is_nullable(),
271            )
272            .into()),
273            ArrowDataType::FixedSizeList(field, _) => Ok(ArrayType::new(
274                (*field).data_type().try_into_kernel()?,
275                (*field).is_nullable(),
276            )
277            .into()),
278            ArrowDataType::Map(field, _) => {
279                if let ArrowDataType::Struct(struct_fields) = field.data_type() {
280                    let key_type = DataType::try_from_arrow(struct_fields[0].data_type())?;
281                    let value_type = DataType::try_from_arrow(struct_fields[1].data_type())?;
282                    let value_type_nullable = struct_fields[1].is_nullable();
283                    Ok(MapType::new(key_type, value_type, value_type_nullable).into())
284                } else {
285                    unreachable!("DataType::Map should contain a struct field child");
286                }
287            }
288            // Dictionary types are just an optimized in-memory representation of an array.
289            // Schema-wise, they are the same as the value type.
290            ArrowDataType::Dictionary(_, value_type) => {
291                Ok(value_type.as_ref().try_into_kernel()?)
292            }
293            s => Err(ArrowError::SchemaError(format!(
294                "Invalid data type for Delta Lake: {s}"
295            ))),
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::engine::arrow_conversion::ArrowField;
304    use crate::engine::arrow_data::unshredded_variant_arrow_type;
305    use crate::{
306        schema::{DataType, StructField},
307        DeltaResult,
308    };
309    use std::collections::HashMap;
310
311    #[test]
312    fn test_metadata_string_conversion() -> DeltaResult<()> {
313        let mut metadata = HashMap::new();
314        metadata.insert("description", "hello world".to_owned());
315        let struct_field = StructField::not_null("name", DataType::STRING).with_metadata(metadata);
316
317        let arrow_field = ArrowField::try_from_kernel(&struct_field)?;
318        let new_metadata = arrow_field.metadata();
319
320        assert_eq!(
321            new_metadata.get("description").unwrap(),
322            &"hello world".to_owned()
323        );
324        Ok(())
325    }
326
327    #[test]
328    fn test_variant_shredded_type_fail() -> DeltaResult<()> {
329        let unshredded_variant = DataType::unshredded_variant();
330        let unshredded_variant_arrow = ArrowDataType::try_from_kernel(&unshredded_variant)?;
331        assert!(unshredded_variant_arrow == unshredded_variant_arrow_type());
332        let shredded_variant = DataType::variant_type([
333            StructField::nullable("metadata", DataType::BINARY),
334            StructField::nullable("value", DataType::BINARY),
335            StructField::nullable("typed_value", DataType::INTEGER),
336        ])?;
337        let shredded_variant_arrow = ArrowDataType::try_from_kernel(&shredded_variant);
338        assert!(shredded_variant_arrow
339            .unwrap_err()
340            .to_string()
341            .contains("Incorrect Variant Schema"));
342        Ok(())
343    }
344}