datafusion_comet_spark_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{
19    cast::as_primitive_array,
20    types::{Int32Type, TimestampMicrosecondType},
21};
22use arrow_schema::{ArrowError, DataType, TimeUnit, DECIMAL128_MAX_PRECISION};
23use std::sync::Arc;
24
25use crate::timezone::Tz;
26use arrow::{
27    array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
28    temporal_conversions::as_datetime,
29};
30use arrow_array::types::TimestampMillisecondType;
31use arrow_data::decimal::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
32use chrono::{DateTime, Offset, TimeZone};
33
34/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
35/// to apply timezone offset.
36//
37//  We consider the following cases:
38//
39//  | --------------------- | ------------ | ----------------- | -------------------------------- |
40//  | Conversion            | Input array  | Timezone          | Output array                     |
41//  | --------------------- | ------------ | ----------------- | -------------------------------- |
42//  | Timestamp ->          | Array in UTC | Timezone of input | A timestamp with the timezone    |
43//  |  Utf8 or Date32       |              |                   | offset applied and timezone      |
44//  |                       |              |                   | removed                          |
45//  | --------------------- | ------------ | ----------------- | -------------------------------- |
46//  | Timestamp ->          | Array in UTC | Timezone of input | Same as input array              |
47//  |  Timestamp  w/Timezone|              |                   |                                  |
48//  | --------------------- | ------------ | ----------------- | -------------------------------- |
49//  | Timestamp_ntz ->      | Array in     | Timezone of input | Same as input array              |
50//  |   Utf8 or Date32      | timezone     |                   |                                  |
51//  |                       | session local|                   |                                  |
52//  |                       | timezone     |                   |                                  |
53//  | --------------------- | ------------ | ----------------- | -------------------------------- |
54//  | Timestamp_ntz ->      | Array in     | Timezone of input |  Array in UTC and timezone       |
55//  |  Timestamp w/Timezone | session local|                   |  specified in input              |
56//  |                       | timezone     |                   |                                  |
57//  | --------------------- | ------------ | ----------------- | -------------------------------- |
58//  | Timestamp(_ntz) ->    |                                                                     |
59//  |        Any other type |              Not Supported                                          |
60//  | --------------------- | ------------ | ----------------- | -------------------------------- |
61//
62pub fn array_with_timezone(
63    array: ArrayRef,
64    timezone: String,
65    to_type: Option<&DataType>,
66) -> Result<ArrayRef, ArrowError> {
67    match array.data_type() {
68        DataType::Timestamp(_, None) => {
69            assert!(!timezone.is_empty());
70            match to_type {
71                Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
72                Some(DataType::Timestamp(_, Some(_))) => {
73                    timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
74                }
75                Some(DataType::Timestamp(_, None)) => {
76                    timestamp_ntz_to_timestamp(array, timezone.as_str(), None)
77                }
78                _ => {
79                    // Not supported
80                    panic!(
81                        "Cannot convert from {:?} to {:?}",
82                        array.data_type(),
83                        to_type.unwrap()
84                    )
85                }
86            }
87        }
88        DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
89            assert!(!timezone.is_empty());
90            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
91            let array_with_timezone = array.clone().with_timezone(timezone.clone());
92            let array = Arc::new(array_with_timezone) as ArrayRef;
93            match to_type {
94                Some(DataType::Utf8) | Some(DataType::Date32) => {
95                    pre_timestamp_cast(array, timezone)
96                }
97                _ => Ok(array),
98            }
99        }
100        DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
101            assert!(!timezone.is_empty());
102            let array = as_primitive_array::<TimestampMillisecondType>(&array);
103            let array_with_timezone = array.clone().with_timezone(timezone.clone());
104            let array = Arc::new(array_with_timezone) as ArrayRef;
105            match to_type {
106                Some(DataType::Utf8) | Some(DataType::Date32) => {
107                    pre_timestamp_cast(array, timezone)
108                }
109                _ => Ok(array),
110            }
111        }
112        DataType::Dictionary(_, value_type)
113            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
114        {
115            let dict = as_dictionary_array::<Int32Type>(&array);
116            let array = as_primitive_array::<TimestampMicrosecondType>(dict.values());
117            let array_with_timezone =
118                array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?;
119            let dict = dict.with_values(array_with_timezone);
120            Ok(Arc::new(dict))
121        }
122        _ => Ok(array),
123    }
124}
125
126fn datetime_cast_err(value: i64) -> ArrowError {
127    ArrowError::CastError(format!(
128        "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE",
129    ))
130}
131
132/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
133/// a Timestamp(Microsecond, Some<_>) array.
134/// The understanding is that the input array has time in the timezone specified in the second
135/// argument.
136/// Parameters:
137///     array - input array of timestamp without timezone
138///     tz - timezone of the values in the input array
139///     to_timezone - timezone to change the input values to
140fn timestamp_ntz_to_timestamp(
141    array: ArrayRef,
142    tz: &str,
143    to_timezone: Option<&str>,
144) -> Result<ArrayRef, ArrowError> {
145    assert!(!tz.is_empty());
146    match array.data_type() {
147        DataType::Timestamp(TimeUnit::Microsecond, None) => {
148            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
149            let tz: Tz = tz.parse()?;
150            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
151                as_datetime::<TimestampMicrosecondType>(value)
152                    .ok_or_else(|| datetime_cast_err(value))
153                    .map(|local_datetime| {
154                        let datetime: DateTime<Tz> =
155                            tz.from_local_datetime(&local_datetime).unwrap();
156                        datetime.timestamp_micros()
157                    })
158            })?;
159            let array_with_tz = if let Some(to_tz) = to_timezone {
160                array.with_timezone(to_tz)
161            } else {
162                array
163            };
164            Ok(Arc::new(array_with_tz))
165        }
166        DataType::Timestamp(TimeUnit::Millisecond, None) => {
167            let array = as_primitive_array::<TimestampMillisecondType>(&array);
168            let tz: Tz = tz.parse()?;
169            let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
170                as_datetime::<TimestampMillisecondType>(value)
171                    .ok_or_else(|| datetime_cast_err(value))
172                    .map(|local_datetime| {
173                        let datetime: DateTime<Tz> =
174                            tz.from_local_datetime(&local_datetime).unwrap();
175                        datetime.timestamp_millis()
176                    })
177            })?;
178            let array_with_tz = if let Some(to_tz) = to_timezone {
179                array.with_timezone(to_tz)
180            } else {
181                array
182            };
183            Ok(Arc::new(array_with_tz))
184        }
185        _ => Ok(array),
186    }
187}
188
189/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String.
190fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result<ArrayRef, ArrowError> {
191    assert!(!timezone.is_empty());
192    match array.data_type() {
193        DataType::Timestamp(_, _) => {
194            // Spark doesn't output timezone while casting timestamp to string, but arrow's cast
195            // kernel does if timezone exists. So we need to apply offset of timezone to array
196            // timestamp value and remove timezone from array datatype.
197            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
198
199            let tz: Tz = timezone.parse()?;
200            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
201                as_datetime::<TimestampMicrosecondType>(value)
202                    .ok_or_else(|| datetime_cast_err(value))
203                    .map(|datetime| {
204                        let offset = tz.offset_from_utc_datetime(&datetime).fix();
205                        let datetime = datetime + offset;
206                        datetime.and_utc().timestamp_micros()
207                    })
208            })?;
209
210            Ok(Arc::new(array))
211        }
212        _ => Ok(array),
213    }
214}
215
216/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
217/// instead of Err to avoid the cost of formatting the error strings and is
218/// optimized to remove a memcpy that exists in the original function
219/// we can remove this code once we upgrade to a version of arrow-rs that
220/// includes https://github.com/apache/arrow-rs/pull/6419
221#[inline]
222pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
223    precision <= DECIMAL128_MAX_PRECISION
224        && value >= MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
225        && value <= MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
226}
227
228// These are borrowed from hashbrown crate:
229//   https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
230
231// On stable we can use #[cold] to get a equivalent effect: this attributes
232// suggests that the function is unlikely to be called
233#[inline]
234#[cold]
235pub fn cold() {}
236
237#[inline]
238pub fn likely(b: bool) -> bool {
239    if !b {
240        cold();
241    }
242    b
243}
244#[inline]
245pub fn unlikely(b: bool) -> bool {
246    if b {
247        cold();
248    }
249    b
250}