datafusion_comet_spark_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, TimeUnit, DECIMAL128_MAX_PRECISION};
19use arrow::{
20    array::{
21        cast::as_primitive_array,
22        types::{Int32Type, TimestampMicrosecondType},
23        BooleanBufferBuilder,
24    },
25    buffer::BooleanBuffer,
26};
27use datafusion::logical_expr::EmitTo;
28use std::sync::Arc;
29
30use crate::timezone::Tz;
31use arrow::array::types::TimestampMillisecondType;
32use arrow::datatypes::{MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION};
33use arrow::error::ArrowError;
34use arrow::{
35    array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
36    temporal_conversions::as_datetime,
37};
38use chrono::{DateTime, Offset, TimeZone};
39
40/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
41/// to apply timezone offset.
42//
43//  We consider the following cases:
44//
45//  | --------------------- | ------------ | ----------------- | -------------------------------- |
46//  | Conversion            | Input array  | Timezone          | Output array                     |
47//  | --------------------- | ------------ | ----------------- | -------------------------------- |
48//  | Timestamp ->          | Array in UTC | Timezone of input | A timestamp with the timezone    |
49//  |  Utf8 or Date32       |              |                   | offset applied and timezone      |
50//  |                       |              |                   | removed                          |
51//  | --------------------- | ------------ | ----------------- | -------------------------------- |
52//  | Timestamp ->          | Array in UTC | Timezone of input | Same as input array              |
53//  |  Timestamp  w/Timezone|              |                   |                                  |
54//  | --------------------- | ------------ | ----------------- | -------------------------------- |
55//  | Timestamp_ntz ->      | Array in     | Timezone of input | Same as input array              |
56//  |   Utf8 or Date32      | timezone     |                   |                                  |
57//  |                       | session local|                   |                                  |
58//  |                       | timezone     |                   |                                  |
59//  | --------------------- | ------------ | ----------------- | -------------------------------- |
60//  | Timestamp_ntz ->      | Array in     | Timezone of input |  Array in UTC and timezone       |
61//  |  Timestamp w/Timezone | session local|                   |  specified in input              |
62//  |                       | timezone     |                   |                                  |
63//  | --------------------- | ------------ | ----------------- | -------------------------------- |
64//  | Timestamp(_ntz) ->    |                                                                     |
65//  |        Any other type |              Not Supported                                          |
66//  | --------------------- | ------------ | ----------------- | -------------------------------- |
67//
68pub fn array_with_timezone(
69    array: ArrayRef,
70    timezone: String,
71    to_type: Option<&DataType>,
72) -> Result<ArrayRef, ArrowError> {
73    match array.data_type() {
74        DataType::Timestamp(_, None) => {
75            assert!(!timezone.is_empty());
76            match to_type {
77                Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
78                Some(DataType::Timestamp(_, Some(_))) => {
79                    timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
80                }
81                _ => {
82                    // Not supported
83                    panic!(
84                        "Cannot convert from {:?} to {:?}",
85                        array.data_type(),
86                        to_type.unwrap()
87                    )
88                }
89            }
90        }
91        DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
92            assert!(!timezone.is_empty());
93            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
94            let array_with_timezone = array.clone().with_timezone(timezone.clone());
95            let array = Arc::new(array_with_timezone) as ArrayRef;
96            match to_type {
97                Some(DataType::Utf8) | Some(DataType::Date32) => {
98                    pre_timestamp_cast(array, timezone)
99                }
100                _ => Ok(array),
101            }
102        }
103        DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
104            assert!(!timezone.is_empty());
105            let array = as_primitive_array::<TimestampMillisecondType>(&array);
106            let array_with_timezone = array.clone().with_timezone(timezone.clone());
107            let array = Arc::new(array_with_timezone) as ArrayRef;
108            match to_type {
109                Some(DataType::Utf8) | Some(DataType::Date32) => {
110                    pre_timestamp_cast(array, timezone)
111                }
112                _ => Ok(array),
113            }
114        }
115        DataType::Dictionary(_, value_type)
116            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
117        {
118            let dict = as_dictionary_array::<Int32Type>(&array);
119            let array = as_primitive_array::<TimestampMicrosecondType>(dict.values());
120            let array_with_timezone =
121                array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?;
122            let dict = dict.with_values(array_with_timezone);
123            Ok(Arc::new(dict))
124        }
125        _ => Ok(array),
126    }
127}
128
129fn datetime_cast_err(value: i64) -> ArrowError {
130    ArrowError::CastError(format!(
131        "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE",
132    ))
133}
134
135/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
136/// a Timestamp(Microsecond, Some<_>) array.
137/// The understanding is that the input array has time in the timezone specified in the second
138/// argument.
139/// Parameters:
140///     array - input array of timestamp without timezone
141///     tz - timezone of the values in the input array
142///     to_timezone - timezone to change the input values to
143fn timestamp_ntz_to_timestamp(
144    array: ArrayRef,
145    tz: &str,
146    to_timezone: Option<&str>,
147) -> Result<ArrayRef, ArrowError> {
148    assert!(!tz.is_empty());
149    match array.data_type() {
150        DataType::Timestamp(TimeUnit::Microsecond, None) => {
151            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
152            let tz: Tz = tz.parse()?;
153            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
154                as_datetime::<TimestampMicrosecondType>(value)
155                    .ok_or_else(|| datetime_cast_err(value))
156                    .map(|local_datetime| {
157                        let datetime: DateTime<Tz> =
158                            tz.from_local_datetime(&local_datetime).unwrap();
159                        datetime.timestamp_micros()
160                    })
161            })?;
162            let array_with_tz = if let Some(to_tz) = to_timezone {
163                array.with_timezone(to_tz)
164            } else {
165                array
166            };
167            Ok(Arc::new(array_with_tz))
168        }
169        DataType::Timestamp(TimeUnit::Millisecond, None) => {
170            let array = as_primitive_array::<TimestampMillisecondType>(&array);
171            let tz: Tz = tz.parse()?;
172            let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
173                as_datetime::<TimestampMillisecondType>(value)
174                    .ok_or_else(|| datetime_cast_err(value))
175                    .map(|local_datetime| {
176                        let datetime: DateTime<Tz> =
177                            tz.from_local_datetime(&local_datetime).unwrap();
178                        datetime.timestamp_millis()
179                    })
180            })?;
181            let array_with_tz = if let Some(to_tz) = to_timezone {
182                array.with_timezone(to_tz)
183            } else {
184                array
185            };
186            Ok(Arc::new(array_with_tz))
187        }
188        _ => Ok(array),
189    }
190}
191
192/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String.
193fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result<ArrayRef, ArrowError> {
194    assert!(!timezone.is_empty());
195    match array.data_type() {
196        DataType::Timestamp(_, _) => {
197            // Spark doesn't output timezone while casting timestamp to string, but arrow's cast
198            // kernel does if timezone exists. So we need to apply offset of timezone to array
199            // timestamp value and remove timezone from array datatype.
200            let array = as_primitive_array::<TimestampMicrosecondType>(&array);
201
202            let tz: Tz = timezone.parse()?;
203            let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
204                as_datetime::<TimestampMicrosecondType>(value)
205                    .ok_or_else(|| datetime_cast_err(value))
206                    .map(|datetime| {
207                        let offset = tz.offset_from_utc_datetime(&datetime).fix();
208                        let datetime = datetime + offset;
209                        datetime.and_utc().timestamp_micros()
210                    })
211            })?;
212
213            Ok(Arc::new(array))
214        }
215        _ => Ok(array),
216    }
217}
218
219/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
220/// instead of Err to avoid the cost of formatting the error strings and is
221/// optimized to remove a memcpy that exists in the original function
222/// we can remove this code once we upgrade to a version of arrow-rs that
223/// includes https://github.com/apache/arrow-rs/pull/6419
224#[inline]
225pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
226    precision <= DECIMAL128_MAX_PRECISION
227        && value >= MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
228        && value <= MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize]
229}
230
231/// Build a boolean buffer from the state and reset the state, based on the emit_to
232/// strategy.
233pub fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
234    let bool_state: BooleanBuffer = state.finish();
235
236    match emit_to {
237        EmitTo::All => bool_state,
238        EmitTo::First(n) => {
239            state.append_buffer(&bool_state.slice(*n, bool_state.len() - n));
240            bool_state.slice(0, *n)
241        }
242    }
243}
244
245// These are borrowed from hashbrown crate:
246//   https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs
247
248// On stable we can use #[cold] to get a equivalent effect: this attributes
249// suggests that the function is unlikely to be called
250#[inline]
251#[cold]
252pub fn cold() {}
253
254#[inline]
255pub fn likely(b: bool) -> bool {
256    if !b {
257        cold();
258    }
259    b
260}
261#[inline]
262pub fn unlikely(b: bool) -> bool {
263    if b {
264        cold();
265    }
266    b
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_build_bool_state() {
275        let mut builder = BooleanBufferBuilder::new(0);
276        builder.append_packed_range(0..16, &[0x42u8, 0x39u8]);
277
278        let mut first_nine = BooleanBufferBuilder::new(0);
279        first_nine.append_packed_range(0..9, &[0x42u8, 0x01u8]);
280        let first_nine = first_nine.finish();
281        let mut last = BooleanBufferBuilder::new(0);
282        last.append_packed_range(0..7, &[0x1cu8]);
283        let last = last.finish();
284
285        assert_eq!(
286            first_nine,
287            build_bool_state(&mut builder, &EmitTo::First(9))
288        );
289        assert_eq!(last, build_bool_state(&mut builder, &EmitTo::All));
290    }
291}