datafusion_comet_spark_expr/hash_funcs/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This includes utilities for hashing and murmur3 hashing.
19
20#[macro_export]
21macro_rules! hash_array {
22    ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => {
23        let array = $column
24            .as_any()
25            .downcast_ref::<$array_type>()
26            .unwrap_or_else(|| {
27                panic!(
28                    "Failed to downcast column to {}. Actual data type: {:?}.",
29                    stringify!($array_type),
30                    $column.data_type()
31                )
32            });
33        if array.null_count() == 0 {
34            for (i, hash) in $hashes.iter_mut().enumerate() {
35                *hash = $hash_method(&array.value(i), *hash);
36            }
37        } else {
38            for (i, hash) in $hashes.iter_mut().enumerate() {
39                if !array.is_null(i) {
40                    *hash = $hash_method(&array.value(i), *hash);
41                }
42            }
43        }
44    };
45}
46
47#[macro_export]
48macro_rules! hash_array_boolean {
49    ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => {
50        let array = $column
51            .as_any()
52            .downcast_ref::<$array_type>()
53            .unwrap_or_else(|| {
54                panic!(
55                    "Failed to downcast column to {}. Actual data type: {:?}.",
56                    stringify!($array_type),
57                    $column.data_type()
58                )
59            });
60        if array.null_count() == 0 {
61            for (i, hash) in $hashes.iter_mut().enumerate() {
62                *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash);
63            }
64        } else {
65            for (i, hash) in $hashes.iter_mut().enumerate() {
66                if !array.is_null(i) {
67                    *hash =
68                        $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash);
69                }
70            }
71        }
72    };
73}
74
75#[macro_export]
76macro_rules! hash_array_primitive {
77    ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => {
78        let array = $column
79            .as_any()
80            .downcast_ref::<$array_type>()
81            .unwrap_or_else(|| {
82                panic!(
83                    "Failed to downcast column to {}. Actual data type: {:?}.",
84                    stringify!($array_type),
85                    $column.data_type()
86                )
87            });
88        let values = array.values();
89
90        if array.null_count() == 0 {
91            for (hash, value) in $hashes.iter_mut().zip(values.iter()) {
92                *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
93            }
94        } else {
95            for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() {
96                if !array.is_null(i) {
97                    *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
98                }
99            }
100        }
101    };
102}
103
104#[macro_export]
105macro_rules! hash_array_primitive_float {
106    ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => {
107        let array = $column
108            .as_any()
109            .downcast_ref::<$array_type>()
110            .unwrap_or_else(|| {
111                panic!(
112                    "Failed to downcast column to {}. Actual data type: {:?}.",
113                    stringify!($array_type),
114                    $column.data_type()
115                )
116            });
117        let values = array.values();
118
119        if array.null_count() == 0 {
120            for (hash, value) in $hashes.iter_mut().zip(values.iter()) {
121                // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression.
122                if *value == 0.0 && value.is_sign_negative() {
123                    *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash);
124                } else {
125                    *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
126                }
127            }
128        } else {
129            for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() {
130                if !array.is_null(i) {
131                    // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression.
132                    if *value == 0.0 && value.is_sign_negative() {
133                        *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash);
134                    } else {
135                        *hash = $hash_method((*value as $ty).to_le_bytes(), *hash);
136                    }
137                }
138            }
139        }
140    };
141}
142
143#[macro_export]
144macro_rules! hash_array_small_decimal {
145    ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => {
146        let array = $column
147            .as_any()
148            .downcast_ref::<$array_type>()
149            .unwrap_or_else(|| {
150                panic!(
151                    "Failed to downcast column to {}. Actual data type: {:?}.",
152                    stringify!($array_type),
153                    $column.data_type()
154                )
155            });
156
157        if array.null_count() == 0 {
158            for (i, hash) in $hashes.iter_mut().enumerate() {
159                *hash = $hash_method(
160                    i64::try_from(array.value(i))
161                        .map(|v| v.to_le_bytes())
162                        .map_err(|e| DataFusionError::Execution(e.to_string()))?,
163                    *hash,
164                );
165            }
166        } else {
167            for (i, hash) in $hashes.iter_mut().enumerate() {
168                if !array.is_null(i) {
169                    *hash = $hash_method(
170                        i64::try_from(array.value(i))
171                            .map(|v| v.to_le_bytes())
172                            .map_err(|e| DataFusionError::Execution(e.to_string()))?,
173                        *hash,
174                    );
175                }
176            }
177        }
178    };
179}
180
181#[macro_export]
182macro_rules! hash_array_decimal {
183    ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => {
184        let array = $column
185            .as_any()
186            .downcast_ref::<$array_type>()
187            .unwrap_or_else(|| {
188                panic!(
189                    "Failed to downcast column to {}. Actual data type: {:?}.",
190                    stringify!($array_type),
191                    $column.data_type()
192                )
193            });
194
195        if array.null_count() == 0 {
196            for (i, hash) in $hashes.iter_mut().enumerate() {
197                *hash = $hash_method(array.value(i).to_le_bytes(), *hash);
198            }
199        } else {
200            for (i, hash) in $hashes.iter_mut().enumerate() {
201                if !array.is_null(i) {
202                    *hash = $hash_method(array.value(i).to_le_bytes(), *hash);
203                }
204            }
205        }
206    };
207}
208
209/// Creates hash values for every row, based on the values in the
210/// columns.
211///
212/// The number of rows to hash is determined by `hashes_buffer.len()`.
213/// `hashes_buffer` should be pre-sized appropriately
214///
215/// `hash_method` is the hash function to use.
216/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input.
217#[macro_export]
218macro_rules! create_hashes_internal {
219    ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => {
220        use arrow::datatypes::{DataType, TimeUnit};
221        use arrow::array::{types::*, *};
222
223        for (i, col) in $arrays.iter().enumerate() {
224            let first_col = i == 0;
225            match col.data_type() {
226                DataType::Boolean => {
227                    $crate::hash_array_boolean!(
228                        BooleanArray,
229                        col,
230                        i32,
231                        $hashes_buffer,
232                        $hash_method
233                    );
234                }
235                DataType::Int8 => {
236                    $crate::hash_array_primitive!(
237                        Int8Array,
238                        col,
239                        i32,
240                        $hashes_buffer,
241                        $hash_method
242                    );
243                }
244                DataType::Int16 => {
245                    $crate::hash_array_primitive!(
246                        Int16Array,
247                        col,
248                        i32,
249                        $hashes_buffer,
250                        $hash_method
251                    );
252                }
253                DataType::Int32 => {
254                    $crate::hash_array_primitive!(
255                        Int32Array,
256                        col,
257                        i32,
258                        $hashes_buffer,
259                        $hash_method
260                    );
261                }
262                DataType::Int64 => {
263                    $crate::hash_array_primitive!(
264                        Int64Array,
265                        col,
266                        i64,
267                        $hashes_buffer,
268                        $hash_method
269                    );
270                }
271                DataType::Float32 => {
272                    $crate::hash_array_primitive_float!(
273                        Float32Array,
274                        col,
275                        f32,
276                        i32,
277                        $hashes_buffer,
278                        $hash_method
279                    );
280                }
281                DataType::Float64 => {
282                    $crate::hash_array_primitive_float!(
283                        Float64Array,
284                        col,
285                        f64,
286                        i64,
287                        $hashes_buffer,
288                        $hash_method
289                    );
290                }
291                DataType::Timestamp(TimeUnit::Second, _) => {
292                    $crate::hash_array_primitive!(
293                        TimestampSecondArray,
294                        col,
295                        i64,
296                        $hashes_buffer,
297                        $hash_method
298                    );
299                }
300                DataType::Timestamp(TimeUnit::Millisecond, _) => {
301                    $crate::hash_array_primitive!(
302                        TimestampMillisecondArray,
303                        col,
304                        i64,
305                        $hashes_buffer,
306                        $hash_method
307                    );
308                }
309                DataType::Timestamp(TimeUnit::Microsecond, _) => {
310                    $crate::hash_array_primitive!(
311                        TimestampMicrosecondArray,
312                        col,
313                        i64,
314                        $hashes_buffer,
315                        $hash_method
316                    );
317                }
318                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
319                    $crate::hash_array_primitive!(
320                        TimestampNanosecondArray,
321                        col,
322                        i64,
323                        $hashes_buffer,
324                        $hash_method
325                    );
326                }
327                DataType::Date32 => {
328                    $crate::hash_array_primitive!(
329                        Date32Array,
330                        col,
331                        i32,
332                        $hashes_buffer,
333                        $hash_method
334                    );
335                }
336                DataType::Date64 => {
337                    $crate::hash_array_primitive!(
338                        Date64Array,
339                        col,
340                        i64,
341                        $hashes_buffer,
342                        $hash_method
343                    );
344                }
345                DataType::Utf8 => {
346                    $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method);
347                }
348                DataType::LargeUtf8 => {
349                    $crate::hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method);
350                }
351                DataType::Binary => {
352                    $crate::hash_array!(BinaryArray, col, $hashes_buffer, $hash_method);
353                }
354                DataType::LargeBinary => {
355                    $crate::hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method);
356                }
357                DataType::FixedSizeBinary(_) => {
358                    $crate::hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method);
359                }
360                // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it.
361                // Else, turn it into bytes and hash it.
362                DataType::Decimal128(precision, _) if *precision <= 18 => {
363                    $crate::hash_array_small_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method);
364                }
365                DataType::Decimal128(_, _) => {
366                    $crate::hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method);
367                }
368                DataType::Dictionary(index_type, _) => match **index_type {
369                    DataType::Int8 => {
370                        $create_dictionary_hash_method::<Int8Type>(col, $hashes_buffer, first_col)?;
371                    }
372                    DataType::Int16 => {
373                        $create_dictionary_hash_method::<Int16Type>(
374                            col,
375                            $hashes_buffer,
376                            first_col,
377                        )?;
378                    }
379                    DataType::Int32 => {
380                        $create_dictionary_hash_method::<Int32Type>(
381                            col,
382                            $hashes_buffer,
383                            first_col,
384                        )?;
385                    }
386                    DataType::Int64 => {
387                        $create_dictionary_hash_method::<Int64Type>(
388                            col,
389                            $hashes_buffer,
390                            first_col,
391                        )?;
392                    }
393                    DataType::UInt8 => {
394                        $create_dictionary_hash_method::<UInt8Type>(
395                            col,
396                            $hashes_buffer,
397                            first_col,
398                        )?;
399                    }
400                    DataType::UInt16 => {
401                        $create_dictionary_hash_method::<UInt16Type>(
402                            col,
403                            $hashes_buffer,
404                            first_col,
405                        )?;
406                    }
407                    DataType::UInt32 => {
408                        $create_dictionary_hash_method::<UInt32Type>(
409                            col,
410                            $hashes_buffer,
411                            first_col,
412                        )?;
413                    }
414                    DataType::UInt64 => {
415                        $create_dictionary_hash_method::<UInt64Type>(
416                            col,
417                            $hashes_buffer,
418                            first_col,
419                        )?;
420                    }
421                    _ => {
422                        return Err(DataFusionError::Internal(format!(
423                            "Unsupported dictionary type in hasher hashing: {}",
424                            col.data_type(),
425                        )))
426                    }
427                },
428                _ => {
429                    // This is internal because we should have caught this before.
430                    return Err(DataFusionError::Internal(format!(
431                        "Unsupported data type in hasher: {}",
432                        col.data_type()
433                    )));
434                }
435            }
436        }
437    };
438}
439
440pub(crate) mod test_utils {
441
442    #[macro_export]
443    macro_rules! test_hashes_internal {
444        ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => {
445            let i = $input;
446            let mut hashes = $initial_seeds.clone();
447            $hash_method(&[i], &mut hashes).unwrap();
448            assert_eq!(hashes, $expected);
449        };
450    }
451
452    #[macro_export]
453    macro_rules! test_hashes_with_nulls {
454        ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => {
455            // copied before inserting nulls
456            let mut input_with_nulls = $values.clone();
457            let mut expected_with_nulls = $expected.clone();
458            // test before inserting nulls
459            let len = $values.len();
460            let initial_seeds = vec![42 as $seed_type; len];
461            let i = Arc::new(<$t>::from($values)) as ArrayRef;
462            $crate::test_hashes_internal!($method, i, initial_seeds, $expected);
463
464            // test with nulls
465            let median = len / 2;
466            input_with_nulls.insert(0, None);
467            input_with_nulls.insert(median, None);
468            expected_with_nulls.insert(0, 42 as $seed_type);
469            expected_with_nulls.insert(median, 42 as $seed_type);
470            let len_with_nulls = len + 2;
471            let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls];
472            let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef;
473            $crate::test_hashes_internal!(
474                $method,
475                nullable_input,
476                initial_seeds_with_nulls,
477                expected_with_nulls
478            );
479        };
480    }
481}