Skip to main content

datafusion_python/common/
data_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::Array;
21use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use datafusion::common::ScalarValue;
23use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment;
24use pyo3::exceptions::{PyNotImplementedError, PyValueError};
25use pyo3::prelude::*;
26
27/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion
28/// from a variety of Python objects into a [`ScalarValue`]. See
29/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
30#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
31pub struct PyScalarValue(pub ScalarValue);
32
33impl From<ScalarValue> for PyScalarValue {
34    fn from(value: ScalarValue) -> Self {
35        Self(value)
36    }
37}
38impl From<PyScalarValue> for ScalarValue {
39    fn from(value: PyScalarValue) -> Self {
40        value.0
41    }
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
45#[pyclass(
46    from_py_object,
47    frozen,
48    eq,
49    eq_int,
50    name = "RexType",
51    module = "datafusion.common"
52)]
53pub enum RexType {
54    Alias,
55    Literal,
56    Call,
57    Reference,
58    ScalarSubquery,
59    Other,
60}
61
62/// These bindings are tying together several disparate systems.
63/// You have SQL types for the SQL strings and RDBMS systems itself.
64/// Rust types for the DataFusion code
65/// Arrow types which represents the underlying arrow format
66/// Python types which represent the type in Python
67/// It is important to keep all of those types in a single
68/// and manageable location. Therefore this structure exists
69/// to map those types and provide a simple place for developers
70/// to map types from one system to another.
71// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now
72#[derive(Debug, Clone)]
73#[pyclass(
74    from_py_object,
75    name = "DataTypeMap",
76    module = "datafusion.common",
77    subclass
78)]
79pub struct DataTypeMap {
80    #[pyo3(get, set)]
81    pub arrow_type: PyDataType,
82    #[pyo3(get, set)]
83    pub python_type: PythonType,
84    #[pyo3(get, set)]
85    pub sql_type: SqlType,
86}
87
88impl DataTypeMap {
89    fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self {
90        DataTypeMap {
91            arrow_type: PyDataType {
92                data_type: arrow_type,
93            },
94            python_type,
95            sql_type,
96        }
97    }
98
99    pub fn map_from_arrow_type(arrow_type: &DataType) -> Result<DataTypeMap, PyErr> {
100        match arrow_type {
101            DataType::Null => Ok(DataTypeMap::new(
102                DataType::Null,
103                PythonType::None,
104                SqlType::NULL,
105            )),
106            DataType::Boolean => Ok(DataTypeMap::new(
107                DataType::Boolean,
108                PythonType::Bool,
109                SqlType::BOOLEAN,
110            )),
111            DataType::Int8 => Ok(DataTypeMap::new(
112                DataType::Int8,
113                PythonType::Int,
114                SqlType::TINYINT,
115            )),
116            DataType::Int16 => Ok(DataTypeMap::new(
117                DataType::Int16,
118                PythonType::Int,
119                SqlType::SMALLINT,
120            )),
121            DataType::Int32 => Ok(DataTypeMap::new(
122                DataType::Int32,
123                PythonType::Int,
124                SqlType::INTEGER,
125            )),
126            DataType::Int64 => Ok(DataTypeMap::new(
127                DataType::Int64,
128                PythonType::Int,
129                SqlType::BIGINT,
130            )),
131            DataType::UInt8 => Ok(DataTypeMap::new(
132                DataType::UInt8,
133                PythonType::Int,
134                SqlType::TINYINT,
135            )),
136            DataType::UInt16 => Ok(DataTypeMap::new(
137                DataType::UInt16,
138                PythonType::Int,
139                SqlType::SMALLINT,
140            )),
141            DataType::UInt32 => Ok(DataTypeMap::new(
142                DataType::UInt32,
143                PythonType::Int,
144                SqlType::INTEGER,
145            )),
146            DataType::UInt64 => Ok(DataTypeMap::new(
147                DataType::UInt64,
148                PythonType::Int,
149                SqlType::BIGINT,
150            )),
151            DataType::Float16 => Ok(DataTypeMap::new(
152                DataType::Float16,
153                PythonType::Float,
154                SqlType::FLOAT,
155            )),
156            DataType::Float32 => Ok(DataTypeMap::new(
157                DataType::Float32,
158                PythonType::Float,
159                SqlType::FLOAT,
160            )),
161            DataType::Float64 => Ok(DataTypeMap::new(
162                DataType::Float64,
163                PythonType::Float,
164                SqlType::FLOAT,
165            )),
166            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
167                DataType::Timestamp(*unit, tz.clone()),
168                PythonType::Datetime,
169                SqlType::DATE,
170            )),
171            DataType::Date32 => Ok(DataTypeMap::new(
172                DataType::Date32,
173                PythonType::Datetime,
174                SqlType::DATE,
175            )),
176            DataType::Date64 => Ok(DataTypeMap::new(
177                DataType::Date64,
178                PythonType::Datetime,
179                SqlType::DATE,
180            )),
181            DataType::Time32(unit) => Ok(DataTypeMap::new(
182                DataType::Time32(*unit),
183                PythonType::Datetime,
184                SqlType::DATE,
185            )),
186            DataType::Time64(unit) => Ok(DataTypeMap::new(
187                DataType::Time64(*unit),
188                PythonType::Datetime,
189                SqlType::DATE,
190            )),
191            DataType::Duration(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
192            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
193                DataType::Interval(*interval_unit),
194                PythonType::Datetime,
195                match interval_unit {
196                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
197                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
198                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
199                },
200            )),
201            DataType::Binary => Ok(DataTypeMap::new(
202                DataType::Binary,
203                PythonType::Bytes,
204                SqlType::BINARY,
205            )),
206            DataType::FixedSizeBinary(_) => {
207                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
208            }
209            DataType::LargeBinary => Ok(DataTypeMap::new(
210                DataType::LargeBinary,
211                PythonType::Bytes,
212                SqlType::BINARY,
213            )),
214            DataType::Utf8 => Ok(DataTypeMap::new(
215                DataType::Utf8,
216                PythonType::Str,
217                SqlType::VARCHAR,
218            )),
219            DataType::LargeUtf8 => Ok(DataTypeMap::new(
220                DataType::LargeUtf8,
221                PythonType::Str,
222                SqlType::VARCHAR,
223            )),
224            DataType::List(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
225            DataType::FixedSizeList(_, _) => {
226                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
227            }
228            DataType::LargeList(_) => {
229                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
230            }
231            DataType::Struct(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
232            DataType::Union(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
233            DataType::Dictionary(_, _) => {
234                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
235            }
236            DataType::Decimal32(precision, scale) => Ok(DataTypeMap::new(
237                DataType::Decimal32(*precision, *scale),
238                PythonType::Float,
239                SqlType::DECIMAL,
240            )),
241            DataType::Decimal64(precision, scale) => Ok(DataTypeMap::new(
242                DataType::Decimal64(*precision, *scale),
243                PythonType::Float,
244                SqlType::DECIMAL,
245            )),
246            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
247                DataType::Decimal128(*precision, *scale),
248                PythonType::Float,
249                SqlType::DECIMAL,
250            )),
251            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
252                DataType::Decimal256(*precision, *scale),
253                PythonType::Float,
254                SqlType::DECIMAL,
255            )),
256            DataType::Map(_, _) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
257            DataType::RunEndEncoded(_, _) => {
258                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
259            }
260            DataType::BinaryView => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
261            DataType::Utf8View => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
262            DataType::ListView(_) => Err(PyNotImplementedError::new_err(format!("{arrow_type:?}"))),
263            DataType::LargeListView(_) => {
264                Err(PyNotImplementedError::new_err(format!("{arrow_type:?}")))
265            }
266        }
267    }
268
269    /// Generate the `DataTypeMap` from a `ScalarValue` instance
270    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
271        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
272    }
273
274    /// Maps a `ScalarValue` to an Arrow `DataType`
275    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
276        match scalar_val {
277            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
278            ScalarValue::Float16(_) => Ok(DataType::Float16),
279            ScalarValue::Float32(_) => Ok(DataType::Float32),
280            ScalarValue::Float64(_) => Ok(DataType::Float64),
281            ScalarValue::Decimal32(_, precision, scale) => {
282                Ok(DataType::Decimal32(*precision, *scale))
283            }
284            ScalarValue::Decimal64(_, precision, scale) => {
285                Ok(DataType::Decimal64(*precision, *scale))
286            }
287            ScalarValue::Decimal128(_, precision, scale) => {
288                Ok(DataType::Decimal128(*precision, *scale))
289            }
290            ScalarValue::Decimal256(_, precision, scale) => {
291                Ok(DataType::Decimal256(*precision, *scale))
292            }
293            ScalarValue::Dictionary(data_type, scalar_type) => {
294                // Call this function again to map the dictionary scalar_value to an Arrow type
295                Ok(DataType::Dictionary(
296                    Box::new(*data_type.clone()),
297                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
298                ))
299            }
300            ScalarValue::Int8(_) => Ok(DataType::Int8),
301            ScalarValue::Int16(_) => Ok(DataType::Int16),
302            ScalarValue::Int32(_) => Ok(DataType::Int32),
303            ScalarValue::Int64(_) => Ok(DataType::Int64),
304            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
305            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
306            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
307            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
308            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
309            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
310            ScalarValue::Binary(_) => Ok(DataType::Binary),
311            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
312            ScalarValue::Date32(_) => Ok(DataType::Date32),
313            ScalarValue::Date64(_) => Ok(DataType::Date64),
314            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
315            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
316            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
317            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
318            ScalarValue::Null => Ok(DataType::Null),
319            ScalarValue::TimestampSecond(_, tz) => {
320                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
321            }
322            ScalarValue::TimestampMillisecond(_, tz) => {
323                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
324            }
325            ScalarValue::TimestampMicrosecond(_, tz) => {
326                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
327            }
328            ScalarValue::TimestampNanosecond(_, tz) => {
329                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
330            }
331            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
332            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
333            ScalarValue::IntervalMonthDayNano(..) => {
334                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
335            }
336            ScalarValue::List(arr) => Ok(arr.data_type().to_owned()),
337            ScalarValue::Struct(_fields) => Err(PyNotImplementedError::new_err(
338                "ScalarValue::Struct".to_string(),
339            )),
340            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
341            ScalarValue::FixedSizeList(_array_ref) => {
342                // The FieldRef was removed from ScalarValue::FixedSizeList in
343                // https://github.com/apache/arrow-datafusion/pull/8221, so we can no
344                // longer convert back to a DataType here
345                Err(PyNotImplementedError::new_err(
346                    "ScalarValue::FixedSizeList".to_string(),
347                ))
348            }
349            ScalarValue::LargeList(_) => Err(PyNotImplementedError::new_err(
350                "ScalarValue::LargeList".to_string(),
351            )),
352            ScalarValue::DurationSecond(_) => Ok(DataType::Duration(TimeUnit::Second)),
353            ScalarValue::DurationMillisecond(_) => Ok(DataType::Duration(TimeUnit::Millisecond)),
354            ScalarValue::DurationMicrosecond(_) => Ok(DataType::Duration(TimeUnit::Microsecond)),
355            ScalarValue::DurationNanosecond(_) => Ok(DataType::Duration(TimeUnit::Nanosecond)),
356            ScalarValue::Union(_, _, _) => Err(PyNotImplementedError::new_err(
357                "ScalarValue::LargeList".to_string(),
358            )),
359            ScalarValue::Utf8View(_) => Ok(DataType::Utf8View),
360            ScalarValue::BinaryView(_) => Ok(DataType::BinaryView),
361            ScalarValue::Map(_) => Err(PyNotImplementedError::new_err(
362                "ScalarValue::Map".to_string(),
363            )),
364            ScalarValue::RunEndEncoded(field1, field2, _) => Ok(DataType::RunEndEncoded(
365                Arc::clone(field1),
366                Arc::clone(field2),
367            )),
368        }
369    }
370}
371
372#[pymethods]
373impl DataTypeMap {
374    #[new]
375    pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self {
376        DataTypeMap {
377            arrow_type,
378            python_type,
379            sql_type,
380        }
381    }
382
383    #[staticmethod]
384    #[pyo3(name = "from_parquet_type_str")]
385    /// When using pyarrow.parquet.read_metadata().schema.column(x).physical_type you are presented
386    /// with a String type for schema rather than an object type. Here we make a best effort
387    /// to convert that to a physical type.
388    pub fn py_map_from_parquet_type_str(parquet_str_type: String) -> PyResult<DataTypeMap> {
389        let arrow_dtype = match parquet_str_type.to_lowercase().as_str() {
390            "boolean" => Ok(DataType::Boolean),
391            "int32" => Ok(DataType::Int32),
392            "int64" => Ok(DataType::Int64),
393            "int96" => {
394                // Int96 is an old parquet datatype that is now deprecated. We convert to nanosecond timestamp
395                Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
396            }
397            "float" => Ok(DataType::Float32),
398            "double" => Ok(DataType::Float64),
399            "byte_array" => Ok(DataType::Utf8),
400            _ => Err(PyValueError::new_err(format!(
401                "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}"
402            ))),
403        };
404        DataTypeMap::map_from_arrow_type(&arrow_dtype?)
405    }
406
407    #[staticmethod]
408    #[pyo3(name = "arrow")]
409    pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult<DataTypeMap> {
410        DataTypeMap::map_from_arrow_type(&arrow_type.data_type)
411    }
412
413    #[staticmethod]
414    #[pyo3(name = "arrow_str")]
415    pub fn py_map_from_arrow_type_str(arrow_type_str: String) -> PyResult<DataTypeMap> {
416        let data_type = PyDataType::py_map_from_arrow_type_str(arrow_type_str);
417        DataTypeMap::map_from_arrow_type(&data_type?.data_type)
418    }
419
420    #[staticmethod]
421    #[pyo3(name = "sql")]
422    pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult<DataTypeMap> {
423        match sql_type {
424            SqlType::ANY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
425            SqlType::ARRAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
426            SqlType::BIGINT => Ok(DataTypeMap::new(
427                DataType::Int64,
428                PythonType::Int,
429                SqlType::BIGINT,
430            )),
431            SqlType::BINARY => Ok(DataTypeMap::new(
432                DataType::Binary,
433                PythonType::Bytes,
434                SqlType::BINARY,
435            )),
436            SqlType::BOOLEAN => Ok(DataTypeMap::new(
437                DataType::Boolean,
438                PythonType::Bool,
439                SqlType::BOOLEAN,
440            )),
441            SqlType::CHAR => Ok(DataTypeMap::new(
442                DataType::UInt8,
443                PythonType::Int,
444                SqlType::CHAR,
445            )),
446            SqlType::COLUMN_LIST => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
447            SqlType::CURSOR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
448            SqlType::DATE => Ok(DataTypeMap::new(
449                DataType::Date64,
450                PythonType::Datetime,
451                SqlType::DATE,
452            )),
453            SqlType::DECIMAL => Ok(DataTypeMap::new(
454                DataType::Decimal128(1, 1),
455                PythonType::Float,
456                SqlType::DECIMAL,
457            )),
458            SqlType::DISTINCT => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
459            SqlType::DOUBLE => Ok(DataTypeMap::new(
460                DataType::Decimal256(1, 1),
461                PythonType::Float,
462                SqlType::DOUBLE,
463            )),
464            SqlType::DYNAMIC_STAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
465            SqlType::FLOAT => Ok(DataTypeMap::new(
466                DataType::Decimal128(1, 1),
467                PythonType::Float,
468                SqlType::FLOAT,
469            )),
470            SqlType::GEOMETRY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
471            SqlType::INTEGER => Ok(DataTypeMap::new(
472                DataType::Int8,
473                PythonType::Int,
474                SqlType::INTEGER,
475            )),
476            SqlType::INTERVAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
477            SqlType::INTERVAL_DAY => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
478            SqlType::INTERVAL_DAY_HOUR => {
479                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
480            }
481            SqlType::INTERVAL_DAY_MINUTE => {
482                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
483            }
484            SqlType::INTERVAL_DAY_SECOND => {
485                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
486            }
487            SqlType::INTERVAL_HOUR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
488            SqlType::INTERVAL_HOUR_MINUTE => {
489                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
490            }
491            SqlType::INTERVAL_HOUR_SECOND => {
492                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
493            }
494            SqlType::INTERVAL_MINUTE => {
495                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
496            }
497            SqlType::INTERVAL_MINUTE_SECOND => {
498                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
499            }
500            SqlType::INTERVAL_MONTH => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
501            SqlType::INTERVAL_SECOND => {
502                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
503            }
504            SqlType::INTERVAL_YEAR => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
505            SqlType::INTERVAL_YEAR_MONTH => {
506                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
507            }
508            SqlType::MAP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
509            SqlType::MULTISET => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
510            SqlType::NULL => Ok(DataTypeMap::new(
511                DataType::Null,
512                PythonType::None,
513                SqlType::NULL,
514            )),
515            SqlType::OTHER => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
516            SqlType::REAL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
517            SqlType::ROW => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
518            SqlType::SARG => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
519            SqlType::SMALLINT => Ok(DataTypeMap::new(
520                DataType::Int16,
521                PythonType::Int,
522                SqlType::SMALLINT,
523            )),
524            SqlType::STRUCTURED => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
525            SqlType::SYMBOL => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
526            SqlType::TIME => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
527            SqlType::TIME_WITH_LOCAL_TIME_ZONE => {
528                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
529            }
530            SqlType::TIMESTAMP => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
531            SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => {
532                Err(PyNotImplementedError::new_err(format!("{sql_type:?}")))
533            }
534            SqlType::TINYINT => Ok(DataTypeMap::new(
535                DataType::Int8,
536                PythonType::Int,
537                SqlType::TINYINT,
538            )),
539            SqlType::UNKNOWN => Err(PyNotImplementedError::new_err(format!("{sql_type:?}"))),
540            SqlType::VARBINARY => Ok(DataTypeMap::new(
541                DataType::LargeBinary,
542                PythonType::Bytes,
543                SqlType::VARBINARY,
544            )),
545            SqlType::VARCHAR => Ok(DataTypeMap::new(
546                DataType::Utf8,
547                PythonType::Str,
548                SqlType::VARCHAR,
549            )),
550        }
551    }
552
553    /// Unfortunately PyO3 does not allow for us to expose the DataType as an enum since
554    /// we cannot directly annotate the Enum instance of dependency code. Therefore, here
555    /// we provide an enum to mimic it.
556    #[pyo3(name = "friendly_arrow_type_name")]
557    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
558        Ok(match &self.arrow_type.data_type {
559            DataType::Null => "Null",
560            DataType::Boolean => "Boolean",
561            DataType::Int8 => "Int8",
562            DataType::Int16 => "Int16",
563            DataType::Int32 => "Int32",
564            DataType::Int64 => "Int64",
565            DataType::UInt8 => "UInt8",
566            DataType::UInt16 => "UInt16",
567            DataType::UInt32 => "UInt32",
568            DataType::UInt64 => "UInt64",
569            DataType::Float16 => "Float16",
570            DataType::Float32 => "Float32",
571            DataType::Float64 => "Float64",
572            DataType::Timestamp(_, _) => "Timestamp",
573            DataType::Date32 => "Date32",
574            DataType::Date64 => "Date64",
575            DataType::Time32(_) => "Time32",
576            DataType::Time64(_) => "Time64",
577            DataType::Duration(_) => "Duration",
578            DataType::Interval(_) => "Interval",
579            DataType::Binary => "Binary",
580            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
581            DataType::LargeBinary => "LargeBinary",
582            DataType::Utf8 => "Utf8",
583            DataType::LargeUtf8 => "LargeUtf8",
584            DataType::List(_) => "List",
585            DataType::FixedSizeList(_, _) => "FixedSizeList",
586            DataType::LargeList(_) => "LargeList",
587            DataType::Struct(_) => "Struct",
588            DataType::Union(_, _) => "Union",
589            DataType::Dictionary(_, _) => "Dictionary",
590            DataType::Decimal32(_, _) => "Decimal32",
591            DataType::Decimal64(_, _) => "Decimal64",
592            DataType::Decimal128(_, _) => "Decimal128",
593            DataType::Decimal256(_, _) => "Decimal256",
594            DataType::Map(_, _) => "Map",
595            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
596            DataType::BinaryView => "BinaryView",
597            DataType::Utf8View => "Utf8View",
598            DataType::ListView(_) => "ListView",
599            DataType::LargeListView(_) => "LargeListView",
600        })
601    }
602}
603
604/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass`
605/// Since `DataType` exists in another package we cannot make that happen here so we wrap
606/// `DataType` as `PyDataType` This exists solely to satisfy those constraints.
607#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
608#[pyclass(
609    from_py_object,
610    frozen,
611    name = "DataType",
612    module = "datafusion.common"
613)]
614pub struct PyDataType {
615    pub data_type: DataType,
616}
617
618impl PyDataType {
619    /// There are situations when obtaining dtypes on the Python side where the Arrow type
620    /// is presented as a String rather than an actual DataType. This function is used to
621    /// convert that String to a DataType for the Python side to use.
622    pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult<PyDataType> {
623        // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]"
624        let arrow_str_type = match arrow_str_type.find('[') {
625            Some(index) => arrow_str_type[0..index].to_string(),
626            None => arrow_str_type, // Return early if ',' is not found.
627        };
628
629        let arrow_dtype = match arrow_str_type.to_lowercase().as_str() {
630            "bool" => Ok(DataType::Boolean),
631            "boolean" => Ok(DataType::Boolean),
632            "uint8" => Ok(DataType::UInt8),
633            "uint16" => Ok(DataType::UInt16),
634            "uint32" => Ok(DataType::UInt32),
635            "uint64" => Ok(DataType::UInt64),
636            "int8" => Ok(DataType::Int8),
637            "int16" => Ok(DataType::Int16),
638            "int32" => Ok(DataType::Int32),
639            "int64" => Ok(DataType::Int64),
640            "float" => Ok(DataType::Float32),
641            "double" => Ok(DataType::Float64),
642            "float16" => Ok(DataType::Float16),
643            "float32" => Ok(DataType::Float32),
644            "float64" => Ok(DataType::Float64),
645            "datetime64" => Ok(DataType::Date64),
646            "object" => Ok(DataType::Utf8),
647            _ => Err(PyValueError::new_err(format!(
648                "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}"
649            ))),
650        };
651        Ok(PyDataType {
652            data_type: arrow_dtype?,
653        })
654    }
655}
656
657impl From<PyDataType> for DataType {
658    fn from(data_type: PyDataType) -> DataType {
659        data_type.data_type
660    }
661}
662
663impl From<DataType> for PyDataType {
664    fn from(data_type: DataType) -> PyDataType {
665        PyDataType { data_type }
666    }
667}
668
669/// Represents the possible Python types that can be mapped to the SQL types
670#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
671#[pyclass(
672    from_py_object,
673    frozen,
674    eq,
675    eq_int,
676    name = "PythonType",
677    module = "datafusion.common"
678)]
679pub enum PythonType {
680    Array,
681    Bool,
682    Bytes,
683    Datetime,
684    Float,
685    Int,
686    List,
687    None,
688    Object,
689    Str,
690}
691
692/// Represents the types that are possible for DataFusion to parse
693/// from a SQL query. Aka "SqlType" and are valid values for
694/// ANSI SQL
695#[allow(non_camel_case_types)]
696#[allow(clippy::upper_case_acronyms)]
697#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
698#[pyclass(
699    from_py_object,
700    frozen,
701    eq,
702    eq_int,
703    name = "SqlType",
704    module = "datafusion.common"
705)]
706pub enum SqlType {
707    ANY,
708    ARRAY,
709    BIGINT,
710    BINARY,
711    BOOLEAN,
712    CHAR,
713    COLUMN_LIST,
714    CURSOR,
715    DATE,
716    DECIMAL,
717    DISTINCT,
718    DOUBLE,
719    DYNAMIC_STAR,
720    FLOAT,
721    GEOMETRY,
722    INTEGER,
723    INTERVAL,
724    INTERVAL_DAY,
725    INTERVAL_DAY_HOUR,
726    INTERVAL_DAY_MINUTE,
727    INTERVAL_DAY_SECOND,
728    INTERVAL_HOUR,
729    INTERVAL_HOUR_MINUTE,
730    INTERVAL_HOUR_SECOND,
731    INTERVAL_MINUTE,
732    INTERVAL_MINUTE_SECOND,
733    INTERVAL_MONTH,
734    INTERVAL_SECOND,
735    INTERVAL_YEAR,
736    INTERVAL_YEAR_MONTH,
737    MAP,
738    MULTISET,
739    NULL,
740    OTHER,
741    REAL,
742    ROW,
743    SARG,
744    SMALLINT,
745    STRUCTURED,
746    SYMBOL,
747    TIME,
748    TIME_WITH_LOCAL_TIME_ZONE,
749    TIMESTAMP,
750    TIMESTAMP_WITH_LOCAL_TIME_ZONE,
751    TINYINT,
752    UNKNOWN,
753    VARBINARY,
754    VARCHAR,
755}
756
757/// Specifies Ignore / Respect NULL within window functions.
758/// For example
759/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
760#[allow(non_camel_case_types)]
761#[allow(clippy::upper_case_acronyms)]
762#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
763#[pyclass(
764    from_py_object,
765    frozen,
766    eq,
767    eq_int,
768    name = "NullTreatment",
769    module = "datafusion.common"
770)]
771pub enum NullTreatment {
772    IGNORE_NULLS,
773    RESPECT_NULLS,
774}
775
776impl From<NullTreatment> for DFNullTreatment {
777    fn from(null_treatment: NullTreatment) -> DFNullTreatment {
778        match null_treatment {
779            NullTreatment::IGNORE_NULLS => DFNullTreatment::IgnoreNulls,
780            NullTreatment::RESPECT_NULLS => DFNullTreatment::RespectNulls,
781        }
782    }
783}
784
785impl From<DFNullTreatment> for NullTreatment {
786    fn from(null_treatment: DFNullTreatment) -> NullTreatment {
787        match null_treatment {
788            DFNullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
789            DFNullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
790        }
791    }
792}