datafusion_python/common/
data_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::arrow::array::Array;
19use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
20use datafusion::common::ScalarValue;
21use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
22use pyo3::exceptions::PyNotImplementedError;
23use pyo3::{exceptions::PyValueError, prelude::*};
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
26pub struct PyScalarValue(pub ScalarValue);
27
28impl From<ScalarValue> for PyScalarValue {
29    fn from(value: ScalarValue) -> Self {
30        Self(value)
31    }
32}
33impl From<PyScalarValue> for ScalarValue {
34    fn from(value: PyScalarValue) -> Self {
35        value.0
36    }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")]
41pub enum RexType {
42    Alias,
43    Literal,
44    Call,
45    Reference,
46    ScalarSubquery,
47    Other,
48}
49
50/// These bindings are tying together several disparate systems.
51/// You have SQL types for the SQL strings and RDBMS systems itself.
52/// Rust types for the DataFusion code
53/// Arrow types which represents the underlying arrow format
54/// Python types which represent the type in Python
55/// It is important to keep all of those types in a single
56/// and manageable location. Therefore this structure exists
57/// to map those types and provide a simple place for developers
58/// to map types from one system to another.
59#[derive(Debug, Clone)]
60#[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)]
61pub struct DataTypeMap {
62    #[pyo3(get, set)]
63    pub arrow_type: PyDataType,
64    #[pyo3(get, set)]
65    pub python_type: PythonType,
66    #[pyo3(get, set)]
67    pub sql_type: SqlType,
68}
69
70impl DataTypeMap {
71    fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self {
72        DataTypeMap {
73            arrow_type: PyDataType {
74                data_type: arrow_type,
75            },
76            python_type,
77            sql_type,
78        }
79    }
80
81    pub fn map_from_arrow_type(arrow_type: &DataType) -> Result<DataTypeMap, PyErr> {
82        match arrow_type {
83            DataType::Null => Ok(DataTypeMap::new(
84                DataType::Null,
85                PythonType::None,
86                SqlType::NULL,
87            )),
88            DataType::Boolean => Ok(DataTypeMap::new(
89                DataType::Boolean,
90                PythonType::Bool,
91                SqlType::BOOLEAN,
92            )),
93            DataType::Int8 => Ok(DataTypeMap::new(
94                DataType::Int8,
95                PythonType::Int,
96                SqlType::TINYINT,
97            )),
98            DataType::Int16 => Ok(DataTypeMap::new(
99                DataType::Int16,
100                PythonType::Int,
101                SqlType::SMALLINT,
102            )),
103            DataType::Int32 => Ok(DataTypeMap::new(
104                DataType::Int32,
105                PythonType::Int,
106                SqlType::INTEGER,
107            )),
108            DataType::Int64 => Ok(DataTypeMap::new(
109                DataType::Int64,
110                PythonType::Int,
111                SqlType::BIGINT,
112            )),
113            DataType::UInt8 => Ok(DataTypeMap::new(
114                DataType::UInt8,
115                PythonType::Int,
116                SqlType::TINYINT,
117            )),
118            DataType::UInt16 => Ok(DataTypeMap::new(
119                DataType::UInt16,
120                PythonType::Int,
121                SqlType::SMALLINT,
122            )),
123            DataType::UInt32 => Ok(DataTypeMap::new(
124                DataType::UInt32,
125                PythonType::Int,
126                SqlType::INTEGER,
127            )),
128            DataType::UInt64 => Ok(DataTypeMap::new(
129                DataType::UInt64,
130                PythonType::Int,
131                SqlType::BIGINT,
132            )),
133            DataType::Float16 => Ok(DataTypeMap::new(
134                DataType::Float16,
135                PythonType::Float,
136                SqlType::FLOAT,
137            )),
138            DataType::Float32 => Ok(DataTypeMap::new(
139                DataType::Float32,
140                PythonType::Float,
141                SqlType::FLOAT,
142            )),
143            DataType::Float64 => Ok(DataTypeMap::new(
144                DataType::Float64,
145                PythonType::Float,
146                SqlType::FLOAT,
147            )),
148            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
149                DataType::Timestamp(*unit, tz.clone()),
150                PythonType::Datetime,
151                SqlType::DATE,
152            )),
153            DataType::Date32 => Ok(DataTypeMap::new(
154                DataType::Date32,
155                PythonType::Datetime,
156                SqlType::DATE,
157            )),
158            DataType::Date64 => Ok(DataTypeMap::new(
159                DataType::Date64,
160                PythonType::Datetime,
161                SqlType::DATE,
162            )),
163            DataType::Time32(unit) => Ok(DataTypeMap::new(
164                DataType::Time32(*unit),
165                PythonType::Datetime,
166                SqlType::DATE,
167            )),
168            DataType::Time64(unit) => Ok(DataTypeMap::new(
169                DataType::Time64(*unit),
170                PythonType::Datetime,
171                SqlType::DATE,
172            )),
173            DataType::Duration(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
174            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
175                DataType::Interval(*interval_unit),
176                PythonType::Datetime,
177                match interval_unit {
178                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
179                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
180                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
181                },
182            )),
183            DataType::Binary => Ok(DataTypeMap::new(
184                DataType::Binary,
185                PythonType::Bytes,
186                SqlType::BINARY,
187            )),
188            DataType::FixedSizeBinary(_) => {
189                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
190            }
191            DataType::LargeBinary => Ok(DataTypeMap::new(
192                DataType::LargeBinary,
193                PythonType::Bytes,
194                SqlType::BINARY,
195            )),
196            DataType::Utf8 => Ok(DataTypeMap::new(
197                DataType::Utf8,
198                PythonType::Str,
199                SqlType::VARCHAR,
200            )),
201            DataType::LargeUtf8 => Ok(DataTypeMap::new(
202                DataType::LargeUtf8,
203                PythonType::Str,
204                SqlType::VARCHAR,
205            )),
206            DataType::List(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
207            DataType::FixedSizeList(_, _) => {
208                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
209            }
210            DataType::LargeList(_) => {
211                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
212            }
213            DataType::Struct(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
214            DataType::Union(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
215            DataType::Dictionary(_, _) => {
216                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
217            }
218            DataType::Decimal32(precision, scale) => Ok(DataTypeMap::new(
219                DataType::Decimal32(*precision, *scale),
220                PythonType::Float,
221                SqlType::DECIMAL,
222            )),
223            DataType::Decimal64(precision, scale) => Ok(DataTypeMap::new(
224                DataType::Decimal64(*precision, *scale),
225                PythonType::Float,
226                SqlType::DECIMAL,
227            )),
228            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
229                DataType::Decimal128(*precision, *scale),
230                PythonType::Float,
231                SqlType::DECIMAL,
232            )),
233            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
234                DataType::Decimal256(*precision, *scale),
235                PythonType::Float,
236                SqlType::DECIMAL,
237            )),
238            DataType::Map(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
239            DataType::RunEndEncoded(_, _) => {
240                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
241            }
242            DataType::BinaryView => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
243            DataType::Utf8View => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
244            DataType::ListView(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
245            DataType::LargeListView(_) => {
246                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
247            }
248        }
249    }
250
251    /// Generate the `DataTypeMap` from a `ScalarValue` instance
252    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
253        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
254    }
255
256    /// Maps a `ScalarValue` to an Arrow `DataType`
257    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
258        match scalar_val {
259            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
260            ScalarValue::Float16(_) => Ok(DataType::Float16),
261            ScalarValue::Float32(_) => Ok(DataType::Float32),
262            ScalarValue::Float64(_) => Ok(DataType::Float64),
263            ScalarValue::Decimal128(_, precision, scale) => {
264                Ok(DataType::Decimal128(*precision, *scale))
265            }
266            ScalarValue::Decimal256(_, precision, scale) => {
267                Ok(DataType::Decimal256(*precision, *scale))
268            }
269            ScalarValue::Dictionary(data_type, scalar_type) => {
270                // Call this function again to map the dictionary scalar_value to an Arrow type
271                Ok(DataType::Dictionary(
272                    Box::new(*data_type.clone()),
273                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
274                ))
275            }
276            ScalarValue::Int8(_) => Ok(DataType::Int8),
277            ScalarValue::Int16(_) => Ok(DataType::Int16),
278            ScalarValue::Int32(_) => Ok(DataType::Int32),
279            ScalarValue::Int64(_) => Ok(DataType::Int64),
280            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
281            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
282            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
283            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
284            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
285            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
286            ScalarValue::Binary(_) => Ok(DataType::Binary),
287            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
288            ScalarValue::Date32(_) => Ok(DataType::Date32),
289            ScalarValue::Date64(_) => Ok(DataType::Date64),
290            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
291            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
292            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
293            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
294            ScalarValue::Null => Ok(DataType::Null),
295            ScalarValue::TimestampSecond(_, tz) => {
296                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
297            }
298            ScalarValue::TimestampMillisecond(_, tz) => {
299                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
300            }
301            ScalarValue::TimestampMicrosecond(_, tz) => {
302                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
303            }
304            ScalarValue::TimestampNanosecond(_, tz) => {
305                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
306            }
307            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
308            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
309            ScalarValue::IntervalMonthDayNano(..) => {
310                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
311            }
312            ScalarValue::List(arr) => Ok(arr.data_type().to_owned()),
313            ScalarValue::Struct(_fields) => Err(PyNotImplementedError::new_err(
314                "ScalarValue::Struct".to_string(),
315            )),
316            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
317            ScalarValue::FixedSizeList(_array_ref) => {
318                // The FieldRef was removed from ScalarValue::FixedSizeList in
319                // https://github.com/apache/arrow-datafusion/pull/8221, so we can no
320                // longer convert back to a DataType here
321                Err(PyNotImplementedError::new_err(
322                    "ScalarValue::FixedSizeList".to_string(),
323                ))
324            }
325            ScalarValue::LargeList(_) => Err(PyNotImplementedError::new_err(
326                "ScalarValue::LargeList".to_string(),
327            )),
328            ScalarValue::DurationSecond(_) => Ok(DataType::Duration(TimeUnit::Second)),
329            ScalarValue::DurationMillisecond(_) => Ok(DataType::Duration(TimeUnit::Millisecond)),
330            ScalarValue::DurationMicrosecond(_) => Ok(DataType::Duration(TimeUnit::Microsecond)),
331            ScalarValue::DurationNanosecond(_) => Ok(DataType::Duration(TimeUnit::Nanosecond)),
332            ScalarValue::Union(_, _, _) => Err(PyNotImplementedError::new_err(
333                "ScalarValue::LargeList".to_string(),
334            )),
335            ScalarValue::Utf8View(_) => Ok(DataType::Utf8View),
336            ScalarValue::BinaryView(_) => Ok(DataType::BinaryView),
337            ScalarValue::Map(_) => Err(PyNotImplementedError::new_err(
338                "ScalarValue::Map".to_string(),
339            )),
340        }
341    }
342}
343
344#[pymethods]
345impl DataTypeMap {
346    #[new]
347    pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self {
348        DataTypeMap {
349            arrow_type,
350            python_type,
351            sql_type,
352        }
353    }
354
355    #[staticmethod]
356    #[pyo3(name = "from_parquet_type_str")]
357    /// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
358    /// with a String type for schema rather than an object type. Here we make a best effort
359    /// to convert that to a physical type.
360    pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
361        let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
362            "boolean" => Ok(DataType::Boolean),
363            "int32" => Ok(DataType::Int32),
364            "int64" => Ok(DataType::Int64),
365            "int96" => {
366                // Int96 is an old parquet datatype that is now deprecated. We convert to nanosecond timestamp
367                Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
368            }
369            "float" => Ok(DataType::Float32),
370            "double" => Ok(DataType::Float64),
371            "byte_array" => Ok(DataType::Utf8),
372            _ => Err(PyValueError::new_err(format!(
373                "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}"
374            ))),
375        };
376        DataTypeMap::map_from_arrow_type(&arrow_dtype?)
377    }
378
379    #[staticmethod]
380    #[pyo3(name = "arrow")]
381    pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
382        DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
383    }
384
385    #[staticmethod]
386    #[pyo3(name = "arrow_str")]
387    pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
388        let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
389        DataTypeMap::map_from_arrow_type(&data_type?.data_type)
390    }
391
392    #[staticmethod]
393    #[pyo3(name = "sql")]
394    pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
395        match sql_type {
396            SqlType::ANY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
397            SqlType::ARRAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
398            SqlType::BIGINT => Ok(DataTypeMap::new(
399                DataType::Int64,
400                PythonType::Int,
401                SqlType::BIGINT,
402            )),
403            SqlType::BINARY => Ok(DataTypeMap::new(
404                DataType::Binary,
405                PythonType::Bytes,
406                SqlType::BINARY,
407            )),
408            SqlType::BOOLEAN => Ok(DataTypeMap::new(
409                DataType::Boolean,
410                PythonType::Bool,
411                SqlType::BOOLEAN,
412            )),
413            SqlType::CHAR => Ok(DataTypeMap::new(
414                DataType::UInt8,
415                PythonType::Int,
416                SqlType::CHAR,
417            )),
418            SqlType::COLUMN_LIST => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
419            SqlType::CURSOR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
420            SqlType::DATE => Ok(DataTypeMap::new(
421                DataType::Date64,
422                PythonType::Datetime,
423                SqlType::DATE,
424            )),
425            SqlType::DECIMAL => Ok(DataTypeMap::new(
426                DataType::Decimal128(1, 1),
427                PythonType::Float,
428                SqlType::DECIMAL,
429            )),
430            SqlType::DISTINCT => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
431            SqlType::DOUBLE => Ok(DataTypeMap::new(
432                DataType::Decimal256(1, 1),
433                PythonType::Float,
434                SqlType::DOUBLE,
435            )),
436            SqlType::DYNAMIC_STAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
437            SqlType::FLOAT => Ok(DataTypeMap::new(
438                DataType::Decimal128(1, 1),
439                PythonType::Float,
440                SqlType::FLOAT,
441            )),
442            SqlType::GEOMETRY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
443            SqlType::INTEGER => Ok(DataTypeMap::new(
444                DataType::Int8,
445                PythonType::Int,
446                SqlType::INTEGER,
447            )),
448            SqlType::INTERVAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
449            SqlType::INTERVAL_DAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
450            SqlType::INTERVAL_DAY_HOUR => {
451                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
452            }
453            SqlType::INTERVAL_DAY_MINUTE => {
454                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
455            }
456            SqlType::INTERVAL_DAY_SECOND => {
457                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
458            }
459            SqlType::INTERVAL_HOUR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
460            SqlType::INTERVAL_HOUR_MINUTE => {
461                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
462            }
463            SqlType::INTERVAL_HOUR_SECOND => {
464                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
465            }
466            SqlType::INTERVAL_MINUTE => {
467                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
468            }
469            SqlType::INTERVAL_MINUTE_SECOND => {
470                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
471            }
472            SqlType::INTERVAL_MONTH => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
473            SqlType::INTERVAL_SECOND => {
474                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
475            }
476            SqlType::INTERVAL_YEAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
477            SqlType::INTERVAL_YEAR_MONTH => {
478                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
479            }
480            SqlType::MAP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
481            SqlType::MULTISET => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
482            SqlType::NULL => Ok(DataTypeMap::new(
483                DataType::Null,
484                PythonType::None,
485                SqlType::NULL,
486            )),
487            SqlType::OTHER => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
488            SqlType::REAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
489            SqlType::ROW => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
490            SqlType::SARG => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
491            SqlType::SMALLINT => Ok(DataTypeMap::new(
492                DataType::Int16,
493                PythonType::Int,
494                SqlType::SMALLINT,
495            )),
496            SqlType::STRUCTURED => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
497            SqlType::SYMBOL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
498            SqlType::TIME => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
499            SqlType::TIME_WITH_LOCAL_TIME_ZONE => {
500                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
501            }
502            SqlType::TIMESTAMP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
503            SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => {
504                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
505            }
506            SqlType::TINYINT => Ok(DataTypeMap::new(
507                DataType::Int8,
508                PythonType::Int,
509                SqlType::TINYINT,
510            )),
511            SqlType::UNKNOWN => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
512            SqlType::VARBINARY => Ok(DataTypeMap::new(
513                DataType::LargeBinary,
514                PythonType::Bytes,
515                SqlType::VARBINARY,
516            )),
517            SqlType::VARCHAR => Ok(DataTypeMap::new(
518                DataType::Utf8,
519                PythonType::Str,
520                SqlType::VARCHAR,
521            )),
522        }
523    }
524
525    /// Unfortunately PyO3 does not allow for us to expose the DataType as an enum since
526    /// we cannot directly annotate the Enum instance of dependency code. Therefore, here
527    /// we provide an enum to mimic it.
528    #[pyo3(name = "friendly_arrow_type_name")]
529    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
530        Ok(match &self.arrow_type.data_type {
531            DataType::Null => "Null",
532            DataType::Boolean => "Boolean",
533            DataType::Int8 => "Int8",
534            DataType::Int16 => "Int16",
535            DataType::Int32 => "Int32",
536            DataType::Int64 => "Int64",
537            DataType::UInt8 => "UInt8",
538            DataType::UInt16 => "UInt16",
539            DataType::UInt32 => "UInt32",
540            DataType::UInt64 => "UInt64",
541            DataType::Float16 => "Float16",
542            DataType::Float32 => "Float32",
543            DataType::Float64 => "Float64",
544            DataType::Timestamp(_, _) => "Timestamp",
545            DataType::Date32 => "Date32",
546            DataType::Date64 => "Date64",
547            DataType::Time32(_) => "Time32",
548            DataType::Time64(_) => "Time64",
549            DataType::Duration(_) => "Duration",
550            DataType::Interval(_) => "Interval",
551            DataType::Binary => "Binary",
552            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
553            DataType::LargeBinary => "LargeBinary",
554            DataType::Utf8 => "Utf8",
555            DataType::LargeUtf8 => "LargeUtf8",
556            DataType::List(_) => "List",
557            DataType::FixedSizeList(_, _) => "FixedSizeList",
558            DataType::LargeList(_) => "LargeList",
559            DataType::Struct(_) => "Struct",
560            DataType::Union(_, _) => "Union",
561            DataType::Dictionary(_, _) => "Dictionary",
562            DataType::Decimal32(_, _) => "Decimal32",
563            DataType::Decimal64(_, _) => "Decimal64",
564            DataType::Decimal128(_, _) => "Decimal128",
565            DataType::Decimal256(_, _) => "Decimal256",
566            DataType::Map(_, _) => "Map",
567            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
568            DataType::BinaryView => "BinaryView",
569            DataType::Utf8View => "Utf8View",
570            DataType::ListView(_) => "ListView",
571            DataType::LargeListView(_) => "LargeListView",
572        })
573    }
574}
575
576/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
577/// Since `DataType` exists in another package we cannot make that happen here so we wrap
578/// `DataType` as `PyDataType` This exists solely to satisfy those constraints.
579#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
580#[pyclass(name = "DataType", module = "datafusion.common")]
581pub struct PyDataType {
582    pub data_type: DataType,
583}
584
585impl PyDataType {
586    /// There are situations when obtaining dtypes on the Python side where the Arrow type
587    /// is presented as a String rather than an actual DataType. This function is used to
588    /// convert that String to a DataType for the Python side to use.
589    pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
590        // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]"
591        let arrow_str_type = match arrow_str_type.find('[') {
592            Some(index) => arrow_str_type[0..index].to_string(),
593            None => arrow_str_type, // Return early if ',' is not found.
594        };
595
596        let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
597            "bool" => Ok(DataType::Boolean),
598            "boolean" => Ok(DataType::Boolean),
599            "uint8" => Ok(DataType::UInt8),
600            "uint16" => Ok(DataType::UInt16),
601            "uint32" => Ok(DataType::UInt32),
602            "uint64" => Ok(DataType::UInt64),
603            "int8" => Ok(DataType::Int8),
604            "int16" => Ok(DataType::Int16),
605            "int32" => Ok(DataType::Int32),
606            "int64" => Ok(DataType::Int64),
607            "float" => Ok(DataType::Float32),
608            "double" => Ok(DataType::Float64),
609            "float16" => Ok(DataType::Float16),
610            "float32" => Ok(DataType::Float32),
611            "float64" => Ok(DataType::Float64),
612            "datetime64" => Ok(DataType::Date64),
613            "object" => Ok(DataType::Utf8),
614            _ => Err(PyValueError::new_err(format!(
615                "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}"
616            ))),
617        };
618        Ok(PyDataType {
619            data_type: arrow_dtype?,
620        })
621    }
622}
623
624impl From<PyDataType> for DataType {
625    fn from(data_type: PyDataType) -> DataType {
626        data_type.data_type
627    }
628}
629
630impl From<DataType> for PyDataType {
631    fn from(data_type: DataType) -> PyDataType {
632        PyDataType { data_type }
633    }
634}
635
636/// Represents the possible Python types that can be mapped to the SQL types
637#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
638#[pyclass(eq, eq_int, name = "PythonType", module = "datafusion.common")]
639pub enum PythonType {
640    Array,
641    Bool,
642    Bytes,
643    Datetime,
644    Float,
645    Int,
646    List,
647    None,
648    Object,
649    Str,
650}
651
652/// Represents the types that are possible for DataFusion to parse
653/// from a SQL query. Aka "SqlType" and are valid values for
654/// ANSI SQL
655#[allow(non_camel_case_types)]
656#[allow(clippy::upper_case_acronyms)]
657#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
658#[pyclass(eq, eq_int, name = "SqlType", module = "datafusion.common")]
659pub enum SqlType {
660    ANY,
661    ARRAY,
662    BIGINT,
663    BINARY,
664    BOOLEAN,
665    CHAR,
666    COLUMN_LIST,
667    CURSOR,
668    DATE,
669    DECIMAL,
670    DISTINCT,
671    DOUBLE,
672    DYNAMIC_STAR,
673    FLOAT,
674    GEOMETRY,
675    INTEGER,
676    INTERVAL,
677    INTERVAL_DAY,
678    INTERVAL_DAY_HOUR,
679    INTERVAL_DAY_MINUTE,
680    INTERVAL_DAY_SECOND,
681    INTERVAL_HOUR,
682    INTERVAL_HOUR_MINUTE,
683    INTERVAL_HOUR_SECOND,
684    INTERVAL_MINUTE,
685    INTERVAL_MINUTE_SECOND,
686    INTERVAL_MONTH,
687    INTERVAL_SECOND,
688    INTERVAL_YEAR,
689    INTERVAL_YEAR_MONTH,
690    MAP,
691    MULTISET,
692    NULL,
693    OTHER,
694    REAL,
695    ROW,
696    SARG,
697    SMALLINT,
698    STRUCTURED,
699    SYMBOL,
700    TIME,
701    TIME_WITH_LOCAL_TIME_ZONE,
702    TIMESTAMP,
703    TIMESTAMP_WITH_LOCAL_TIME_ZONE,
704    TINYINT,
705    UNKNOWN,
706    VARBINARY,
707    VARCHAR,
708}
709
710/// Specifies Ignore / Respect NULL within window functions.
711/// For example
712/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
713#[allow(non_camel_case_types)]
714#[allow(clippy::upper_case_acronyms)]
715#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
716#[pyclass(eq, eq_int, name = "NullTreatment", module = "datafusion.common")]
717pub enum NullTreatment {
718    IGNORE_NULLS,
719    RESPECT_NULLS,
720}
721
722impl From<NullTreatment> for DFNullTreatment {
723    fn from(null_treatment: NullTreatment) -> DFNullTreatment {
724        match null_treatment {
725            NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls,
726            NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls,
727        }
728    }
729}
730
731impl From<DFNullTreatment> for NullTreatment {
732    fn from(null_treatment: DFNullTreatment) -> NullTreatment {
733        match null_treatment {
734            DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
735            DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
736        }
737    }
738}