datafusion_comet_spark_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{
19    cast::as_primitive_array,
20    types::{Int32Type, TimestampMicrosecondType},
21};
22use arrow_schema::{ArrowError, DataType, TimeUnit, DECIMAL128_MAX_PRECISION};
23use std::sync::Arc;
24
25use crate::timezone::Tz;
26use arrow::{
27    array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
28    temporal_conversions::as_datetime,
29};
30use arrow_array::types::TimestampMillisecondType;
31use arrow_data::decimal::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
32use chrono::{DateTime, Offset, TimeZone};
33
34/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
35/// to apply timezone offset.
36//
37//  We consider the following cases:
38//
39//  | --------------------- | ------------ | ----------------- | -------------------------------- |
40//  | Conversion            | Input array  | Timezone          | Output array                     |
41//  | --------------------- | ------------ | ----------------- | -------------------------------- |
42//  | Timestamp ->          | Array in UTC | Timezone of input | A timestamp with the timezone    |
43//  |  Utf8 or Date32       |              |                   | offset applied and timezone      |
44//  |                       |              |                   | removed                          |
45//  | --------------------- | ------------ | ----------------- | -------------------------------- |
46//  | Timestamp ->          | Array in UTC | Timezone of input | Same as input array              |
47//  |  Timestamp  w/Timezone|              |                   |                                  |
48//  | --------------------- | ------------ | ----------------- | -------------------------------- |
49//  | Timestamp_ntz ->      | Array in     | Timezone of input | Same as input array              |
50//  |   Utf8 or Date32      | timezone     |                   |                                  |
51//  |                       | session local|                   |                                  |
52//  |                       | timezone     |                   |                                  |
53//  | --------------------- | ------------ | ----------------- | -------------------------------- |
54//  | Timestamp_ntz ->      | Array in     | Timezone of input |  Array in UTC and timezone       |
55//  |  Timestamp w/Timezone | session local|                   |  specified in input              |
56//  |                       | timezone     |                   |                                  |
57//  | --------------------- | ------------ | ----------------- | -------------------------------- |
58//  | Timestamp(_ntz) ->    |                                                                     |
59//  |        Any other type |              Not Supported                                          |
60//  | --------------------- | ------------ | ----------------- | -------------------------------- |
61//
62pub fn array_with_timezone(
63    array: ArrayRef,
64    timezone: String,
65    to_type: Option<&DataType>,
66) -> Result<ArrayRef, ArrowError> {
67    match array.data_type() {
68        DataType::Timestamp(_, None) => {
69            assert!(!timezone.is_empty());
70            match to_type {
71                Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
72                Some(DataType::Timestamp(_, Some(_))) => {
73                    timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
74                }
75                _ => {
76                    // Not supported
77                    panic!(
78                        "Cannot convert from {:?} to {:?}",
79                        array.data_type(),
80                        to_type.unwrap()
81                    )
82                }
83            }
84        }
85        DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
86            assert!(!timezone.is_empty());
87            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
88            let array_with_timezone = array.clone().with_timezone(timezone.clone());
89            let array = Arc::new(array_with_timezone) as ArrayRef;
90            match to_type {
91                Some(DataType::Utf8) | Some(DataType::Date32) => {
92                    pre_timestamp_cast(array, timezone)
93                }
94                _ => Ok(array),
95            }
96        }
97        DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
98            assert!(!timezone.is_empty());
99            let array = as_primitive_array::<TimestampMillisecondType>(&array);
100            let array_with_timezone = array.clone().with_timezone(timezone.clone());
101            let array = Arc::new(array_with_timezone) as ArrayRef;
102            match to_type {
103                Some(DataType::Utf8) | Some(DataType::Date32) => {
104                    pre_timestamp_cast(array, timezone)
105                }
106                _ => Ok(array),
107            }
108        }
109        DataType::Dictionary(_, value_type)
110            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
111        {
112            let dict = as_dictionary_array::<Int32Type>(&array);
113            let array = as_primitive_array::<TimestampMicrosecondType>(dict.values());
114            let array_with_timezone =
115                array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?;
116            let dict = dict.with_values(array_with_timezone);
117            Ok(Arc::new(dict))
118        }
119        _ => Ok(array),
120    }
121}
122
123fn datetime_cast_err(value: i64) -> ArrowError {
124    ArrowError::CastError(format!(
125        "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE",
126    ))
127}
128
129/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
130/// a Timestamp(Microsecond, Some<_>) array.
131/// The understanding is that the input array has time in the timezone specified in the second
132/// argument.
133/// Parameters:
134///     array - input array of timestamp without timezone
135///     tz - timezone of the values in the input array
136///     to_timezone - timezone to change the input values to
137fn timestamp_ntz_to_timestamp(
138    array: ArrayRef,
139    tz: &str,
140    to_timezone: Option<&str>,
141) -> Result<ArrayRef, ArrowError> {
142    assert!(!tz.is_empty());
143    match array.data_type() {
144        DataType::Timestamp(TimeUnit::Microsecond, None) => {
145            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
146            let tz: Tz = tz.parse()?;
147            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
148                as_datetime::<TimestampMicrosecondType>(value)
149                    .ok_or_else(|| datetime_cast_err(value))
150                    .map(|local_datetime| {
151                        let datetime: DateTime<Tz> =
152                            tz.from_local_datetime(&local_datetime).unwrap();
153                        datetime.timestamp_micros()
154                    })
155            })?;
156            let array_with_tz = if let Some(to_tz) = to_timezone {
157                array.with_timezone(to_tz)
158            } else {
159                array
160            };
161            Ok(Arc::new(array_with_tz))
162        }
163        DataType::Timestamp(TimeUnit::Millisecond, None) => {
164            let array = as_primitive_array::<TimestampMillisecondType>(&array);
165            let tz: Tz = tz.parse()?;
166            let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
167                as_datetime::<TimestampMillisecondType>(value)
168                    .ok_or_else(|| datetime_cast_err(value))
169                    .map(|local_datetime| {
170                        let datetime: DateTime<Tz> =
171                            tz.from_local_datetime(&local_datetime).unwrap();
172                        datetime.timestamp_millis()
173                    })
174            })?;
175            let array_with_tz = if let Some(to_tz) = to_timezone {
176                array.with_timezone(to_tz)
177            } else {
178                array
179            };
180            Ok(Arc::new(array_with_tz))
181        }
182        _ => Ok(array),
183    }
184}
185
186/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String.
187fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result<ArrayRef, ArrowError> {
188    assert!(!timezone.is_empty());
189    match array.data_type() {
190        DataType::Timestamp(_, _) => {
191            // Spark doesn't output timezone while casting timestamp to string, but arrow's cast
192            // kernel does if timezone exists. So we need to apply offset of timezone to array
193            // timestamp value and remove timezone from array datatype.
194            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
195
196            let tz: Tz = timezone.parse()?;
197            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
198                as_datetime::<TimestampMicrosecondType>(value)
199                    .ok_or_else(|| datetime_cast_err(value))
200                    .map(|datetime| {
201                        let offset = tz.offset_from_utc_datetime(&datetime).fix();
202                        let datetime = datetime + offset;
203                        datetime.and_utc().timestamp_micros()
204                    })
205            })?;
206
207            Ok(Arc::new(array))
208        }
209        _ => Ok(array),
210    }
211}
212
213/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
214/// instead of Err to avoid the cost of formatting the error strings and is
215/// optimized to remove a memcpy that exists in the original function
216/// we can remove this code once we upgrade to a version of arrow-rs that
217/// includes https://github.com/apache/arrow-rs/pull/6419
218#[inline]
219pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
220    precision <= DECIMAL128_MAX_PRECISION
221        && value >= MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
222        && value <= MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
223}
224
225// These are borrowed from hashbrown crate:
226//   https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
227
228// On stable we can use #[cold] to get a equivalent effect: this attributes
229// suggests that the function is unlikely to be called
230#[inline]
231#[cold]
232pub fn cold() {}
233
234#[inline]
235pub fn likely(b: bool) -> bool {
236    if !b {
237        cold();
238    }
239    b
240}
241#[inline]
242pub fn unlikely(b: bool) -> bool {
243    if b {
244        cold();
245    }
246    b
247}