datafusion_comet_spark_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    cast::as_primitive_array,
20    types::{Int32Type, TimestampMicrosecondType},
21};
22use arrow::datatypes::{DataType, TimeUnit, DECIMAL128_MAX_PRECISION};
23use std::sync::Arc;
24
25use crate::timezone::Tz;
26use arrow::array::types::TimestampMillisecondType;
27use arrow::datatypes::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
28use arrow::error::ArrowError;
29use arrow::{
30    array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
31    temporal_conversions::as_datetime,
32};
33use chrono::{DateTime, Offset, TimeZone};
34
35/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
36/// to apply timezone offset.
37//
38//  We consider the following cases:
39//
40//  | --------------------- | ------------ | ----------------- | -------------------------------- |
41//  | Conversion            | Input array  | Timezone          | Output array                     |
42//  | --------------------- | ------------ | ----------------- | -------------------------------- |
43//  | Timestamp ->          | Array in UTC | Timezone of input | A timestamp with the timezone    |
44//  |  Utf8 or Date32       |              |                   | offset applied and timezone      |
45//  |                       |              |                   | removed                          |
46//  | --------------------- | ------------ | ----------------- | -------------------------------- |
47//  | Timestamp ->          | Array in UTC | Timezone of input | Same as input array              |
48//  |  Timestamp  w/Timezone|              |                   |                                  |
49//  | --------------------- | ------------ | ----------------- | -------------------------------- |
50//  | Timestamp_ntz ->      | Array in     | Timezone of input | Same as input array              |
51//  |   Utf8 or Date32      | timezone     |                   |                                  |
52//  |                       | session local|                   |                                  |
53//  |                       | timezone     |                   |                                  |
54//  | --------------------- | ------------ | ----------------- | -------------------------------- |
55//  | Timestamp_ntz ->      | Array in     | Timezone of input |  Array in UTC and timezone       |
56//  |  Timestamp w/Timezone | session local|                   |  specified in input              |
57//  |                       | timezone     |                   |                                  |
58//  | --------------------- | ------------ | ----------------- | -------------------------------- |
59//  | Timestamp(_ntz) ->    |                                                                     |
60//  |        Any other type |              Not Supported                                          |
61//  | --------------------- | ------------ | ----------------- | -------------------------------- |
62//
63pub fn array_with_timezone(
64    array: ArrayRef,
65    timezone: String,
66    to_type: Option<&DataType>,
67) -> Result<ArrayRef, ArrowError> {
68    match array.data_type() {
69        DataType::Timestamp(_, None) => {
70            assert!(!timezone.is_empty());
71            match to_type {
72                Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
73                Some(DataType::Timestamp(_, Some(_))) => {
74                    timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
75                }
76                _ => {
77                    // Not supported
78                    panic!(
79                        "Cannot convert from {:?} to {:?}",
80                        array.data_type(),
81                        to_type.unwrap()
82                    )
83                }
84            }
85        }
86        DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
87            assert!(!timezone.is_empty());
88            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
89            let array_with_timezone = array.clone().with_timezone(timezone.clone());
90            let array = Arc::new(array_with_timezone) as ArrayRef;
91            match to_type {
92                Some(DataType::Utf8) | Some(DataType::Date32) => {
93                    pre_timestamp_cast(array, timezone)
94                }
95                _ => Ok(array),
96            }
97        }
98        DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
99            assert!(!timezone.is_empty());
100            let array = as_primitive_array::<TimestampMillisecondType>(&array);
101            let array_with_timezone = array.clone().with_timezone(timezone.clone());
102            let array = Arc::new(array_with_timezone) as ArrayRef;
103            match to_type {
104                Some(DataType::Utf8) | Some(DataType::Date32) => {
105                    pre_timestamp_cast(array, timezone)
106                }
107                _ => Ok(array),
108            }
109        }
110        DataType::Dictionary(_, value_type)
111            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
112        {
113            let dict = as_dictionary_array::<Int32Type>(&array);
114            let array = as_primitive_array::<TimestampMicrosecondType>(dict.values());
115            let array_with_timezone =
116                array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?;
117            let dict = dict.with_values(array_with_timezone);
118            Ok(Arc::new(dict))
119        }
120        _ => Ok(array),
121    }
122}
123
124fn datetime_cast_err(value: i64) -> ArrowError {
125    ArrowError::CastError(format!(
126        "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE",
127    ))
128}
129
130/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
131/// a Timestamp(Microsecond, Some<_>) array.
132/// The understanding is that the input array has time in the timezone specified in the second
133/// argument.
134/// Parameters:
135///     array - input array of timestamp without timezone
136///     tz - timezone of the values in the input array
137///     to_timezone - timezone to change the input values to
138fn timestamp_ntz_to_timestamp(
139    array: ArrayRef,
140    tz: &str,
141    to_timezone: Option<&str>,
142) -> Result<ArrayRef, ArrowError> {
143    assert!(!tz.is_empty());
144    match array.data_type() {
145        DataType::Timestamp(TimeUnit::Microsecond, None) => {
146            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
147            let tz: Tz = tz.parse()?;
148            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
149                as_datetime::<TimestampMicrosecondType>(value)
150                    .ok_or_else(|| datetime_cast_err(value))
151                    .map(|local_datetime| {
152                        let datetime: DateTime<Tz> =
153                            tz.from_local_datetime(&local_datetime).unwrap();
154                        datetime.timestamp_micros()
155                    })
156            })?;
157            let array_with_tz = if let Some(to_tz) = to_timezone {
158                array.with_timezone(to_tz)
159            } else {
160                array
161            };
162            Ok(Arc::new(array_with_tz))
163        }
164        DataType::Timestamp(TimeUnit::Millisecond, None) => {
165            let array = as_primitive_array::<TimestampMillisecondType>(&array);
166            let tz: Tz = tz.parse()?;
167            let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
168                as_datetime::<TimestampMillisecondType>(value)
169                    .ok_or_else(|| datetime_cast_err(value))
170                    .map(|local_datetime| {
171                        let datetime: DateTime<Tz> =
172                            tz.from_local_datetime(&local_datetime).unwrap();
173                        datetime.timestamp_millis()
174                    })
175            })?;
176            let array_with_tz = if let Some(to_tz) = to_timezone {
177                array.with_timezone(to_tz)
178            } else {
179                array
180            };
181            Ok(Arc::new(array_with_tz))
182        }
183        _ => Ok(array),
184    }
185}
186
187/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String.
188fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result<ArrayRef, ArrowError> {
189    assert!(!timezone.is_empty());
190    match array.data_type() {
191        DataType::Timestamp(_, _) => {
192            // Spark doesn't output timezone while casting timestamp to string, but arrow's cast
193            // kernel does if timezone exists. So we need to apply offset of timezone to array
194            // timestamp value and remove timezone from array datatype.
195            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
196
197            let tz: Tz = timezone.parse()?;
198            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
199                as_datetime::<TimestampMicrosecondType>(value)
200                    .ok_or_else(|| datetime_cast_err(value))
201                    .map(|datetime| {
202                        let offset = tz.offset_from_utc_datetime(&datetime).fix();
203                        let datetime = datetime + offset;
204                        datetime.and_utc().timestamp_micros()
205                    })
206            })?;
207
208            Ok(Arc::new(array))
209        }
210        _ => Ok(array),
211    }
212}
213
214/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
215/// instead of Err to avoid the cost of formatting the error strings and is
216/// optimized to remove a memcpy that exists in the original function
217/// we can remove this code once we upgrade to a version of arrow-rs that
218/// includes https://github.com/apache/arrow-rs/pull/6419
219#[inline]
220pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
221    precision <= DECIMAL128_MAX_PRECISION
222        && value >= MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
223        && value <= MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
224}
225
226// These are borrowed from hashbrown crate:
227//   https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
228
229// On stable we can use #[cold] to get a equivalent effect: this attributes
230// suggests that the function is unlikely to be called
231#[inline]
232#[cold]
233pub fn cold() {}
234
235#[inline]
236pub fn likely(b: bool) -> bool {
237    if !b {
238        cold();
239    }
240    b
241}
242#[inline]
243pub fn unlikely(b: bool) -> bool {
244    if b {
245        cold();
246    }
247    b
248}