1use std::sync::Arc;
4
5use crate::arrow::datatypes::{
6 DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
7 SchemaRef as ArrowSchemaRef, TimeUnit,
8};
9use crate::arrow::error::ArrowError;
10use itertools::Itertools;
11
12use crate::error::Error;
13use crate::schema::{
14 ArrayType, DataType, MapType, MetadataValue, PrimitiveType, StructField, StructType,
15};
16
17pub(crate) const LIST_ARRAY_ROOT: &str = "element";
18pub(crate) const MAP_ROOT_DEFAULT: &str = "key_value";
19pub(crate) const MAP_KEY_DEFAULT: &str = "key";
20pub(crate) const MAP_VALUE_DEFAULT: &str = "value";
21
22pub trait TryIntoArrow<ArrowType> {
25 fn try_into_arrow(self) -> Result<ArrowType, ArrowError>;
26}
27
28pub trait TryFromArrow<ArrowType>: Sized {
31 fn try_from_arrow(t: ArrowType) -> Result<Self, ArrowError>;
32}
33
34pub trait TryIntoKernel<KernelType> {
37 fn try_into_kernel(self) -> Result<KernelType, ArrowError>;
38}
39
40pub trait TryFromKernel<KernelType>: Sized {
43 fn try_from_kernel(t: KernelType) -> Result<Self, ArrowError>;
44}
45
46impl<KernelType, ArrowType> TryIntoArrow<ArrowType> for KernelType
47where
48 ArrowType: TryFromKernel<KernelType>,
49{
50 fn try_into_arrow(self) -> Result<ArrowType, ArrowError> {
51 ArrowType::try_from_kernel(self)
52 }
53}
54
55impl<KernelType, ArrowType> TryIntoKernel<KernelType> for ArrowType
56where
57 KernelType: TryFromArrow<ArrowType>,
58{
59 fn try_into_kernel(self) -> Result<KernelType, ArrowError> {
60 KernelType::try_from_arrow(self)
61 }
62}
63
64impl TryFromKernel<&StructType> for ArrowSchema {
65 fn try_from_kernel(s: &StructType) -> Result<Self, ArrowError> {
66 let fields: Vec<ArrowField> = s.fields().map(|f| f.try_into_arrow()).try_collect()?;
67 Ok(ArrowSchema::new(fields))
68 }
69}
70
71impl TryFromKernel<&StructField> for ArrowField {
72 fn try_from_kernel(f: &StructField) -> Result<Self, ArrowError> {
73 let metadata = f
74 .metadata()
75 .iter()
76 .map(|(key, val)| match &val {
77 &MetadataValue::String(val) => Ok((key.clone(), val.clone())),
78 _ => Ok((key.clone(), serde_json::to_string(val)?)),
79 })
80 .collect::<Result<_, serde_json::Error>>()
81 .map_err(|err| ArrowError::JsonError(err.to_string()))?;
82
83 let field = ArrowField::new(f.name(), f.data_type().try_into_arrow()?, f.is_nullable())
84 .with_metadata(metadata);
85
86 Ok(field)
87 }
88}
89
90impl TryFromKernel<&ArrayType> for ArrowField {
91 fn try_from_kernel(a: &ArrayType) -> Result<Self, ArrowError> {
92 Ok(ArrowField::new(
93 LIST_ARRAY_ROOT,
94 a.element_type().try_into_arrow()?,
95 a.contains_null(),
96 ))
97 }
98}
99
100impl TryFromKernel<&MapType> for ArrowField {
101 fn try_from_kernel(a: &MapType) -> Result<Self, ArrowError> {
102 Ok(ArrowField::new(
103 MAP_ROOT_DEFAULT,
104 ArrowDataType::Struct(
105 vec![
106 ArrowField::new(MAP_KEY_DEFAULT, a.key_type().try_into_arrow()?, false),
107 ArrowField::new(
108 MAP_VALUE_DEFAULT,
109 a.value_type().try_into_arrow()?,
110 a.value_contains_null(),
111 ),
112 ]
113 .into(),
114 ),
115 false, ))
117 }
118}
119
120impl TryFromKernel<&DataType> for ArrowDataType {
121 fn try_from_kernel(t: &DataType) -> Result<Self, ArrowError> {
122 match t {
123 DataType::Primitive(p) => {
124 match p {
125 PrimitiveType::String => Ok(ArrowDataType::Utf8),
126 PrimitiveType::Long => Ok(ArrowDataType::Int64), PrimitiveType::Integer => Ok(ArrowDataType::Int32),
128 PrimitiveType::Short => Ok(ArrowDataType::Int16),
129 PrimitiveType::Byte => Ok(ArrowDataType::Int8),
130 PrimitiveType::Float => Ok(ArrowDataType::Float32),
131 PrimitiveType::Double => Ok(ArrowDataType::Float64),
132 PrimitiveType::Boolean => Ok(ArrowDataType::Boolean),
133 PrimitiveType::Binary => Ok(ArrowDataType::Binary),
134 PrimitiveType::Decimal(dtype) => Ok(ArrowDataType::Decimal128(
135 dtype.precision(),
136 dtype.scale() as i8, )),
138 PrimitiveType::Date => {
139 Ok(ArrowDataType::Date32)
142 }
143 PrimitiveType::Timestamp => Ok(ArrowDataType::Timestamp(
145 TimeUnit::Microsecond,
146 Some("UTC".into()),
147 )),
148 PrimitiveType::TimestampNtz => {
149 Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None))
150 }
151 }
152 }
153 DataType::Struct(s) => Ok(ArrowDataType::Struct(
154 s.fields()
155 .map(TryIntoArrow::try_into_arrow)
156 .collect::<Result<Vec<ArrowField>, ArrowError>>()?
157 .into(),
158 )),
159 DataType::Array(a) => Ok(ArrowDataType::List(Arc::new(a.as_ref().try_into_arrow()?))),
160 DataType::Map(m) => Ok(ArrowDataType::Map(
161 Arc::new(m.as_ref().try_into_arrow()?),
162 false,
163 )),
164 DataType::Variant(s) => {
165 if *t == DataType::unshredded_variant() {
166 Ok(ArrowDataType::Struct(
167 s.fields()
168 .map(TryIntoArrow::try_into_arrow)
169 .collect::<Result<Vec<ArrowField>, ArrowError>>()?
170 .into(),
171 ))
172 } else {
173 Err(ArrowError::SchemaError(format!(
174 "Incorrect Variant Schema: {t}. Only the unshredded variant schema is supported right now."
175 )))
176 }
177 }
178 }
179 }
180}
181
182impl TryFromArrow<&ArrowSchema> for StructType {
183 fn try_from_arrow(arrow_schema: &ArrowSchema) -> Result<Self, ArrowError> {
184 StructType::try_from_results(
185 arrow_schema
186 .fields()
187 .iter()
188 .map(|field| field.as_ref().try_into_kernel()),
189 )
190 .map_err(|e| ArrowError::from_external_error(e.into()))
191 }
192}
193
194impl TryFromArrow<ArrowSchemaRef> for StructType {
195 fn try_from_arrow(arrow_schema: ArrowSchemaRef) -> Result<Self, ArrowError> {
196 arrow_schema.as_ref().try_into_kernel()
197 }
198}
199
200impl TryFromArrow<&ArrowField> for StructField {
201 fn try_from_arrow(arrow_field: &ArrowField) -> Result<Self, ArrowError> {
202 Ok(StructField::new(
203 arrow_field.name().clone(),
204 DataType::try_from_arrow(arrow_field.data_type())?,
205 arrow_field.is_nullable(),
206 )
207 .with_metadata(arrow_field.metadata().iter().map(|(k, v)| (k.clone(), v))))
208 }
209}
210
211impl TryFromArrow<&ArrowDataType> for DataType {
212 fn try_from_arrow(arrow_datatype: &ArrowDataType) -> Result<Self, ArrowError> {
213 match arrow_datatype {
214 ArrowDataType::Utf8 => Ok(DataType::STRING),
215 ArrowDataType::LargeUtf8 => Ok(DataType::STRING),
216 ArrowDataType::Utf8View => Ok(DataType::STRING),
217 ArrowDataType::Int64 => Ok(DataType::LONG), ArrowDataType::Int32 => Ok(DataType::INTEGER),
219 ArrowDataType::Int16 => Ok(DataType::SHORT),
220 ArrowDataType::Int8 => Ok(DataType::BYTE),
221 ArrowDataType::UInt64 => Ok(DataType::LONG), ArrowDataType::UInt32 => Ok(DataType::INTEGER),
223 ArrowDataType::UInt16 => Ok(DataType::SHORT),
224 ArrowDataType::UInt8 => Ok(DataType::BYTE),
225 ArrowDataType::Float32 => Ok(DataType::FLOAT),
226 ArrowDataType::Float64 => Ok(DataType::DOUBLE),
227 ArrowDataType::Boolean => Ok(DataType::BOOLEAN),
228 ArrowDataType::Binary => Ok(DataType::BINARY),
229 ArrowDataType::FixedSizeBinary(_) => Ok(DataType::BINARY),
230 ArrowDataType::LargeBinary => Ok(DataType::BINARY),
231 ArrowDataType::BinaryView => Ok(DataType::BINARY),
232 ArrowDataType::Decimal128(p, s) => {
233 if *s < 0 {
234 return Err(ArrowError::from_external_error(
235 Error::invalid_decimal("Negative scales are not supported in Delta").into(),
236 ));
237 };
238 DataType::decimal(*p, *s as u8)
239 .map_err(|e| ArrowError::from_external_error(e.into()))
240 }
241 ArrowDataType::Date32 => Ok(DataType::DATE),
242 ArrowDataType::Date64 => Ok(DataType::DATE),
243 ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ),
244 ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz))
245 if tz.eq_ignore_ascii_case("utc") =>
246 {
247 Ok(DataType::TIMESTAMP)
248 }
249 ArrowDataType::Struct(fields) => DataType::try_struct_type_from_results(
250 fields.iter().map(|field| field.as_ref().try_into_kernel()),
251 )
252 .map_err(|e| ArrowError::from_external_error(e.into())),
253 ArrowDataType::List(field) => Ok(ArrayType::new(
254 (*field).data_type().try_into_kernel()?,
255 (*field).is_nullable(),
256 )
257 .into()),
258 ArrowDataType::ListView(field) => Ok(ArrayType::new(
259 (*field).data_type().try_into_kernel()?,
260 (*field).is_nullable(),
261 )
262 .into()),
263 ArrowDataType::LargeList(field) => Ok(ArrayType::new(
264 (*field).data_type().try_into_kernel()?,
265 (*field).is_nullable(),
266 )
267 .into()),
268 ArrowDataType::LargeListView(field) => Ok(ArrayType::new(
269 (*field).data_type().try_into_kernel()?,
270 (*field).is_nullable(),
271 )
272 .into()),
273 ArrowDataType::FixedSizeList(field, _) => Ok(ArrayType::new(
274 (*field).data_type().try_into_kernel()?,
275 (*field).is_nullable(),
276 )
277 .into()),
278 ArrowDataType::Map(field, _) => {
279 if let ArrowDataType::Struct(struct_fields) = field.data_type() {
280 let key_type = DataType::try_from_arrow(struct_fields[0].data_type())?;
281 let value_type = DataType::try_from_arrow(struct_fields[1].data_type())?;
282 let value_type_nullable = struct_fields[1].is_nullable();
283 Ok(MapType::new(key_type, value_type, value_type_nullable).into())
284 } else {
285 unreachable!("DataType::Map should contain a struct field child");
286 }
287 }
288 ArrowDataType::Dictionary(_, value_type) => {
291 Ok(value_type.as_ref().try_into_kernel()?)
292 }
293 s => Err(ArrowError::SchemaError(format!(
294 "Invalid data type for Delta Lake: {s}"
295 ))),
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::engine::arrow_conversion::ArrowField;
304 use crate::engine::arrow_data::unshredded_variant_arrow_type;
305 use crate::{
306 schema::{DataType, StructField},
307 DeltaResult,
308 };
309 use std::collections::HashMap;
310
311 #[test]
312 fn test_metadata_string_conversion() -> DeltaResult<()> {
313 let mut metadata = HashMap::new();
314 metadata.insert("description", "hello world".to_owned());
315 let struct_field = StructField::not_null("name", DataType::STRING).with_metadata(metadata);
316
317 let arrow_field = ArrowField::try_from_kernel(&struct_field)?;
318 let new_metadata = arrow_field.metadata();
319
320 assert_eq!(
321 new_metadata.get("description").unwrap(),
322 &"hello world".to_owned()
323 );
324 Ok(())
325 }
326
327 #[test]
328 fn test_variant_shredded_type_fail() -> DeltaResult<()> {
329 let unshredded_variant = DataType::unshredded_variant();
330 let unshredded_variant_arrow = ArrowDataType::try_from_kernel(&unshredded_variant)?;
331 assert!(unshredded_variant_arrow == unshredded_variant_arrow_type());
332 let shredded_variant = DataType::variant_type([
333 StructField::nullable("metadata", DataType::BINARY),
334 StructField::nullable("value", DataType::BINARY),
335 StructField::nullable("typed_value", DataType::INTEGER),
336 ])?;
337 let shredded_variant_arrow = ArrowDataType::try_from_kernel(&shredded_variant);
338 assert!(shredded_variant_arrow
339 .unwrap_err()
340 .to_string()
341 .contains("Incorrect Variant Schema"));
342 Ok(())
343 }
344}