datafusion_comet_spark_expr/utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{
19 cast::as_primitive_array,
20 types::{Int32Type, TimestampMicrosecondType},
21};
22use arrow_schema::{ArrowError, DataType, TimeUnit, DECIMAL128_MAX_PRECISION};
23use std::sync::Arc;
24
25use crate::timezone::Tz;
26use arrow::{
27 array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
28 temporal_conversions::as_datetime,
29};
30use arrow_array::types::TimestampMillisecondType;
31use arrow_data::decimal::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
32use chrono::{DateTime, Offset, TimeZone};
33
34/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
35/// to apply timezone offset.
36//
37// We consider the following cases:
38//
39// | --------------------- | ------------ | ----------------- | -------------------------------- |
40// | Conversion | Input array | Timezone | Output array |
41// | --------------------- | ------------ | ----------------- | -------------------------------- |
42// | Timestamp -> | Array in UTC | Timezone of input | A timestamp with the timezone |
43// | Utf8 or Date32 | | | offset applied and timezone |
44// | | | | removed |
45// | --------------------- | ------------ | ----------------- | -------------------------------- |
46// | Timestamp -> | Array in UTC | Timezone of input | Same as input array |
47// | Timestamp w/Timezone| | | |
48// | --------------------- | ------------ | ----------------- | -------------------------------- |
49// | Timestamp_ntz -> | Array in | Timezone of input | Same as input array |
50// | Utf8 or Date32 | timezone | | |
51// | | session local| | |
52// | | timezone | | |
53// | --------------------- | ------------ | ----------------- | -------------------------------- |
54// | Timestamp_ntz -> | Array in | Timezone of input | Array in UTC and timezone |
55// | Timestamp w/Timezone | session local| | specified in input |
56// | | timezone | | |
57// | --------------------- | ------------ | ----------------- | -------------------------------- |
58// | Timestamp(_ntz) -> | |
59// | Any other type | Not Supported |
60// | --------------------- | ------------ | ----------------- | -------------------------------- |
61//
62pub fn array_with_timezone(
63 array: ArrayRef,
64 timezone: String,
65 to_type: Option<&DataType>,
66) -> Result<ArrayRef, ArrowError> {
67 match array.data_type() {
68 DataType::Timestamp(_, None) => {
69 assert!(!timezone.is_empty());
70 match to_type {
71 Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
72 Some(DataType::Timestamp(_, Some(_))) => {
73 timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
74 }
75 Some(DataType::Timestamp(_, None)) => {
76 timestamp_ntz_to_timestamp(array, timezone.as_str(), None)
77 }
78 _ => {
79 // Not supported
80 panic!(
81 "Cannot convert from {:?} to {:?}",
82 array.data_type(),
83 to_type.unwrap()
84 )
85 }
86 }
87 }
88 DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
89 assert!(!timezone.is_empty());
90 let array = as_primitive_array::<TimestampMicrosecondType>(&array);
91 let array_with_timezone = array.clone().with_timezone(timezone.clone());
92 let array = Arc::new(array_with_timezone) as ArrayRef;
93 match to_type {
94 Some(DataType::Utf8) | Some(DataType::Date32) => {
95 pre_timestamp_cast(array, timezone)
96 }
97 _ => Ok(array),
98 }
99 }
100 DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
101 assert!(!timezone.is_empty());
102 let array = as_primitive_array::<TimestampMillisecondType>(&array);
103 let array_with_timezone = array.clone().with_timezone(timezone.clone());
104 let array = Arc::new(array_with_timezone) as ArrayRef;
105 match to_type {
106 Some(DataType::Utf8) | Some(DataType::Date32) => {
107 pre_timestamp_cast(array, timezone)
108 }
109 _ => Ok(array),
110 }
111 }
112 DataType::Dictionary(_, value_type)
113 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
114 {
115 let dict = as_dictionary_array::<Int32Type>(&array);
116 let array = as_primitive_array::<TimestampMicrosecondType>(dict.values());
117 let array_with_timezone =
118 array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?;
119 let dict = dict.with_values(array_with_timezone);
120 Ok(Arc::new(dict))
121 }
122 _ => Ok(array),
123 }
124}
125
126fn datetime_cast_err(value: i64) -> ArrowError {
127 ArrowError::CastError(format!(
128 "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE",
129 ))
130}
131
132/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
133/// a Timestamp(Microsecond, Some<_>) array.
134/// The understanding is that the input array has time in the timezone specified in the second
135/// argument.
136/// Parameters:
137/// array - input array of timestamp without timezone
138/// tz - timezone of the values in the input array
139/// to_timezone - timezone to change the input values to
140fn timestamp_ntz_to_timestamp(
141 array: ArrayRef,
142 tz: &str,
143 to_timezone: Option<&str>,
144) -> Result<ArrayRef, ArrowError> {
145 assert!(!tz.is_empty());
146 match array.data_type() {
147 DataType::Timestamp(TimeUnit::Microsecond, None) => {
148 let array = as_primitive_array::<TimestampMicrosecondType>(&array);
149 let tz: Tz = tz.parse()?;
150 let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
151 as_datetime::<TimestampMicrosecondType>(value)
152 .ok_or_else(|| datetime_cast_err(value))
153 .map(|local_datetime| {
154 let datetime: DateTime<Tz> =
155 tz.from_local_datetime(&local_datetime).unwrap();
156 datetime.timestamp_micros()
157 })
158 })?;
159 let array_with_tz = if let Some(to_tz) = to_timezone {
160 array.with_timezone(to_tz)
161 } else {
162 array
163 };
164 Ok(Arc::new(array_with_tz))
165 }
166 DataType::Timestamp(TimeUnit::Millisecond, None) => {
167 let array = as_primitive_array::<TimestampMillisecondType>(&array);
168 let tz: Tz = tz.parse()?;
169 let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
170 as_datetime::<TimestampMillisecondType>(value)
171 .ok_or_else(|| datetime_cast_err(value))
172 .map(|local_datetime| {
173 let datetime: DateTime<Tz> =
174 tz.from_local_datetime(&local_datetime).unwrap();
175 datetime.timestamp_millis()
176 })
177 })?;
178 let array_with_tz = if let Some(to_tz) = to_timezone {
179 array.with_timezone(to_tz)
180 } else {
181 array
182 };
183 Ok(Arc::new(array_with_tz))
184 }
185 _ => Ok(array),
186 }
187}
188
189/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String.
190fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result<ArrayRef, ArrowError> {
191 assert!(!timezone.is_empty());
192 match array.data_type() {
193 DataType::Timestamp(_, _) => {
194 // Spark doesn't output timezone while casting timestamp to string, but arrow's cast
195 // kernel does if timezone exists. So we need to apply offset of timezone to array
196 // timestamp value and remove timezone from array datatype.
197 let array = as_primitive_array::<TimestampMicrosecondType>(&array);
198
199 let tz: Tz = timezone.parse()?;
200 let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
201 as_datetime::<TimestampMicrosecondType>(value)
202 .ok_or_else(|| datetime_cast_err(value))
203 .map(|datetime| {
204 let offset = tz.offset_from_utc_datetime(&datetime).fix();
205 let datetime = datetime + offset;
206 datetime.and_utc().timestamp_micros()
207 })
208 })?;
209
210 Ok(Arc::new(array))
211 }
212 _ => Ok(array),
213 }
214}
215
216/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
217/// instead of Err to avoid the cost of formatting the error strings and is
218/// optimized to remove a memcpy that exists in the original function
219/// we can remove this code once we upgrade to a version of arrow-rs that
220/// includes https://github.com/apache/arrow-rs/pull/6419
221#[inline]
222pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
223 precision <= DECIMAL128_MAX_PRECISION
224 && value >= MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
225 && value <= MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
226}
227
228// These are borrowed from hashbrown crate:
229// https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
230
231// On stable we can use #[cold] to get a equivalent effect: this attributes
232// suggests that the function is unlikely to be called
233#[inline]
234#[cold]
235pub fn cold() {}
236
237#[inline]
238pub fn likely(b: bool) -> bool {
239 if !b {
240 cold();
241 }
242 b
243}
244#[inline]
245pub fn unlikely(b: bool) -> bool {
246 if b {
247 cold();
248 }
249 b
250}