datafusion_comet_spark_expr/hash_funcs/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This includes utilities for hashing and murmur3 hashing.
19
20#[macro_export]
21macro_rules! hash_array {
22    ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => {
23        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
24        if array.null_count() == 0 {
25            for (i, hash) in $hashes.iter_mut().enumerate() {
26                *hash = $hash_method(&array.value(i), *hash);
27            }
28        } else {
29            for (i, hash) in $hashes.iter_mut().enumerate() {
30                if !array.is_null(i) {
31                    *hash = $hash_method(&array.value(i), *hash);
32                }
33            }
34        }
35    };
36}
37
38#[macro_export]
39macro_rules! hash_array_boolean {
40    ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => {
41        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
42        if array.null_count() == 0 {
43            for (i, hash) in $hashes.iter_mut().enumerate() {
44                *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash);
45            }
46        } else {
47            for (i, hash) in $hashes.iter_mut().enumerate() {
48                if !array.is_null(i) {
49                    *hash =
50                        $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash);
51                }
52            }
53        }
54    };
55}
56
57#[macro_export]
58macro_rules! hash_array_primitive {
59    ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => {
60        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
61        let values = array.values();
62
63        if array.null_count() == 0 {
64            for (hash, value) in $hashes.iter_mut().zip(values.iter()) {
65                *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
66            }
67        } else {
68            for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() {
69                if !array.is_null(i) {
70                    *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
71                }
72            }
73        }
74    };
75}
76
77#[macro_export]
78macro_rules! hash_array_primitive_float {
79    ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => {
80        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
81        let values = array.values();
82
83        if array.null_count() == 0 {
84            for (hash, value) in $hashes.iter_mut().zip(values.iter()) {
85                // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression.
86                if *value == 0.0 && value.is_sign_negative() {
87                    *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash);
88                } else {
89                    *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
90                }
91            }
92        } else {
93            for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() {
94                if !array.is_null(i) {
95                    // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression.
96                    if *value == 0.0 && value.is_sign_negative() {
97                        *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash);
98                    } else {
99                        *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
100                    }
101                }
102            }
103        }
104    };
105}
106
107#[macro_export]
108macro_rules! hash_array_small_decimal {
109    ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => {
110        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
111
112        if array.null_count() == 0 {
113            for (i, hash) in $hashes.iter_mut().enumerate() {
114                *hash = $hash_method(i64::try_from(array.value(i)).unwrap().to_le_bytes(), *hash);
115            }
116        } else {
117            for (i, hash) in $hashes.iter_mut().enumerate() {
118                if !array.is_null(i) {
119                    *hash =
120                        $hash_method(i64::try_from(array.value(i)).unwrap().to_le_bytes(), *hash);
121                }
122            }
123        }
124    };
125}
126
127#[macro_export]
128macro_rules! hash_array_decimal {
129    ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => {
130        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
131
132        if array.null_count() == 0 {
133            for (i, hash) in $hashes.iter_mut().enumerate() {
134                *hash = $hash_method(array.value(i).to_le_bytes(), *hash);
135            }
136        } else {
137            for (i, hash) in $hashes.iter_mut().enumerate() {
138                if !array.is_null(i) {
139                    *hash = $hash_method(array.value(i).to_le_bytes(), *hash);
140                }
141            }
142        }
143    };
144}
145
146/// Creates hash values for every row, based on the values in the
147/// columns.
148///
149/// The number of rows to hash is determined by `hashes_buffer.len()`.
150/// `hashes_buffer` should be pre-sized appropriately
151///
152/// `hash_method` is the hash function to use.
153/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input.
154#[macro_export]
155macro_rules! create_hashes_internal {
156    ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => {
157        use arrow::datatypes::{DataType, TimeUnit};
158        use arrow_array::{types::*, *};
159
160        for (i, col) in $arrays.iter().enumerate() {
161            let first_col = i == 0;
162            match col.data_type() {
163                DataType::Boolean => {
164                    $crate::hash_array_boolean!(
165                        BooleanArray,
166                        col,
167                        i32,
168                        $hashes_buffer,
169                        $hash_method
170                    );
171                }
172                DataType::Int8 => {
173                    $crate::hash_array_primitive!(
174                        Int8Array,
175                        col,
176                        i32,
177                        $hashes_buffer,
178                        $hash_method
179                    );
180                }
181                DataType::Int16 => {
182                    $crate::hash_array_primitive!(
183                        Int16Array,
184                        col,
185                        i32,
186                        $hashes_buffer,
187                        $hash_method
188                    );
189                }
190                DataType::Int32 => {
191                    $crate::hash_array_primitive!(
192                        Int32Array,
193                        col,
194                        i32,
195                        $hashes_buffer,
196                        $hash_method
197                    );
198                }
199                DataType::Int64 => {
200                    $crate::hash_array_primitive!(
201                        Int64Array,
202                        col,
203                        i64,
204                        $hashes_buffer,
205                        $hash_method
206                    );
207                }
208                DataType::Float32 => {
209                    $crate::hash_array_primitive_float!(
210                        Float32Array,
211                        col,
212                        f32,
213                        i32,
214                        $hashes_buffer,
215                        $hash_method
216                    );
217                }
218                DataType::Float64 => {
219                    $crate::hash_array_primitive_float!(
220                        Float64Array,
221                        col,
222                        f64,
223                        i64,
224                        $hashes_buffer,
225                        $hash_method
226                    );
227                }
228                DataType::Timestamp(TimeUnit::Second, _) => {
229                    $crate::hash_array_primitive!(
230                        TimestampSecondArray,
231                        col,
232                        i64,
233                        $hashes_buffer,
234                        $hash_method
235                    );
236                }
237                DataType::Timestamp(TimeUnit::Millisecond, _) => {
238                    $crate::hash_array_primitive!(
239                        TimestampMillisecondArray,
240                        col,
241                        i64,
242                        $hashes_buffer,
243                        $hash_method
244                    );
245                }
246                DataType::Timestamp(TimeUnit::Microsecond, _) => {
247                    $crate::hash_array_primitive!(
248                        TimestampMicrosecondArray,
249                        col,
250                        i64,
251                        $hashes_buffer,
252                        $hash_method
253                    );
254                }
255                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
256                    $crate::hash_array_primitive!(
257                        TimestampNanosecondArray,
258                        col,
259                        i64,
260                        $hashes_buffer,
261                        $hash_method
262                    );
263                }
264                DataType::Date32 => {
265                    $crate::hash_array_primitive!(
266                        Date32Array,
267                        col,
268                        i32,
269                        $hashes_buffer,
270                        $hash_method
271                    );
272                }
273                DataType::Date64 => {
274                    $crate::hash_array_primitive!(
275                        Date64Array,
276                        col,
277                        i64,
278                        $hashes_buffer,
279                        $hash_method
280                    );
281                }
282                DataType::Utf8 => {
283                    $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method);
284                }
285                DataType::LargeUtf8 => {
286                    $crate::hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method);
287                }
288                DataType::Binary => {
289                    $crate::hash_array!(BinaryArray, col, $hashes_buffer, $hash_method);
290                }
291                DataType::LargeBinary => {
292                    $crate::hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method);
293                }
294                DataType::FixedSizeBinary(_) => {
295                    $crate::hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method);
296                }
297                // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it.
298                // Else, turn it into bytes and hash it.
299                DataType::Decimal128(precision, _) if *precision <= 18 => {
300                    $crate::hash_array_small_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method);
301                }
302                DataType::Decimal128(_, _) => {
303                    $crate::hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method);
304                }
305                DataType::Dictionary(index_type, _) => match **index_type {
306                    DataType::Int8 => {
307                        $create_dictionary_hash_method::<Int8Type>(col, $hashes_buffer, first_col)?;
308                    }
309                    DataType::Int16 => {
310                        $create_dictionary_hash_method::<Int16Type>(
311                            col,
312                            $hashes_buffer,
313                            first_col,
314                        )?;
315                    }
316                    DataType::Int32 => {
317                        $create_dictionary_hash_method::<Int32Type>(
318                            col,
319                            $hashes_buffer,
320                            first_col,
321                        )?;
322                    }
323                    DataType::Int64 => {
324                        $create_dictionary_hash_method::<Int64Type>(
325                            col,
326                            $hashes_buffer,
327                            first_col,
328                        )?;
329                    }
330                    DataType::UInt8 => {
331                        $create_dictionary_hash_method::<UInt8Type>(
332                            col,
333                            $hashes_buffer,
334                            first_col,
335                        )?;
336                    }
337                    DataType::UInt16 => {
338                        $create_dictionary_hash_method::<UInt16Type>(
339                            col,
340                            $hashes_buffer,
341                            first_col,
342                        )?;
343                    }
344                    DataType::UInt32 => {
345                        $create_dictionary_hash_method::<UInt32Type>(
346                            col,
347                            $hashes_buffer,
348                            first_col,
349                        )?;
350                    }
351                    DataType::UInt64 => {
352                        $create_dictionary_hash_method::<UInt64Type>(
353                            col,
354                            $hashes_buffer,
355                            first_col,
356                        )?;
357                    }
358                    _ => {
359                        return Err(DataFusionError::Internal(format!(
360                            "Unsupported dictionary type in hasher hashing: {}",
361                            col.data_type(),
362                        )))
363                    }
364                },
365                _ => {
366                    // This is internal because we should have caught this before.
367                    return Err(DataFusionError::Internal(format!(
368                        "Unsupported data type in hasher: {}",
369                        col.data_type()
370                    )));
371                }
372            }
373        }
374    };
375}
376
377pub(crate) mod test_utils {
378
379    #[macro_export]
380    macro_rules! test_hashes_internal {
381        ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => {
382            let i = $input;
383            let mut hashes = $initial_seeds.clone();
384            $hash_method(&[i], &mut hashes).unwrap();
385            assert_eq!(hashes, $expected);
386        };
387    }
388
389    #[macro_export]
390    macro_rules! test_hashes_with_nulls {
391        ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => {
392            // copied before inserting nulls
393            let mut input_with_nulls = $values.clone();
394            let mut expected_with_nulls = $expected.clone();
395            // test before inserting nulls
396            let len = $values.len();
397            let initial_seeds = vec![42 as $seed_type; len];
398            let i = Arc::new(<$t>::from($values)) as ArrayRef;
399            $crate::test_hashes_internal!($method, i, initial_seeds, $expected);
400
401            // test with nulls
402            let median = len / 2;
403            input_with_nulls.insert(0, None);
404            input_with_nulls.insert(median, None);
405            expected_with_nulls.insert(0, 42 as $seed_type);
406            expected_with_nulls.insert(median, 42 as $seed_type);
407            let len_with_nulls = len + 2;
408            let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls];
409            let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef;
410            $crate::test_hashes_internal!(
411                $method,
412                nullable_input,
413                initial_seeds_with_nulls,
414                expected_with_nulls
415            );
416        };
417    }
418}