datafusion_python/common/
data_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::arrow::array::Array;
19use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
20use datafusion::common::{DataFusionError, ScalarValue};
21use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
22use pyo3::{exceptions::PyValueError, prelude::*};
23
24use crate::errors::py_datafusion_err;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
27pub struct PyScalarValue(pub ScalarValue);
28
29impl From<ScalarValue> for PyScalarValue {
30    fn from(value: ScalarValue) -> Self {
31        Self(value)
32    }
33}
34impl From<PyScalarValue> for ScalarValue {
35    fn from(value: PyScalarValue) -> Self {
36        value.0
37    }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
41#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")]
42pub enum RexType {
43    Alias,
44    Literal,
45    Call,
46    Reference,
47    ScalarSubquery,
48    Other,
49}
50
51/// These bindings are tying together several disparate systems.
52/// You have SQL types for the SQL strings and RDBMS systems itself.
53/// Rust types for the DataFusion code
54/// Arrow types which represents the underlying arrow format
55/// Python types which represent the type in Python
56/// It is important to keep all of those types in a single
57/// and manageable location. Therefore this structure exists
58/// to map those types and provide a simple place for developers
59/// to map types from one system to another.
60#[derive(Debug, Clone)]
61#[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)]
62pub struct DataTypeMap {
63    #[pyo3(get, set)]
64    pub arrow_type: PyDataType,
65    #[pyo3(get, set)]
66    pub python_type: PythonType,
67    #[pyo3(get, set)]
68    pub sql_type: SqlType,
69}
70
71impl DataTypeMap {
72    fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self {
73        DataTypeMap {
74            arrow_type: PyDataType {
75                data_type: arrow_type,
76            },
77            python_type,
78            sql_type,
79        }
80    }
81
82    pub fn map_from_arrow_type(arrow_type: &DataType) -> Result<DataTypeMap, PyErr> {
83        match arrow_type {
84            DataType::Null => Ok(DataTypeMap::new(
85                DataType::Null,
86                PythonType::None,
87                SqlType::NULL,
88            )),
89            DataType::Boolean => Ok(DataTypeMap::new(
90                DataType::Boolean,
91                PythonType::Bool,
92                SqlType::BOOLEAN,
93            )),
94            DataType::Int8 => Ok(DataTypeMap::new(
95                DataType::Int8,
96                PythonType::Int,
97                SqlType::TINYINT,
98            )),
99            DataType::Int16 => Ok(DataTypeMap::new(
100                DataType::Int16,
101                PythonType::Int,
102                SqlType::SMALLINT,
103            )),
104            DataType::Int32 => Ok(DataTypeMap::new(
105                DataType::Int32,
106                PythonType::Int,
107                SqlType::INTEGER,
108            )),
109            DataType::Int64 => Ok(DataTypeMap::new(
110                DataType::Int64,
111                PythonType::Int,
112                SqlType::BIGINT,
113            )),
114            DataType::UInt8 => Ok(DataTypeMap::new(
115                DataType::UInt8,
116                PythonType::Int,
117                SqlType::TINYINT,
118            )),
119            DataType::UInt16 => Ok(DataTypeMap::new(
120                DataType::UInt16,
121                PythonType::Int,
122                SqlType::SMALLINT,
123            )),
124            DataType::UInt32 => Ok(DataTypeMap::new(
125                DataType::UInt32,
126                PythonType::Int,
127                SqlType::INTEGER,
128            )),
129            DataType::UInt64 => Ok(DataTypeMap::new(
130                DataType::UInt64,
131                PythonType::Int,
132                SqlType::BIGINT,
133            )),
134            DataType::Float16 => Ok(DataTypeMap::new(
135                DataType::Float16,
136                PythonType::Float,
137                SqlType::FLOAT,
138            )),
139            DataType::Float32 => Ok(DataTypeMap::new(
140                DataType::Float32,
141                PythonType::Float,
142                SqlType::FLOAT,
143            )),
144            DataType::Float64 => Ok(DataTypeMap::new(
145                DataType::Float64,
146                PythonType::Float,
147                SqlType::FLOAT,
148            )),
149            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
150                DataType::Timestamp(*unit, tz.clone()),
151                PythonType::Datetime,
152                SqlType::DATE,
153            )),
154            DataType::Date32 => Ok(DataTypeMap::new(
155                DataType::Date32,
156                PythonType::Datetime,
157                SqlType::DATE,
158            )),
159            DataType::Date64 => Ok(DataTypeMap::new(
160                DataType::Date64,
161                PythonType::Datetime,
162                SqlType::DATE,
163            )),
164            DataType::Time32(unit) => Ok(DataTypeMap::new(
165                DataType::Time32(*unit),
166                PythonType::Datetime,
167                SqlType::DATE,
168            )),
169            DataType::Time64(unit) => Ok(DataTypeMap::new(
170                DataType::Time64(*unit),
171                PythonType::Datetime,
172                SqlType::DATE,
173            )),
174            DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
175                format!("{:?}", arrow_type),
176            ))),
177            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
178                DataType::Interval(*interval_unit),
179                PythonType::Datetime,
180                match interval_unit {
181                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
182                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
183                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
184                },
185            )),
186            DataType::Binary => Ok(DataTypeMap::new(
187                DataType::Binary,
188                PythonType::Bytes,
189                SqlType::BINARY,
190            )),
191            DataType::FixedSizeBinary(_) => Err(py_datafusion_err(
192                DataFusionError::NotImplemented(format!("{:?}", arrow_type)),
193            )),
194            DataType::LargeBinary => Ok(DataTypeMap::new(
195                DataType::LargeBinary,
196                PythonType::Bytes,
197                SqlType::BINARY,
198            )),
199            DataType::Utf8 => Ok(DataTypeMap::new(
200                DataType::Utf8,
201                PythonType::Str,
202                SqlType::VARCHAR,
203            )),
204            DataType::LargeUtf8 => Ok(DataTypeMap::new(
205                DataType::LargeUtf8,
206                PythonType::Str,
207                SqlType::VARCHAR,
208            )),
209            DataType::List(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
210                "{:?}",
211                arrow_type
212            )))),
213            DataType::FixedSizeList(_, _) => Err(py_datafusion_err(
214                DataFusionError::NotImplemented(format!("{:?}", arrow_type)),
215            )),
216            DataType::LargeList(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
217                format!("{:?}", arrow_type),
218            ))),
219            DataType::Struct(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
220                format!("{:?}", arrow_type),
221            ))),
222            DataType::Union(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
223                format!("{:?}", arrow_type),
224            ))),
225            DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
226                format!("{:?}", arrow_type),
227            ))),
228            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
229                DataType::Decimal128(*precision, *scale),
230                PythonType::Float,
231                SqlType::DECIMAL,
232            )),
233            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
234                DataType::Decimal256(*precision, *scale),
235                PythonType::Float,
236                SqlType::DECIMAL,
237            )),
238            DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
239                format!("{:?}", arrow_type),
240            ))),
241            DataType::RunEndEncoded(_, _) => Err(py_datafusion_err(
242                DataFusionError::NotImplemented(format!("{:?}", arrow_type)),
243            )),
244            DataType::BinaryView => Err(py_datafusion_err(DataFusionError::NotImplemented(
245                format!("{:?}", arrow_type),
246            ))),
247            DataType::Utf8View => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
248                "{:?}",
249                arrow_type
250            )))),
251            DataType::ListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
252                format!("{:?}", arrow_type),
253            ))),
254            DataType::LargeListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
255                format!("{:?}", arrow_type),
256            ))),
257        }
258    }
259
260    /// Generate the `DataTypeMap` from a `ScalarValue` instance
261    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
262        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
263    }
264
265    /// Maps a `ScalarValue` to an Arrow `DataType`
266    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
267        match scalar_val {
268            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
269            ScalarValue::Float16(_) => Ok(DataType::Float16),
270            ScalarValue::Float32(_) => Ok(DataType::Float32),
271            ScalarValue::Float64(_) => Ok(DataType::Float64),
272            ScalarValue::Decimal128(_, precision, scale) => {
273                Ok(DataType::Decimal128(*precision, *scale))
274            }
275            ScalarValue::Decimal256(_, precision, scale) => {
276                Ok(DataType::Decimal256(*precision, *scale))
277            }
278            ScalarValue::Dictionary(data_type, scalar_type) => {
279                // Call this function again to map the dictionary scalar_value to an Arrow type
280                Ok(DataType::Dictionary(
281                    Box::new(*data_type.clone()),
282                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
283                ))
284            }
285            ScalarValue::Int8(_) => Ok(DataType::Int8),
286            ScalarValue::Int16(_) => Ok(DataType::Int16),
287            ScalarValue::Int32(_) => Ok(DataType::Int32),
288            ScalarValue::Int64(_) => Ok(DataType::Int64),
289            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
290            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
291            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
292            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
293            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
294            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
295            ScalarValue::Binary(_) => Ok(DataType::Binary),
296            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
297            ScalarValue::Date32(_) => Ok(DataType::Date32),
298            ScalarValue::Date64(_) => Ok(DataType::Date64),
299            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
300            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
301            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
302            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
303            ScalarValue::Null => Ok(DataType::Null),
304            ScalarValue::TimestampSecond(_, tz) => {
305                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
306            }
307            ScalarValue::TimestampMillisecond(_, tz) => {
308                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
309            }
310            ScalarValue::TimestampMicrosecond(_, tz) => {
311                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
312            }
313            ScalarValue::TimestampNanosecond(_, tz) => {
314                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
315            }
316            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
317            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
318            ScalarValue::IntervalMonthDayNano(..) => {
319                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
320            }
321            ScalarValue::List(arr) => Ok(arr.data_type().to_owned()),
322            ScalarValue::Struct(_fields) => Err(py_datafusion_err(
323                DataFusionError::NotImplemented("ScalarValue::Struct".to_string()),
324            )),
325            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
326            ScalarValue::FixedSizeList(_array_ref) => {
327                // The FieldRef was removed from ScalarValue::FixedSizeList in
328                // https://github.com/apache/arrow-datafusion/pull/8221, so we can no
329                // longer convert back to a DataType here
330                Err(py_datafusion_err(DataFusionError::NotImplemented(
331                    "ScalarValue::FixedSizeList".to_string(),
332                )))
333            }
334            ScalarValue::LargeList(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
335                "ScalarValue::LargeList".to_string(),
336            ))),
337            ScalarValue::DurationSecond(_) => Ok(DataType::Duration(TimeUnit::Second)),
338            ScalarValue::DurationMillisecond(_) => Ok(DataType::Duration(TimeUnit::Millisecond)),
339            ScalarValue::DurationMicrosecond(_) => Ok(DataType::Duration(TimeUnit::Microsecond)),
340            ScalarValue::DurationNanosecond(_) => Ok(DataType::Duration(TimeUnit::Nanosecond)),
341            ScalarValue::Union(_, _, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
342                "ScalarValue::LargeList".to_string(),
343            ))),
344            ScalarValue::Utf8View(_) => Ok(DataType::Utf8View),
345            ScalarValue::BinaryView(_) => Ok(DataType::BinaryView),
346            ScalarValue::Map(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
347                "ScalarValue::Map".to_string(),
348            ))),
349        }
350    }
351}
352
353#[pymethods]
354impl DataTypeMap {
355    #[new]
356    pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self {
357        DataTypeMap {
358            arrow_type,
359            python_type,
360            sql_type,
361        }
362    }
363
364    #[staticmethod]
365    #[pyo3(name = "from_parquet_type_str")]
366    /// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
367    /// with a String type for schema rather than an object type. Here we make a best effort
368    /// to convert that to a physical type.
369    pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
370        let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
371            "boolean" => Ok(DataType::Boolean),
372            "int32" => Ok(DataType::Int32),
373            "int64" => Ok(DataType::Int64),
374            "int96" => {
375                // Int96 is an old parquet datatype that is now deprecated. We convert to nanosecond timestamp
376                Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
377            }
378            "float" => Ok(DataType::Float32),
379            "double" => Ok(DataType::Float64),
380            "byte_array" => Ok(DataType::Utf8),
381            _ => Err(PyValueError::new_err(format!(
382                "Unable to determine Arrow Data Type from Parquet String type: {:?}",
383                parquet_str_type
384            ))),
385        };
386        DataTypeMap::map_from_arrow_type(&arrow_dtype?)
387    }
388
389    #[staticmethod]
390    #[pyo3(name = "arrow")]
391    pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
392        DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
393    }
394
395    #[staticmethod]
396    #[pyo3(name = "arrow_str")]
397    pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
398        let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
399        DataTypeMap::map_from_arrow_type(&data_type?.data_type)
400    }
401
402    #[staticmethod]
403    #[pyo3(name = "sql")]
404    pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
405        match sql_type {
406            SqlType::ANY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
407                "{:?}",
408                sql_type
409            )))),
410            SqlType::ARRAY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
411                "{:?}",
412                sql_type
413            )))),
414            SqlType::BIGINT => Ok(DataTypeMap::new(
415                DataType::Int64,
416                PythonType::Int,
417                SqlType::BIGINT,
418            )),
419            SqlType::BINARY => Ok(DataTypeMap::new(
420                DataType::Binary,
421                PythonType::Bytes,
422                SqlType::BINARY,
423            )),
424            SqlType::BOOLEAN => Ok(DataTypeMap::new(
425                DataType::Boolean,
426                PythonType::Bool,
427                SqlType::BOOLEAN,
428            )),
429            SqlType::CHAR => Ok(DataTypeMap::new(
430                DataType::UInt8,
431                PythonType::Int,
432                SqlType::CHAR,
433            )),
434            SqlType::COLUMN_LIST => Err(py_datafusion_err(DataFusionError::NotImplemented(
435                format!("{:?}", sql_type),
436            ))),
437            SqlType::CURSOR => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
438                "{:?}",
439                sql_type
440            )))),
441            SqlType::DATE => Ok(DataTypeMap::new(
442                DataType::Date64,
443                PythonType::Datetime,
444                SqlType::DATE,
445            )),
446            SqlType::DECIMAL => Ok(DataTypeMap::new(
447                DataType::Decimal128(1, 1),
448                PythonType::Float,
449                SqlType::DECIMAL,
450            )),
451            SqlType::DISTINCT => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
452                "{:?}",
453                sql_type
454            )))),
455            SqlType::DOUBLE => Ok(DataTypeMap::new(
456                DataType::Decimal256(1, 1),
457                PythonType::Float,
458                SqlType::DOUBLE,
459            )),
460            SqlType::DYNAMIC_STAR => Err(py_datafusion_err(DataFusionError::NotImplemented(
461                format!("{:?}", sql_type),
462            ))),
463            SqlType::FLOAT => Ok(DataTypeMap::new(
464                DataType::Decimal128(1, 1),
465                PythonType::Float,
466                SqlType::FLOAT,
467            )),
468            SqlType::GEOMETRY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
469                "{:?}",
470                sql_type
471            )))),
472            SqlType::INTEGER => Ok(DataTypeMap::new(
473                DataType::Int8,
474                PythonType::Int,
475                SqlType::INTEGER,
476            )),
477            SqlType::INTERVAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
478                "{:?}",
479                sql_type
480            )))),
481            SqlType::INTERVAL_DAY => Err(py_datafusion_err(DataFusionError::NotImplemented(
482                format!("{:?}", sql_type),
483            ))),
484            SqlType::INTERVAL_DAY_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented(
485                format!("{:?}", sql_type),
486            ))),
487            SqlType::INTERVAL_DAY_MINUTE => Err(py_datafusion_err(
488                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
489            )),
490            SqlType::INTERVAL_DAY_SECOND => Err(py_datafusion_err(
491                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
492            )),
493            SqlType::INTERVAL_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented(
494                format!("{:?}", sql_type),
495            ))),
496            SqlType::INTERVAL_HOUR_MINUTE => Err(py_datafusion_err(
497                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
498            )),
499            SqlType::INTERVAL_HOUR_SECOND => Err(py_datafusion_err(
500                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
501            )),
502            SqlType::INTERVAL_MINUTE => Err(py_datafusion_err(DataFusionError::NotImplemented(
503                format!("{:?}", sql_type),
504            ))),
505            SqlType::INTERVAL_MINUTE_SECOND => Err(py_datafusion_err(
506                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
507            )),
508            SqlType::INTERVAL_MONTH => Err(py_datafusion_err(DataFusionError::NotImplemented(
509                format!("{:?}", sql_type),
510            ))),
511            SqlType::INTERVAL_SECOND => Err(py_datafusion_err(DataFusionError::NotImplemented(
512                format!("{:?}", sql_type),
513            ))),
514            SqlType::INTERVAL_YEAR => Err(py_datafusion_err(DataFusionError::NotImplemented(
515                format!("{:?}", sql_type),
516            ))),
517            SqlType::INTERVAL_YEAR_MONTH => Err(py_datafusion_err(
518                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
519            )),
520            SqlType::MAP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
521                "{:?}",
522                sql_type
523            )))),
524            SqlType::MULTISET => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
525                "{:?}",
526                sql_type
527            )))),
528            SqlType::NULL => Ok(DataTypeMap::new(
529                DataType::Null,
530                PythonType::None,
531                SqlType::NULL,
532            )),
533            SqlType::OTHER => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
534                "{:?}",
535                sql_type
536            )))),
537            SqlType::REAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
538                "{:?}",
539                sql_type
540            )))),
541            SqlType::ROW => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
542                "{:?}",
543                sql_type
544            )))),
545            SqlType::SARG => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
546                "{:?}",
547                sql_type
548            )))),
549            SqlType::SMALLINT => Ok(DataTypeMap::new(
550                DataType::Int16,
551                PythonType::Int,
552                SqlType::SMALLINT,
553            )),
554            SqlType::STRUCTURED => Err(py_datafusion_err(DataFusionError::NotImplemented(
555                format!("{:?}", sql_type),
556            ))),
557            SqlType::SYMBOL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
558                "{:?}",
559                sql_type
560            )))),
561            SqlType::TIME => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
562                "{:?}",
563                sql_type
564            )))),
565            SqlType::TIME_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err(
566                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
567            )),
568            SqlType::TIMESTAMP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
569                "{:?}",
570                sql_type
571            )))),
572            SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err(
573                DataFusionError::NotImplemented(format!("{:?}", sql_type)),
574            )),
575            SqlType::TINYINT => Ok(DataTypeMap::new(
576                DataType::Int8,
577                PythonType::Int,
578                SqlType::TINYINT,
579            )),
580            SqlType::UNKNOWN => Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
581                "{:?}",
582                sql_type
583            )))),
584            SqlType::VARBINARY => Ok(DataTypeMap::new(
585                DataType::LargeBinary,
586                PythonType::Bytes,
587                SqlType::VARBINARY,
588            )),
589            SqlType::VARCHAR => Ok(DataTypeMap::new(
590                DataType::Utf8,
591                PythonType::Str,
592                SqlType::VARCHAR,
593            )),
594        }
595    }
596
597    /// Unfortunately PyO3 does not allow for us to expose the DataType as an enum since
598    /// we cannot directly annotae the Enum instance of dependency code. Therefore, here
599    /// we provide an enum to mimic it.
600    #[pyo3(name = "friendly_arrow_type_name")]
601    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
602        Ok(match &self.arrow_type.data_type {
603            DataType::Null => "Null",
604            DataType::Boolean => "Boolean",
605            DataType::Int8 => "Int8",
606            DataType::Int16 => "Int16",
607            DataType::Int32 => "Int32",
608            DataType::Int64 => "Int64",
609            DataType::UInt8 => "UInt8",
610            DataType::UInt16 => "UInt16",
611            DataType::UInt32 => "UInt32",
612            DataType::UInt64 => "UInt64",
613            DataType::Float16 => "Float16",
614            DataType::Float32 => "Float32",
615            DataType::Float64 => "Float64",
616            DataType::Timestamp(_, _) => "Timestamp",
617            DataType::Date32 => "Date32",
618            DataType::Date64 => "Date64",
619            DataType::Time32(_) => "Time32",
620            DataType::Time64(_) => "Time64",
621            DataType::Duration(_) => "Duration",
622            DataType::Interval(_) => "Interval",
623            DataType::Binary => "Binary",
624            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
625            DataType::LargeBinary => "LargeBinary",
626            DataType::Utf8 => "Utf8",
627            DataType::LargeUtf8 => "LargeUtf8",
628            DataType::List(_) => "List",
629            DataType::FixedSizeList(_, _) => "FixedSizeList",
630            DataType::LargeList(_) => "LargeList",
631            DataType::Struct(_) => "Struct",
632            DataType::Union(_, _) => "Union",
633            DataType::Dictionary(_, _) => "Dictionary",
634            DataType::Decimal128(_, _) => "Decimal128",
635            DataType::Decimal256(_, _) => "Decimal256",
636            DataType::Map(_, _) => "Map",
637            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
638            DataType::BinaryView => "BinaryView",
639            DataType::Utf8View => "Utf8View",
640            DataType::ListView(_) => "ListView",
641            DataType::LargeListView(_) => "LargeListView",
642        })
643    }
644}
645
646/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
647/// Since `DataType` exists in another package we cannot make that happen here so we wrap
648/// `DataType` as `PyDataType` This exists solely to satisfy those constraints.
649#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
650#[pyclass(name = "DataType", module = "datafusion.common")]
651pub struct PyDataType {
652    pub data_type: DataType,
653}
654
655impl PyDataType {
656    /// There are situations when obtaining dtypes on the Python side where the Arrow type
657    /// is presented as a String rather than an actual DataType. This function is used to
658    /// convert that String to a DataType for the Python side to use.
659    pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
660        // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]"
661        let arrow_str_type = match arrow_str_type.find('[') {
662            Some(index) => arrow_str_type[0..index].to_string(),
663            None => arrow_str_type, // Return early if ',' is not found.
664        };
665
666        let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
667            "bool" => Ok(DataType::Boolean),
668            "boolean" => Ok(DataType::Boolean),
669            "uint8" => Ok(DataType::UInt8),
670            "uint16" => Ok(DataType::UInt16),
671            "uint32" => Ok(DataType::UInt32),
672            "uint64" => Ok(DataType::UInt64),
673            "int8" => Ok(DataType::Int8),
674            "int16" => Ok(DataType::Int16),
675            "int32" => Ok(DataType::Int32),
676            "int64" => Ok(DataType::Int64),
677            "float" => Ok(DataType::Float32),
678            "double" => Ok(DataType::Float64),
679            "float16" => Ok(DataType::Float16),
680            "float32" => Ok(DataType::Float32),
681            "float64" => Ok(DataType::Float64),
682            "datetime64" => Ok(DataType::Date64),
683            "object" => Ok(DataType::Utf8),
684            _ => Err(PyValueError::new_err(format!(
685                "Unable to determine Arrow Data Type from Arrow String type: {:?}",
686                arrow_str_type
687            ))),
688        };
689        Ok(PyDataType {
690            data_type: arrow_dtype?,
691        })
692    }
693}
694
695impl From<PyDataType> for DataType {
696    fn from(data_type: PyDataType) -> DataType {
697        data_type.data_type
698    }
699}
700
701impl From<DataType> for PyDataType {
702    fn from(data_type: DataType) -> PyDataType {
703        PyDataType { data_type }
704    }
705}
706
707/// Represents the possible Python types that can be mapped to the SQL types
708#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
709#[pyclass(eq, eq_int, name = "PythonType", module = "datafusion.common")]
710pub enum PythonType {
711    Array,
712    Bool,
713    Bytes,
714    Datetime,
715    Float,
716    Int,
717    List,
718    None,
719    Object,
720    Str,
721}
722
723/// Represents the types that are possible for DataFusion to parse
724/// from a SQL query. Aka "SqlType" and are valid values for
725/// ANSI SQL
726#[allow(non_camel_case_types)]
727#[allow(clippy::upper_case_acronyms)]
728#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
729#[pyclass(eq, eq_int, name = "SqlType", module = "datafusion.common")]
730pub enum SqlType {
731    ANY,
732    ARRAY,
733    BIGINT,
734    BINARY,
735    BOOLEAN,
736    CHAR,
737    COLUMN_LIST,
738    CURSOR,
739    DATE,
740    DECIMAL,
741    DISTINCT,
742    DOUBLE,
743    DYNAMIC_STAR,
744    FLOAT,
745    GEOMETRY,
746    INTEGER,
747    INTERVAL,
748    INTERVAL_DAY,
749    INTERVAL_DAY_HOUR,
750    INTERVAL_DAY_MINUTE,
751    INTERVAL_DAY_SECOND,
752    INTERVAL_HOUR,
753    INTERVAL_HOUR_MINUTE,
754    INTERVAL_HOUR_SECOND,
755    INTERVAL_MINUTE,
756    INTERVAL_MINUTE_SECOND,
757    INTERVAL_MONTH,
758    INTERVAL_SECOND,
759    INTERVAL_YEAR,
760    INTERVAL_YEAR_MONTH,
761    MAP,
762    MULTISET,
763    NULL,
764    OTHER,
765    REAL,
766    ROW,
767    SARG,
768    SMALLINT,
769    STRUCTURED,
770    SYMBOL,
771    TIME,
772    TIME_WITH_LOCAL_TIME_ZONE,
773    TIMESTAMP,
774    TIMESTAMP_WITH_LOCAL_TIME_ZONE,
775    TINYINT,
776    UNKNOWN,
777    VARBINARY,
778    VARCHAR,
779}
780
781/// Specifies Ignore / Respect NULL within window functions.
782/// For example
783/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
784#[allow(non_camel_case_types)]
785#[allow(clippy::upper_case_acronyms)]
786#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
787#[pyclass(eq, eq_int, name = "NullTreatment", module = "datafusion.common")]
788pub enum NullTreatment {
789    IGNORE_NULLS,
790    RESPECT_NULLS,
791}
792
793impl From<NullTreatment> for DFNullTreatment {
794    fn from(null_treatment: NullTreatment) -> DFNullTreatment {
795        match null_treatment {
796            NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls,
797            NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls,
798        }
799    }
800}
801
802impl From<DFNullTreatment> for NullTreatment {
803    fn from(null_treatment: DFNullTreatment) -> NullTreatment {
804        match null_treatment {
805            DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
806            DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
807        }
808    }
809}