datafusion_python/common/
data_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::arrow::array::Array;
19use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
20use datafusion::common::ScalarValue;
21use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
22use pyo3::exceptions::PyNotImplementedError;
23use pyo3::{exceptions::PyValueError, prelude::*};
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
26pub struct PyScalarValue(pub ScalarValue);
27
28impl From<ScalarValue> for PyScalarValue {
29    fn from(value: ScalarValue) -> Self {
30        Self(value)
31    }
32}
33impl From<PyScalarValue> for ScalarValue {
34    fn from(value: PyScalarValue) -> Self {
35        value.0
36    }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[pyclass(frozen, eq, eq_int, name = "RexType", module = "datafusion.common")]
41pub enum RexType {
42    Alias,
43    Literal,
44    Call,
45    Reference,
46    ScalarSubquery,
47    Other,
48}
49
50/// These bindings are tying together several disparate systems.
51/// You have SQL types for the SQL strings and RDBMS systems itself.
52/// Rust types for the DataFusion code
53/// Arrow types which represents the underlying arrow format
54/// Python types which represent the type in Python
55/// It is important to keep all of those types in a single
56/// and manageable location. Therefore this structure exists
57/// to map those types and provide a simple place for developers
58/// to map types from one system to another.
59// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now
60#[derive(Debug, Clone)]
61#[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)]
62pub struct DataTypeMap {
63    #[pyo3(get, set)]
64    pub arrow_type: PyDataType,
65    #[pyo3(get, set)]
66    pub python_type: PythonType,
67    #[pyo3(get, set)]
68    pub sql_type: SqlType,
69}
70
71impl DataTypeMap {
72    fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self {
73        DataTypeMap {
74            arrow_type: PyDataType {
75                data_type: arrow_type,
76            },
77            python_type,
78            sql_type,
79        }
80    }
81
82    pub fn map_from_arrow_type(arrow_type: &DataType) -> Result<DataTypeMap, PyErr> {
83        match arrow_type {
84            DataType::Null => Ok(DataTypeMap::new(
85                DataType::Null,
86                PythonType::None,
87                SqlType::NULL,
88            )),
89            DataType::Boolean => Ok(DataTypeMap::new(
90                DataType::Boolean,
91                PythonType::Bool,
92                SqlType::BOOLEAN,
93            )),
94            DataType::Int8 => Ok(DataTypeMap::new(
95                DataType::Int8,
96                PythonType::Int,
97                SqlType::TINYINT,
98            )),
99            DataType::Int16 => Ok(DataTypeMap::new(
100                DataType::Int16,
101                PythonType::Int,
102                SqlType::SMALLINT,
103            )),
104            DataType::Int32 => Ok(DataTypeMap::new(
105                DataType::Int32,
106                PythonType::Int,
107                SqlType::INTEGER,
108            )),
109            DataType::Int64 => Ok(DataTypeMap::new(
110                DataType::Int64,
111                PythonType::Int,
112                SqlType::BIGINT,
113            )),
114            DataType::UInt8 => Ok(DataTypeMap::new(
115                DataType::UInt8,
116                PythonType::Int,
117                SqlType::TINYINT,
118            )),
119            DataType::UInt16 => Ok(DataTypeMap::new(
120                DataType::UInt16,
121                PythonType::Int,
122                SqlType::SMALLINT,
123            )),
124            DataType::UInt32 => Ok(DataTypeMap::new(
125                DataType::UInt32,
126                PythonType::Int,
127                SqlType::INTEGER,
128            )),
129            DataType::UInt64 => Ok(DataTypeMap::new(
130                DataType::UInt64,
131                PythonType::Int,
132                SqlType::BIGINT,
133            )),
134            DataType::Float16 => Ok(DataTypeMap::new(
135                DataType::Float16,
136                PythonType::Float,
137                SqlType::FLOAT,
138            )),
139            DataType::Float32 => Ok(DataTypeMap::new(
140                DataType::Float32,
141                PythonType::Float,
142                SqlType::FLOAT,
143            )),
144            DataType::Float64 => Ok(DataTypeMap::new(
145                DataType::Float64,
146                PythonType::Float,
147                SqlType::FLOAT,
148            )),
149            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
150                DataType::Timestamp(*unit, tz.clone()),
151                PythonType::Datetime,
152                SqlType::DATE,
153            )),
154            DataType::Date32 => Ok(DataTypeMap::new(
155                DataType::Date32,
156                PythonType::Datetime,
157                SqlType::DATE,
158            )),
159            DataType::Date64 => Ok(DataTypeMap::new(
160                DataType::Date64,
161                PythonType::Datetime,
162                SqlType::DATE,
163            )),
164            DataType::Time32(unit) => Ok(DataTypeMap::new(
165                DataType::Time32(*unit),
166                PythonType::Datetime,
167                SqlType::DATE,
168            )),
169            DataType::Time64(unit) => Ok(DataTypeMap::new(
170                DataType::Time64(*unit),
171                PythonType::Datetime,
172                SqlType::DATE,
173            )),
174            DataType::Duration(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
175            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
176                DataType::Interval(*interval_unit),
177                PythonType::Datetime,
178                match interval_unit {
179                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
180                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
181                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
182                },
183            )),
184            DataType::Binary => Ok(DataTypeMap::new(
185                DataType::Binary,
186                PythonType::Bytes,
187                SqlType::BINARY,
188            )),
189            DataType::FixedSizeBinary(_) => {
190                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
191            }
192            DataType::LargeBinary => Ok(DataTypeMap::new(
193                DataType::LargeBinary,
194                PythonType::Bytes,
195                SqlType::BINARY,
196            )),
197            DataType::Utf8 => Ok(DataTypeMap::new(
198                DataType::Utf8,
199                PythonType::Str,
200                SqlType::VARCHAR,
201            )),
202            DataType::LargeUtf8 => Ok(DataTypeMap::new(
203                DataType::LargeUtf8,
204                PythonType::Str,
205                SqlType::VARCHAR,
206            )),
207            DataType::List(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
208            DataType::FixedSizeList(_, _) => {
209                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
210            }
211            DataType::LargeList(_) => {
212                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
213            }
214            DataType::Struct(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
215            DataType::Union(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
216            DataType::Dictionary(_, _) => {
217                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
218            }
219            DataType::Decimal32(precision, scale) => Ok(DataTypeMap::new(
220                DataType::Decimal32(*precision, *scale),
221                PythonType::Float,
222                SqlType::DECIMAL,
223            )),
224            DataType::Decimal64(precision, scale) => Ok(DataTypeMap::new(
225                DataType::Decimal64(*precision, *scale),
226                PythonType::Float,
227                SqlType::DECIMAL,
228            )),
229            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
230                DataType::Decimal128(*precision, *scale),
231                PythonType::Float,
232                SqlType::DECIMAL,
233            )),
234            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
235                DataType::Decimal256(*precision, *scale),
236                PythonType::Float,
237                SqlType::DECIMAL,
238            )),
239            DataType::Map(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
240            DataType::RunEndEncoded(_, _) => {
241                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
242            }
243            DataType::BinaryView => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
244            DataType::Utf8View => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
245            DataType::ListView(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
246            DataType::LargeListView(_) => {
247                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
248            }
249        }
250    }
251
252    /// Generate the `DataTypeMap` from a `ScalarValue` instance
253    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
254        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
255    }
256
257    /// Maps a `ScalarValue` to an Arrow `DataType`
258    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
259        match scalar_val {
260            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
261            ScalarValue::Float16(_) => Ok(DataType::Float16),
262            ScalarValue::Float32(_) => Ok(DataType::Float32),
263            ScalarValue::Float64(_) => Ok(DataType::Float64),
264            ScalarValue::Decimal128(_, precision, scale) => {
265                Ok(DataType::Decimal128(*precision, *scale))
266            }
267            ScalarValue::Decimal256(_, precision, scale) => {
268                Ok(DataType::Decimal256(*precision, *scale))
269            }
270            ScalarValue::Dictionary(data_type, scalar_type) => {
271                // Call this function again to map the dictionary scalar_value to an Arrow type
272                Ok(DataType::Dictionary(
273                    Box::new(*data_type.clone()),
274                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
275                ))
276            }
277            ScalarValue::Int8(_) => Ok(DataType::Int8),
278            ScalarValue::Int16(_) => Ok(DataType::Int16),
279            ScalarValue::Int32(_) => Ok(DataType::Int32),
280            ScalarValue::Int64(_) => Ok(DataType::Int64),
281            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
282            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
283            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
284            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
285            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
286            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
287            ScalarValue::Binary(_) => Ok(DataType::Binary),
288            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
289            ScalarValue::Date32(_) => Ok(DataType::Date32),
290            ScalarValue::Date64(_) => Ok(DataType::Date64),
291            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
292            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
293            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
294            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
295            ScalarValue::Null => Ok(DataType::Null),
296            ScalarValue::TimestampSecond(_, tz) => {
297                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
298            }
299            ScalarValue::TimestampMillisecond(_, tz) => {
300                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
301            }
302            ScalarValue::TimestampMicrosecond(_, tz) => {
303                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
304            }
305            ScalarValue::TimestampNanosecond(_, tz) => {
306                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
307            }
308            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
309            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
310            ScalarValue::IntervalMonthDayNano(..) => {
311                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
312            }
313            ScalarValue::List(arr) => Ok(arr.data_type().to_owned()),
314            ScalarValue::Struct(_fields) => Err(PyNotImplementedError::new_err(
315                "ScalarValue::Struct".to_string(),
316            )),
317            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
318            ScalarValue::FixedSizeList(_array_ref) => {
319                // The FieldRef was removed from ScalarValue::FixedSizeList in
320                // https://github.com/apache/arrow-datafusion/pull/8221, so we can no
321                // longer convert back to a DataType here
322                Err(PyNotImplementedError::new_err(
323                    "ScalarValue::FixedSizeList".to_string(),
324                ))
325            }
326            ScalarValue::LargeList(_) => Err(PyNotImplementedError::new_err(
327                "ScalarValue::LargeList".to_string(),
328            )),
329            ScalarValue::DurationSecond(_) => Ok(DataType::Duration(TimeUnit::Second)),
330            ScalarValue::DurationMillisecond(_) => Ok(DataType::Duration(TimeUnit::Millisecond)),
331            ScalarValue::DurationMicrosecond(_) => Ok(DataType::Duration(TimeUnit::Microsecond)),
332            ScalarValue::DurationNanosecond(_) => Ok(DataType::Duration(TimeUnit::Nanosecond)),
333            ScalarValue::Union(_, _, _) => Err(PyNotImplementedError::new_err(
334                "ScalarValue::LargeList".to_string(),
335            )),
336            ScalarValue::Utf8View(_) => Ok(DataType::Utf8View),
337            ScalarValue::BinaryView(_) => Ok(DataType::BinaryView),
338            ScalarValue::Map(_) => Err(PyNotImplementedError::new_err(
339                "ScalarValue::Map".to_string(),
340            )),
341        }
342    }
343}
344
345#[pymethods]
346impl DataTypeMap {
347    #[new]
348    pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self {
349        DataTypeMap {
350            arrow_type,
351            python_type,
352            sql_type,
353        }
354    }
355
356    #[staticmethod]
357    #[pyo3(name = "from_parquet_type_str")]
358    /// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
359    /// with a String type for schema rather than an object type. Here we make a best effort
360    /// to convert that to a physical type.
361    pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
362        let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
363            "boolean" => Ok(DataType::Boolean),
364            "int32" => Ok(DataType::Int32),
365            "int64" => Ok(DataType::Int64),
366            "int96" => {
367                // Int96 is an old parquet datatype that is now deprecated. We convert to nanosecond timestamp
368                Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
369            }
370            "float" => Ok(DataType::Float32),
371            "double" => Ok(DataType::Float64),
372            "byte_array" => Ok(DataType::Utf8),
373            _ => Err(PyValueError::new_err(format!(
374                "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}"
375            ))),
376        };
377        DataTypeMap::map_from_arrow_type(&arrow_dtype?)
378    }
379
380    #[staticmethod]
381    #[pyo3(name = "arrow")]
382    pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
383        DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
384    }
385
386    #[staticmethod]
387    #[pyo3(name = "arrow_str")]
388    pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
389        let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
390        DataTypeMap::map_from_arrow_type(&data_type?.data_type)
391    }
392
393    #[staticmethod]
394    #[pyo3(name = "sql")]
395    pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
396        match sql_type {
397            SqlType::ANY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
398            SqlType::ARRAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
399            SqlType::BIGINT => Ok(DataTypeMap::new(
400                DataType::Int64,
401                PythonType::Int,
402                SqlType::BIGINT,
403            )),
404            SqlType::BINARY => Ok(DataTypeMap::new(
405                DataType::Binary,
406                PythonType::Bytes,
407                SqlType::BINARY,
408            )),
409            SqlType::BOOLEAN => Ok(DataTypeMap::new(
410                DataType::Boolean,
411                PythonType::Bool,
412                SqlType::BOOLEAN,
413            )),
414            SqlType::CHAR => Ok(DataTypeMap::new(
415                DataType::UInt8,
416                PythonType::Int,
417                SqlType::CHAR,
418            )),
419            SqlType::COLUMN_LIST => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
420            SqlType::CURSOR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
421            SqlType::DATE => Ok(DataTypeMap::new(
422                DataType::Date64,
423                PythonType::Datetime,
424                SqlType::DATE,
425            )),
426            SqlType::DECIMAL => Ok(DataTypeMap::new(
427                DataType::Decimal128(1, 1),
428                PythonType::Float,
429                SqlType::DECIMAL,
430            )),
431            SqlType::DISTINCT => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
432            SqlType::DOUBLE => Ok(DataTypeMap::new(
433                DataType::Decimal256(1, 1),
434                PythonType::Float,
435                SqlType::DOUBLE,
436            )),
437            SqlType::DYNAMIC_STAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
438            SqlType::FLOAT => Ok(DataTypeMap::new(
439                DataType::Decimal128(1, 1),
440                PythonType::Float,
441                SqlType::FLOAT,
442            )),
443            SqlType::GEOMETRY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
444            SqlType::INTEGER => Ok(DataTypeMap::new(
445                DataType::Int8,
446                PythonType::Int,
447                SqlType::INTEGER,
448            )),
449            SqlType::INTERVAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
450            SqlType::INTERVAL_DAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
451            SqlType::INTERVAL_DAY_HOUR => {
452                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
453            }
454            SqlType::INTERVAL_DAY_MINUTE => {
455                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
456            }
457            SqlType::INTERVAL_DAY_SECOND => {
458                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
459            }
460            SqlType::INTERVAL_HOUR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
461            SqlType::INTERVAL_HOUR_MINUTE => {
462                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
463            }
464            SqlType::INTERVAL_HOUR_SECOND => {
465                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
466            }
467            SqlType::INTERVAL_MINUTE => {
468                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
469            }
470            SqlType::INTERVAL_MINUTE_SECOND => {
471                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
472            }
473            SqlType::INTERVAL_MONTH => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
474            SqlType::INTERVAL_SECOND => {
475                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
476            }
477            SqlType::INTERVAL_YEAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
478            SqlType::INTERVAL_YEAR_MONTH => {
479                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
480            }
481            SqlType::MAP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
482            SqlType::MULTISET => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
483            SqlType::NULL => Ok(DataTypeMap::new(
484                DataType::Null,
485                PythonType::None,
486                SqlType::NULL,
487            )),
488            SqlType::OTHER => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
489            SqlType::REAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
490            SqlType::ROW => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
491            SqlType::SARG => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
492            SqlType::SMALLINT => Ok(DataTypeMap::new(
493                DataType::Int16,
494                PythonType::Int,
495                SqlType::SMALLINT,
496            )),
497            SqlType::STRUCTURED => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
498            SqlType::SYMBOL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
499            SqlType::TIME => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
500            SqlType::TIME_WITH_LOCAL_TIME_ZONE => {
501                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
502            }
503            SqlType::TIMESTAMP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
504            SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => {
505                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
506            }
507            SqlType::TINYINT => Ok(DataTypeMap::new(
508                DataType::Int8,
509                PythonType::Int,
510                SqlType::TINYINT,
511            )),
512            SqlType::UNKNOWN => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
513            SqlType::VARBINARY => Ok(DataTypeMap::new(
514                DataType::LargeBinary,
515                PythonType::Bytes,
516                SqlType::VARBINARY,
517            )),
518            SqlType::VARCHAR => Ok(DataTypeMap::new(
519                DataType::Utf8,
520                PythonType::Str,
521                SqlType::VARCHAR,
522            )),
523        }
524    }
525
526    /// Unfortunately PyO3 does not allow for us to expose the DataType as an enum since
527    /// we cannot directly annotate the Enum instance of dependency code. Therefore, here
528    /// we provide an enum to mimic it.
529    #[pyo3(name = "friendly_arrow_type_name")]
530    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
531        Ok(match &self.arrow_type.data_type {
532            DataType::Null => "Null",
533            DataType::Boolean => "Boolean",
534            DataType::Int8 => "Int8",
535            DataType::Int16 => "Int16",
536            DataType::Int32 => "Int32",
537            DataType::Int64 => "Int64",
538            DataType::UInt8 => "UInt8",
539            DataType::UInt16 => "UInt16",
540            DataType::UInt32 => "UInt32",
541            DataType::UInt64 => "UInt64",
542            DataType::Float16 => "Float16",
543            DataType::Float32 => "Float32",
544            DataType::Float64 => "Float64",
545            DataType::Timestamp(_, _) => "Timestamp",
546            DataType::Date32 => "Date32",
547            DataType::Date64 => "Date64",
548            DataType::Time32(_) => "Time32",
549            DataType::Time64(_) => "Time64",
550            DataType::Duration(_) => "Duration",
551            DataType::Interval(_) => "Interval",
552            DataType::Binary => "Binary",
553            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
554            DataType::LargeBinary => "LargeBinary",
555            DataType::Utf8 => "Utf8",
556            DataType::LargeUtf8 => "LargeUtf8",
557            DataType::List(_) => "List",
558            DataType::FixedSizeList(_, _) => "FixedSizeList",
559            DataType::LargeList(_) => "LargeList",
560            DataType::Struct(_) => "Struct",
561            DataType::Union(_, _) => "Union",
562            DataType::Dictionary(_, _) => "Dictionary",
563            DataType::Decimal32(_, _) => "Decimal32",
564            DataType::Decimal64(_, _) => "Decimal64",
565            DataType::Decimal128(_, _) => "Decimal128",
566            DataType::Decimal256(_, _) => "Decimal256",
567            DataType::Map(_, _) => "Map",
568            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
569            DataType::BinaryView => "BinaryView",
570            DataType::Utf8View => "Utf8View",
571            DataType::ListView(_) => "ListView",
572            DataType::LargeListView(_) => "LargeListView",
573        })
574    }
575}
576
577/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
578/// Since `DataType` exists in another package we cannot make that happen here so we wrap
579/// `DataType` as `PyDataType` This exists solely to satisfy those constraints.
580#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
581#[pyclass(frozen, name = "DataType", module = "datafusion.common")]
582pub struct PyDataType {
583    pub data_type: DataType,
584}
585
586impl PyDataType {
587    /// There are situations when obtaining dtypes on the Python side where the Arrow type
588    /// is presented as a String rather than an actual DataType. This function is used to
589    /// convert that String to a DataType for the Python side to use.
590    pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
591        // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]"
592        let arrow_str_type = match arrow_str_type.find('[') {
593            Some(index) => arrow_str_type[0..index].to_string(),
594            None => arrow_str_type, // Return early if ',' is not found.
595        };
596
597        let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
598            "bool" => Ok(DataType::Boolean),
599            "boolean" => Ok(DataType::Boolean),
600            "uint8" => Ok(DataType::UInt8),
601            "uint16" => Ok(DataType::UInt16),
602            "uint32" => Ok(DataType::UInt32),
603            "uint64" => Ok(DataType::UInt64),
604            "int8" => Ok(DataType::Int8),
605            "int16" => Ok(DataType::Int16),
606            "int32" => Ok(DataType::Int32),
607            "int64" => Ok(DataType::Int64),
608            "float" => Ok(DataType::Float32),
609            "double" => Ok(DataType::Float64),
610            "float16" => Ok(DataType::Float16),
611            "float32" => Ok(DataType::Float32),
612            "float64" => Ok(DataType::Float64),
613            "datetime64" => Ok(DataType::Date64),
614            "object" => Ok(DataType::Utf8),
615            _ => Err(PyValueError::new_err(format!(
616                "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}"
617            ))),
618        };
619        Ok(PyDataType {
620            data_type: arrow_dtype?,
621        })
622    }
623}
624
625impl From<PyDataType> for DataType {
626    fn from(data_type: PyDataType) -> DataType {
627        data_type.data_type
628    }
629}
630
631impl From<DataType> for PyDataType {
632    fn from(data_type: DataType) -> PyDataType {
633        PyDataType { data_type }
634    }
635}
636
637/// Represents the possible Python types that can be mapped to the SQL types
638#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
639#[pyclass(frozen, eq, eq_int, name = "PythonType", module = "datafusion.common")]
640pub enum PythonType {
641    Array,
642    Bool,
643    Bytes,
644    Datetime,
645    Float,
646    Int,
647    List,
648    None,
649    Object,
650    Str,
651}
652
653/// Represents the types that are possible for DataFusion to parse
654/// from a SQL query. Aka "SqlType" and are valid values for
655/// ANSI SQL
656#[allow(non_camel_case_types)]
657#[allow(clippy::upper_case_acronyms)]
658#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
659#[pyclass(frozen, eq, eq_int, name = "SqlType", module = "datafusion.common")]
660pub enum SqlType {
661    ANY,
662    ARRAY,
663    BIGINT,
664    BINARY,
665    BOOLEAN,
666    CHAR,
667    COLUMN_LIST,
668    CURSOR,
669    DATE,
670    DECIMAL,
671    DISTINCT,
672    DOUBLE,
673    DYNAMIC_STAR,
674    FLOAT,
675    GEOMETRY,
676    INTEGER,
677    INTERVAL,
678    INTERVAL_DAY,
679    INTERVAL_DAY_HOUR,
680    INTERVAL_DAY_MINUTE,
681    INTERVAL_DAY_SECOND,
682    INTERVAL_HOUR,
683    INTERVAL_HOUR_MINUTE,
684    INTERVAL_HOUR_SECOND,
685    INTERVAL_MINUTE,
686    INTERVAL_MINUTE_SECOND,
687    INTERVAL_MONTH,
688    INTERVAL_SECOND,
689    INTERVAL_YEAR,
690    INTERVAL_YEAR_MONTH,
691    MAP,
692    MULTISET,
693    NULL,
694    OTHER,
695    REAL,
696    ROW,
697    SARG,
698    SMALLINT,
699    STRUCTURED,
700    SYMBOL,
701    TIME,
702    TIME_WITH_LOCAL_TIME_ZONE,
703    TIMESTAMP,
704    TIMESTAMP_WITH_LOCAL_TIME_ZONE,
705    TINYINT,
706    UNKNOWN,
707    VARBINARY,
708    VARCHAR,
709}
710
711/// Specifies Ignore / Respect NULL within window functions.
712/// For example
713/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
714#[allow(non_camel_case_types)]
715#[allow(clippy::upper_case_acronyms)]
716#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
717#[pyclass(
718    frozen,
719    eq,
720    eq_int,
721    name = "NullTreatment",
722    module = "datafusion.common"
723)]
724pub enum NullTreatment {
725    IGNORE_NULLS,
726    RESPECT_NULLS,
727}
728
729impl From<NullTreatment> for DFNullTreatment {
730    fn from(null_treatment: NullTreatment) -> DFNullTreatment {
731        match null_treatment {
732            NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls,
733            NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls,
734        }
735    }
736}
737
738impl From<DFNullTreatment> for NullTreatment {
739    fn from(null_treatment: DFNullTreatment) -> NullTreatment {
740        match null_treatment {
741            DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
742            DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
743        }
744    }
745}