Skip to main content

datafusion_python/common/
data_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::arrow::array::Array;
19use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
20use datafusion::common::ScalarValue;
21use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment;
22use pyo3::exceptions::{PyNotImplementedError, PyValueError};
23use pyo3::prelude::*;
24
25/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion
26/// from a variety of Python objects into a [`ScalarValue`]. See
27/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
28#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
29pub struct PyScalarValue(pub ScalarValue);
30
31impl From<ScalarValue> for PyScalarValue {
32    fn from(value: ScalarValue) -> Self {
33        Self(value)
34    }
35}
36impl From<PyScalarValue> for ScalarValue {
37    fn from(value: PyScalarValue) -> Self {
38        value.0
39    }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
43#[pyclass(frozen, eq, eq_int, name = "RexType", module = "datafusion.common")]
44pub enum RexType {
45    Alias,
46    Literal,
47    Call,
48    Reference,
49    ScalarSubquery,
50    Other,
51}
52
53/// These bindings are tying together several disparate systems.
54/// You have SQL types for the SQL strings and RDBMS systems itself.
55/// Rust types for the DataFusion code
56/// Arrow types which represents the underlying arrow format
57/// Python types which represent the type in Python
58/// It is important to keep all of those types in a single
59/// and manageable location. Therefore this structure exists
60/// to map those types and provide a simple place for developers
61/// to map types from one system to another.
62// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now
63#[derive(Debug, Clone)]
64#[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)]
65pub struct DataTypeMap {
66    #[pyo3(get, set)]
67    pub arrow_type: PyDataType,
68    #[pyo3(get, set)]
69    pub python_type: PythonType,
70    #[pyo3(get, set)]
71    pub sql_type: SqlType,
72}
73
74impl DataTypeMap {
75    fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self {
76        DataTypeMap {
77            arrow_type: PyDataType {
78                data_type: arrow_type,
79            },
80            python_type,
81            sql_type,
82        }
83    }
84
85    pub fn map_from_arrow_type(arrow_type: &DataType) -> Result<DataTypeMap, PyErr> {
86        match arrow_type {
87            DataType::Null => Ok(DataTypeMap::new(
88                DataType::Null,
89                PythonType::None,
90                SqlType::NULL,
91            )),
92            DataType::Boolean => Ok(DataTypeMap::new(
93                DataType::Boolean,
94                PythonType::Bool,
95                SqlType::BOOLEAN,
96            )),
97            DataType::Int8 => Ok(DataTypeMap::new(
98                DataType::Int8,
99                PythonType::Int,
100                SqlType::TINYINT,
101            )),
102            DataType::Int16 => Ok(DataTypeMap::new(
103                DataType::Int16,
104                PythonType::Int,
105                SqlType::SMALLINT,
106            )),
107            DataType::Int32 => Ok(DataTypeMap::new(
108                DataType::Int32,
109                PythonType::Int,
110                SqlType::INTEGER,
111            )),
112            DataType::Int64 => Ok(DataTypeMap::new(
113                DataType::Int64,
114                PythonType::Int,
115                SqlType::BIGINT,
116            )),
117            DataType::UInt8 => Ok(DataTypeMap::new(
118                DataType::UInt8,
119                PythonType::Int,
120                SqlType::TINYINT,
121            )),
122            DataType::UInt16 => Ok(DataTypeMap::new(
123                DataType::UInt16,
124                PythonType::Int,
125                SqlType::SMALLINT,
126            )),
127            DataType::UInt32 => Ok(DataTypeMap::new(
128                DataType::UInt32,
129                PythonType::Int,
130                SqlType::INTEGER,
131            )),
132            DataType::UInt64 => Ok(DataTypeMap::new(
133                DataType::UInt64,
134                PythonType::Int,
135                SqlType::BIGINT,
136            )),
137            DataType::Float16 => Ok(DataTypeMap::new(
138                DataType::Float16,
139                PythonType::Float,
140                SqlType::FLOAT,
141            )),
142            DataType::Float32 => Ok(DataTypeMap::new(
143                DataType::Float32,
144                PythonType::Float,
145                SqlType::FLOAT,
146            )),
147            DataType::Float64 => Ok(DataTypeMap::new(
148                DataType::Float64,
149                PythonType::Float,
150                SqlType::FLOAT,
151            )),
152            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
153                DataType::Timestamp(*unit, tz.clone()),
154                PythonType::Datetime,
155                SqlType::DATE,
156            )),
157            DataType::Date32 => Ok(DataTypeMap::new(
158                DataType::Date32,
159                PythonType::Datetime,
160                SqlType::DATE,
161            )),
162            DataType::Date64 => Ok(DataTypeMap::new(
163                DataType::Date64,
164                PythonType::Datetime,
165                SqlType::DATE,
166            )),
167            DataType::Time32(unit) => Ok(DataTypeMap::new(
168                DataType::Time32(*unit),
169                PythonType::Datetime,
170                SqlType::DATE,
171            )),
172            DataType::Time64(unit) => Ok(DataTypeMap::new(
173                DataType::Time64(*unit),
174                PythonType::Datetime,
175                SqlType::DATE,
176            )),
177            DataType::Duration(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
178            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
179                DataType::Interval(*interval_unit),
180                PythonType::Datetime,
181                match interval_unit {
182                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
183                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
184                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
185                },
186            )),
187            DataType::Binary => Ok(DataTypeMap::new(
188                DataType::Binary,
189                PythonType::Bytes,
190                SqlType::BINARY,
191            )),
192            DataType::FixedSizeBinary(_) => {
193                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
194            }
195            DataType::LargeBinary => Ok(DataTypeMap::new(
196                DataType::LargeBinary,
197                PythonType::Bytes,
198                SqlType::BINARY,
199            )),
200            DataType::Utf8 => Ok(DataTypeMap::new(
201                DataType::Utf8,
202                PythonType::Str,
203                SqlType::VARCHAR,
204            )),
205            DataType::LargeUtf8 => Ok(DataTypeMap::new(
206                DataType::LargeUtf8,
207                PythonType::Str,
208                SqlType::VARCHAR,
209            )),
210            DataType::List(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
211            DataType::FixedSizeList(_, _) => {
212                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
213            }
214            DataType::LargeList(_) => {
215                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
216            }
217            DataType::Struct(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
218            DataType::Union(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
219            DataType::Dictionary(_, _) => {
220                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
221            }
222            DataType::Decimal32(precision, scale) => Ok(DataTypeMap::new(
223                DataType::Decimal32(*precision, *scale),
224                PythonType::Float,
225                SqlType::DECIMAL,
226            )),
227            DataType::Decimal64(precision, scale) => Ok(DataTypeMap::new(
228                DataType::Decimal64(*precision, *scale),
229                PythonType::Float,
230                SqlType::DECIMAL,
231            )),
232            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
233                DataType::Decimal128(*precision, *scale),
234                PythonType::Float,
235                SqlType::DECIMAL,
236            )),
237            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
238                DataType::Decimal256(*precision, *scale),
239                PythonType::Float,
240                SqlType::DECIMAL,
241            )),
242            DataType::Map(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
243            DataType::RunEndEncoded(_, _) => {
244                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
245            }
246            DataType::BinaryView => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
247            DataType::Utf8View => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
248            DataType::ListView(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
249            DataType::LargeListView(_) => {
250                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
251            }
252        }
253    }
254
255    /// Generate the `DataTypeMap` from a `ScalarValue` instance
256    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
257        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
258    }
259
260    /// Maps a `ScalarValue` to an Arrow `DataType`
261    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
262        match scalar_val {
263            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
264            ScalarValue::Float16(_) => Ok(DataType::Float16),
265            ScalarValue::Float32(_) => Ok(DataType::Float32),
266            ScalarValue::Float64(_) => Ok(DataType::Float64),
267            ScalarValue::Decimal32(_, precision, scale) => {
268                Ok(DataType::Decimal32(*precision, *scale))
269            }
270            ScalarValue::Decimal64(_, precision, scale) => {
271                Ok(DataType::Decimal64(*precision, *scale))
272            }
273            ScalarValue::Decimal128(_, precision, scale) => {
274                Ok(DataType::Decimal128(*precision, *scale))
275            }
276            ScalarValue::Decimal256(_, precision, scale) => {
277                Ok(DataType::Decimal256(*precision, *scale))
278            }
279            ScalarValue::Dictionary(data_type, scalar_type) => {
280                // Call this function again to map the dictionary scalar_value to an Arrow type
281                Ok(DataType::Dictionary(
282                    Box::new(*data_type.clone()),
283                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
284                ))
285            }
286            ScalarValue::Int8(_) => Ok(DataType::Int8),
287            ScalarValue::Int16(_) => Ok(DataType::Int16),
288            ScalarValue::Int32(_) => Ok(DataType::Int32),
289            ScalarValue::Int64(_) => Ok(DataType::Int64),
290            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
291            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
292            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
293            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
294            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
295            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
296            ScalarValue::Binary(_) => Ok(DataType::Binary),
297            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
298            ScalarValue::Date32(_) => Ok(DataType::Date32),
299            ScalarValue::Date64(_) => Ok(DataType::Date64),
300            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
301            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
302            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
303            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
304            ScalarValue::Null => Ok(DataType::Null),
305            ScalarValue::TimestampSecond(_, tz) => {
306                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
307            }
308            ScalarValue::TimestampMillisecond(_, tz) => {
309                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
310            }
311            ScalarValue::TimestampMicrosecond(_, tz) => {
312                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
313            }
314            ScalarValue::TimestampNanosecond(_, tz) => {
315                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
316            }
317            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
318            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
319            ScalarValue::IntervalMonthDayNano(..) => {
320                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
321            }
322            ScalarValue::List(arr) => Ok(arr.data_type().to_owned()),
323            ScalarValue::Struct(_fields) => Err(PyNotImplementedError::new_err(
324                "ScalarValue::Struct".to_string(),
325            )),
326            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
327            ScalarValue::FixedSizeList(_array_ref) => {
328                // The FieldRef was removed from ScalarValue::FixedSizeList in
329                // https://github.com/apache/arrow-datafusion/pull/8221, so we can no
330                // longer convert back to a DataType here
331                Err(PyNotImplementedError::new_err(
332                    "ScalarValue::FixedSizeList".to_string(),
333                ))
334            }
335            ScalarValue::LargeList(_) => Err(PyNotImplementedError::new_err(
336                "ScalarValue::LargeList".to_string(),
337            )),
338            ScalarValue::DurationSecond(_) => Ok(DataType::Duration(TimeUnit::Second)),
339            ScalarValue::DurationMillisecond(_) => Ok(DataType::Duration(TimeUnit::Millisecond)),
340            ScalarValue::DurationMicrosecond(_) => Ok(DataType::Duration(TimeUnit::Microsecond)),
341            ScalarValue::DurationNanosecond(_) => Ok(DataType::Duration(TimeUnit::Nanosecond)),
342            ScalarValue::Union(_, _, _) => Err(PyNotImplementedError::new_err(
343                "ScalarValue::LargeList".to_string(),
344            )),
345            ScalarValue::Utf8View(_) => Ok(DataType::Utf8View),
346            ScalarValue::BinaryView(_) => Ok(DataType::BinaryView),
347            ScalarValue::Map(_) => Err(PyNotImplementedError::new_err(
348                "ScalarValue::Map".to_string(),
349            )),
350        }
351    }
352}
353
354#[pymethods]
355impl DataTypeMap {
356    #[new]
357    pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self {
358        DataTypeMap {
359            arrow_type,
360            python_type,
361            sql_type,
362        }
363    }
364
365    #[staticmethod]
366    #[pyo3(name = "from_parquet_type_str")]
367    /// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
368    /// with a String type for schema rather than an object type. Here we make a best effort
369    /// to convert that to a physical type.
370    pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
371        let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
372            "boolean" => Ok(DataType::Boolean),
373            "int32" => Ok(DataType::Int32),
374            "int64" => Ok(DataType::Int64),
375            "int96" => {
376                // Int96 is an old parquet datatype that is now deprecated. We convert to nanosecond timestamp
377                Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
378            }
379            "float" => Ok(DataType::Float32),
380            "double" => Ok(DataType::Float64),
381            "byte_array" => Ok(DataType::Utf8),
382            _ => Err(PyValueError::new_err(format!(
383                "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}"
384            ))),
385        };
386        DataTypeMap::map_from_arrow_type(&arrow_dtype?)
387    }
388
389    #[staticmethod]
390    #[pyo3(name = "arrow")]
391    pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
392        DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
393    }
394
395    #[staticmethod]
396    #[pyo3(name = "arrow_str")]
397    pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
398        let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
399        DataTypeMap::map_from_arrow_type(&data_type?.data_type)
400    }
401
402    #[staticmethod]
403    #[pyo3(name = "sql")]
404    pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
405        match sql_type {
406            SqlType::ANY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
407            SqlType::ARRAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
408            SqlType::BIGINT => Ok(DataTypeMap::new(
409                DataType::Int64,
410                PythonType::Int,
411                SqlType::BIGINT,
412            )),
413            SqlType::BINARY => Ok(DataTypeMap::new(
414                DataType::Binary,
415                PythonType::Bytes,
416                SqlType::BINARY,
417            )),
418            SqlType::BOOLEAN => Ok(DataTypeMap::new(
419                DataType::Boolean,
420                PythonType::Bool,
421                SqlType::BOOLEAN,
422            )),
423            SqlType::CHAR => Ok(DataTypeMap::new(
424                DataType::UInt8,
425                PythonType::Int,
426                SqlType::CHAR,
427            )),
428            SqlType::COLUMN_LIST => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
429            SqlType::CURSOR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
430            SqlType::DATE => Ok(DataTypeMap::new(
431                DataType::Date64,
432                PythonType::Datetime,
433                SqlType::DATE,
434            )),
435            SqlType::DECIMAL => Ok(DataTypeMap::new(
436                DataType::Decimal128(1, 1),
437                PythonType::Float,
438                SqlType::DECIMAL,
439            )),
440            SqlType::DISTINCT => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
441            SqlType::DOUBLE => Ok(DataTypeMap::new(
442                DataType::Decimal256(1, 1),
443                PythonType::Float,
444                SqlType::DOUBLE,
445            )),
446            SqlType::DYNAMIC_STAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
447            SqlType::FLOAT => Ok(DataTypeMap::new(
448                DataType::Decimal128(1, 1),
449                PythonType::Float,
450                SqlType::FLOAT,
451            )),
452            SqlType::GEOMETRY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
453            SqlType::INTEGER => Ok(DataTypeMap::new(
454                DataType::Int8,
455                PythonType::Int,
456                SqlType::INTEGER,
457            )),
458            SqlType::INTERVAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
459            SqlType::INTERVAL_DAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
460            SqlType::INTERVAL_DAY_HOUR => {
461                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
462            }
463            SqlType::INTERVAL_DAY_MINUTE => {
464                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
465            }
466            SqlType::INTERVAL_DAY_SECOND => {
467                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
468            }
469            SqlType::INTERVAL_HOUR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
470            SqlType::INTERVAL_HOUR_MINUTE => {
471                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
472            }
473            SqlType::INTERVAL_HOUR_SECOND => {
474                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
475            }
476            SqlType::INTERVAL_MINUTE => {
477                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
478            }
479            SqlType::INTERVAL_MINUTE_SECOND => {
480                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
481            }
482            SqlType::INTERVAL_MONTH => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
483            SqlType::INTERVAL_SECOND => {
484                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
485            }
486            SqlType::INTERVAL_YEAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
487            SqlType::INTERVAL_YEAR_MONTH => {
488                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
489            }
490            SqlType::MAP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
491            SqlType::MULTISET => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
492            SqlType::NULL => Ok(DataTypeMap::new(
493                DataType::Null,
494                PythonType::None,
495                SqlType::NULL,
496            )),
497            SqlType::OTHER => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
498            SqlType::REAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
499            SqlType::ROW => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
500            SqlType::SARG => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
501            SqlType::SMALLINT => Ok(DataTypeMap::new(
502                DataType::Int16,
503                PythonType::Int,
504                SqlType::SMALLINT,
505            )),
506            SqlType::STRUCTURED => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
507            SqlType::SYMBOL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
508            SqlType::TIME => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
509            SqlType::TIME_WITH_LOCAL_TIME_ZONE => {
510                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
511            }
512            SqlType::TIMESTAMP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
513            SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => {
514                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
515            }
516            SqlType::TINYINT => Ok(DataTypeMap::new(
517                DataType::Int8,
518                PythonType::Int,
519                SqlType::TINYINT,
520            )),
521            SqlType::UNKNOWN => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
522            SqlType::VARBINARY => Ok(DataTypeMap::new(
523                DataType::LargeBinary,
524                PythonType::Bytes,
525                SqlType::VARBINARY,
526            )),
527            SqlType::VARCHAR => Ok(DataTypeMap::new(
528                DataType::Utf8,
529                PythonType::Str,
530                SqlType::VARCHAR,
531            )),
532        }
533    }
534
535    /// Unfortunately PyO3 does not allow for us to expose the DataType as an enum since
536    /// we cannot directly annotate the Enum instance of dependency code. Therefore, here
537    /// we provide an enum to mimic it.
538    #[pyo3(name = "friendly_arrow_type_name")]
539    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
540        Ok(match &self.arrow_type.data_type {
541            DataType::Null => "Null",
542            DataType::Boolean => "Boolean",
543            DataType::Int8 => "Int8",
544            DataType::Int16 => "Int16",
545            DataType::Int32 => "Int32",
546            DataType::Int64 => "Int64",
547            DataType::UInt8 => "UInt8",
548            DataType::UInt16 => "UInt16",
549            DataType::UInt32 => "UInt32",
550            DataType::UInt64 => "UInt64",
551            DataType::Float16 => "Float16",
552            DataType::Float32 => "Float32",
553            DataType::Float64 => "Float64",
554            DataType::Timestamp(_, _) => "Timestamp",
555            DataType::Date32 => "Date32",
556            DataType::Date64 => "Date64",
557            DataType::Time32(_) => "Time32",
558            DataType::Time64(_) => "Time64",
559            DataType::Duration(_) => "Duration",
560            DataType::Interval(_) => "Interval",
561            DataType::Binary => "Binary",
562            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
563            DataType::LargeBinary => "LargeBinary",
564            DataType::Utf8 => "Utf8",
565            DataType::LargeUtf8 => "LargeUtf8",
566            DataType::List(_) => "List",
567            DataType::FixedSizeList(_, _) => "FixedSizeList",
568            DataType::LargeList(_) => "LargeList",
569            DataType::Struct(_) => "Struct",
570            DataType::Union(_, _) => "Union",
571            DataType::Dictionary(_, _) => "Dictionary",
572            DataType::Decimal32(_, _) => "Decimal32",
573            DataType::Decimal64(_, _) => "Decimal64",
574            DataType::Decimal128(_, _) => "Decimal128",
575            DataType::Decimal256(_, _) => "Decimal256",
576            DataType::Map(_, _) => "Map",
577            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
578            DataType::BinaryView => "BinaryView",
579            DataType::Utf8View => "Utf8View",
580            DataType::ListView(_) => "ListView",
581            DataType::LargeListView(_) => "LargeListView",
582        })
583    }
584}
585
586/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
587/// Since `DataType` exists in another package we cannot make that happen here so we wrap
588/// `DataType` as `PyDataType` This exists solely to satisfy those constraints.
589#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
590#[pyclass(frozen, name = "DataType", module = "datafusion.common")]
591pub struct PyDataType {
592    pub data_type: DataType,
593}
594
595impl PyDataType {
596    /// There are situations when obtaining dtypes on the Python side where the Arrow type
597    /// is presented as a String rather than an actual DataType. This function is used to
598    /// convert that String to a DataType for the Python side to use.
599    pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
600        // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]"
601        let arrow_str_type = match arrow_str_type.find('[') {
602            Some(index) => arrow_str_type[0..index].to_string(),
603            None => arrow_str_type, // Return early if ',' is not found.
604        };
605
606        let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
607            "bool" => Ok(DataType::Boolean),
608            "boolean" => Ok(DataType::Boolean),
609            "uint8" => Ok(DataType::UInt8),
610            "uint16" => Ok(DataType::UInt16),
611            "uint32" => Ok(DataType::UInt32),
612            "uint64" => Ok(DataType::UInt64),
613            "int8" => Ok(DataType::Int8),
614            "int16" => Ok(DataType::Int16),
615            "int32" => Ok(DataType::Int32),
616            "int64" => Ok(DataType::Int64),
617            "float" => Ok(DataType::Float32),
618            "double" => Ok(DataType::Float64),
619            "float16" => Ok(DataType::Float16),
620            "float32" => Ok(DataType::Float32),
621            "float64" => Ok(DataType::Float64),
622            "datetime64" => Ok(DataType::Date64),
623            "object" => Ok(DataType::Utf8),
624            _ => Err(PyValueError::new_err(format!(
625                "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}"
626            ))),
627        };
628        Ok(PyDataType {
629            data_type: arrow_dtype?,
630        })
631    }
632}
633
634impl From<PyDataType> for DataType {
635    fn from(data_type: PyDataType) -> DataType {
636        data_type.data_type
637    }
638}
639
640impl From<DataType> for PyDataType {
641    fn from(data_type: DataType) -> PyDataType {
642        PyDataType { data_type }
643    }
644}
645
646/// Represents the possible Python types that can be mapped to the SQL types
647#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
648#[pyclass(frozen, eq, eq_int, name = "PythonType", module = "datafusion.common")]
649pub enum PythonType {
650    Array,
651    Bool,
652    Bytes,
653    Datetime,
654    Float,
655    Int,
656    List,
657    None,
658    Object,
659    Str,
660}
661
662/// Represents the types that are possible for DataFusion to parse
663/// from a SQL query. Aka "SqlType" and are valid values for
664/// ANSI SQL
665#[allow(non_camel_case_types)]
666#[allow(clippy::upper_case_acronyms)]
667#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
668#[pyclass(frozen, eq, eq_int, name = "SqlType", module = "datafusion.common")]
669pub enum SqlType {
670    ANY,
671    ARRAY,
672    BIGINT,
673    BINARY,
674    BOOLEAN,
675    CHAR,
676    COLUMN_LIST,
677    CURSOR,
678    DATE,
679    DECIMAL,
680    DISTINCT,
681    DOUBLE,
682    DYNAMIC_STAR,
683    FLOAT,
684    GEOMETRY,
685    INTEGER,
686    INTERVAL,
687    INTERVAL_DAY,
688    INTERVAL_DAY_HOUR,
689    INTERVAL_DAY_MINUTE,
690    INTERVAL_DAY_SECOND,
691    INTERVAL_HOUR,
692    INTERVAL_HOUR_MINUTE,
693    INTERVAL_HOUR_SECOND,
694    INTERVAL_MINUTE,
695    INTERVAL_MINUTE_SECOND,
696    INTERVAL_MONTH,
697    INTERVAL_SECOND,
698    INTERVAL_YEAR,
699    INTERVAL_YEAR_MONTH,
700    MAP,
701    MULTISET,
702    NULL,
703    OTHER,
704    REAL,
705    ROW,
706    SARG,
707    SMALLINT,
708    STRUCTURED,
709    SYMBOL,
710    TIME,
711    TIME_WITH_LOCAL_TIME_ZONE,
712    TIMESTAMP,
713    TIMESTAMP_WITH_LOCAL_TIME_ZONE,
714    TINYINT,
715    UNKNOWN,
716    VARBINARY,
717    VARCHAR,
718}
719
720/// Specifies Ignore / Respect NULL within window functions.
721/// For example
722/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
723#[allow(non_camel_case_types)]
724#[allow(clippy::upper_case_acronyms)]
725#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
726#[pyclass(
727    frozen,
728    eq,
729    eq_int,
730    name = "NullTreatment",
731    module = "datafusion.common"
732)]
733pub enum NullTreatment {
734    IGNORE_NULLS,
735    RESPECT_NULLS,
736}
737
738impl From<NullTreatment> for DFNullTreatment {
739    fn from(null_treatment: NullTreatment) -> DFNullTreatment {
740        match null_treatment {
741            NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls,
742            NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls,
743        }
744    }
745}
746
747impl From<DFNullTreatment> for NullTreatment {
748    fn from(null_treatment: DFNullTreatment) -> NullTreatment {
749        match null_treatment {
750            DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
751            DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
752        }
753    }
754}