Skip to main content

polars_expr/
hash_keys.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::hash::BuildHasher;
3
4use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5use arrow::bitmap::Bitmap;
6use arrow::compute::utils::combine_validities_and_many;
7use polars_core::frame::DataFrame;
8use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10use polars_core::series::Series;
11use polars_utils::IdxSize;
12use polars_utils::cardinality_sketch::CardinalitySketch;
13use polars_utils::hashing::HashPartitioner;
14use polars_utils::itertools::Itertools;
15use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16use polars_utils::vec::PushUnchecked;
17
18#[derive(PartialEq, Eq, PartialOrd, Ord)]
19pub enum HashKeysVariant {
20    RowEncoded,
21    Single,
22    Binview,
23}
24
25pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26    match dt {
27        dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29        #[cfg(feature = "dtype-decimal")]
30        DataType::Decimal(_, _) => HashKeysVariant::Single,
31        #[cfg(feature = "dtype-categorical")]
32        DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34        DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36        // TODO: more efficient encoding for these.
37        DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39        _ => HashKeysVariant::RowEncoded,
40    }
41}
42
43macro_rules! downcast_single_key_ca {
44    (
45        $self:expr, | $ca:ident | $($body:tt)*
46    ) => {{
47        #[allow(unused_imports)]
48        use polars_core::datatypes::DataType::*;
49        match $self.dtype() {
50            #[cfg(feature = "dtype-i8")]
51            DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52            #[cfg(feature = "dtype-i16")]
53            DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54            DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55            DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56            #[cfg(feature = "dtype-u8")]
57            DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58            #[cfg(feature = "dtype-u16")]
59            DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60            DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61            DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62            #[cfg(feature = "dtype-i128")]
63            DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64            #[cfg(feature = "dtype-u128")]
65            DataType::UInt128 => { let $ca = $self.u128().unwrap(); $($body)* },
66            #[cfg(feature = "dtype-f16")]
67            DataType::Float16 => { let $ca = $self.f16().unwrap(); $($body)* },
68            DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
69            DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
70
71            #[cfg(feature = "dtype-date")]
72            DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
73            #[cfg(feature = "dtype-time")]
74            DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
75            #[cfg(feature = "dtype-datetime")]
76            DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
77            #[cfg(feature = "dtype-duration")]
78            DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
79
80            #[cfg(feature = "dtype-decimal")]
81            DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
82            #[cfg(feature = "dtype-categorical")]
83            dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
84                match dt.cat_physical().unwrap() {
85                    CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
86                    CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
87                    CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
88                }
89            },
90
91            _ => unreachable!(),
92        }
93    }}
94}
95
96/// Represents a DataFrame plus a hash per row, intended for keys in grouping
97/// or joining. The hashes may or may not actually be physically pre-computed,
98/// this depends per type.
99#[derive(Clone, Debug)]
100pub enum HashKeys {
101    RowEncoded(RowEncodedKeys),
102    Binview(BinviewKeys),
103    Single(SingleKeys),
104}
105
106impl HashKeys {
107    pub fn from_df(
108        df: &DataFrame,
109        random_state: PlRandomState,
110        null_is_valid: bool,
111        force_row_encoding: bool,
112    ) -> Self {
113        let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
114        let use_row_encoding = force_row_encoding
115            || df.width() > 1
116            || first_col_variant == HashKeysVariant::RowEncoded;
117        if use_row_encoding {
118            let keys = df.columns();
119            let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
120
121            if !null_is_valid {
122                let validities = keys
123                    .iter()
124                    .map(|c| c.as_materialized_series().rechunk_validity())
125                    .collect_vec();
126                let combined = combine_validities_and_many(&validities);
127                keys_encoded.set_validity(combined);
128            }
129
130            // TODO: use vechash? Not supported yet for lists.
131            // let mut hashes = Vec::with_capacity(df.height());
132            // columns_to_hashes(df.columns(), Some(random_state), &mut hashes).unwrap();
133
134            let hashes = keys_encoded
135                .values_iter()
136                .map(|k| random_state.hash_one(k))
137                .collect();
138            Self::RowEncoded(RowEncodedKeys {
139                hashes: PrimitiveArray::from_vec(hashes),
140                keys: keys_encoded,
141            })
142        } else if first_col_variant == HashKeysVariant::Binview {
143            let keys = if let Ok(ca_str) = df[0].str() {
144                ca_str.as_binary()
145            } else {
146                df[0].binary().unwrap().clone()
147            };
148            let keys = keys.rechunk().downcast_as_array().clone();
149
150            let hashes = if keys.has_nulls() {
151                keys.iter()
152                    .map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
153                    .collect()
154            } else {
155                keys.values_iter()
156                    .map(|k| random_state.hash_one(k))
157                    .collect()
158            };
159
160            Self::Binview(BinviewKeys {
161                hashes: PrimitiveArray::from_vec(hashes),
162                keys,
163                null_is_valid,
164            })
165        } else {
166            Self::Single(SingleKeys {
167                random_state,
168                keys: df[0].as_materialized_series().rechunk(),
169                null_is_valid,
170            })
171        }
172    }
173
174    pub fn len(&self) -> usize {
175        match self {
176            HashKeys::RowEncoded(s) => s.keys.len(),
177            HashKeys::Single(s) => s.keys.len(),
178            HashKeys::Binview(s) => s.keys.len(),
179        }
180    }
181
182    pub fn is_empty(&self) -> bool {
183        self.len() == 0
184    }
185
186    pub fn validity(&self) -> Option<&Bitmap> {
187        match self {
188            HashKeys::RowEncoded(s) => s.keys.validity(),
189            HashKeys::Single(s) => s.keys.chunks()[0].validity(),
190            HashKeys::Binview(s) => s.keys.validity(),
191        }
192    }
193
194    pub fn null_is_valid(&self) -> bool {
195        match self {
196            HashKeys::RowEncoded(_) => false,
197            HashKeys::Single(s) => s.null_is_valid,
198            HashKeys::Binview(s) => s.null_is_valid,
199        }
200    }
201
202    /// Calls f with the index of and hash of each element in this HashKeys.
203    ///
204    /// If the element is null and null_is_valid is false the respective hash
205    /// will be None.
206    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
207        match self {
208            HashKeys::RowEncoded(s) => s.for_each_hash(f),
209            HashKeys::Single(s) => s.for_each_hash(f),
210            HashKeys::Binview(s) => s.for_each_hash(f),
211        }
212    }
213
214    /// Calls f with the index of and hash of each element in the given
215    /// subset of indices of the HashKeys.
216    ///
217    /// If the element is null and null_is_valid is false the respective hash
218    /// will be None.
219    ///
220    /// # Safety
221    /// The indices in the subset must be in-bounds.
222    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
223        &self,
224        subset: &[IdxSize],
225        f: F,
226    ) {
227        match self {
228            HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
229            HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
230            HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
231        }
232    }
233
234    /// After this call partitions will be extended with the partition for each
235    /// hash. Nulls are assigned IdxSize::MAX or a specific partition depending
236    /// on whether partition_nulls is true.
237    pub fn gen_partitions(
238        &self,
239        partitioner: &HashPartitioner,
240        partitions: &mut Vec<IdxSize>,
241        partition_nulls: bool,
242    ) {
243        unsafe {
244            let null_p = if partition_nulls | self.null_is_valid() {
245                partitioner.null_partition() as IdxSize
246            } else {
247                IdxSize::MAX
248            };
249            partitions.reserve(self.len());
250            self.for_each_hash(|_idx, opt_h| {
251                partitions.push_unchecked(
252                    opt_h
253                        .map(|h| partitioner.hash_to_partition(h) as IdxSize)
254                        .unwrap_or(null_p),
255                );
256            });
257        }
258    }
259
260    /// After this call partition_idxs[p] will be extended with the indices of
261    /// hashes that belong to partition p, and the cardinality sketches are
262    /// updated accordingly.
263    pub fn gen_idxs_per_partition(
264        &self,
265        partitioner: &HashPartitioner,
266        partition_idxs: &mut [Vec<IdxSize>],
267        sketches: &mut [CardinalitySketch],
268        partition_nulls: bool,
269    ) {
270        if sketches.is_empty() {
271            self.gen_idxs_per_partition_impl::<false>(
272                partitioner,
273                partition_idxs,
274                sketches,
275                partition_nulls | self.null_is_valid(),
276            );
277        } else {
278            self.gen_idxs_per_partition_impl::<true>(
279                partitioner,
280                partition_idxs,
281                sketches,
282                partition_nulls | self.null_is_valid(),
283            );
284        }
285    }
286
287    fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
288        &self,
289        partitioner: &HashPartitioner,
290        partition_idxs: &mut [Vec<IdxSize>],
291        sketches: &mut [CardinalitySketch],
292        partition_nulls: bool,
293    ) {
294        assert!(partition_idxs.len() == partitioner.num_partitions());
295        assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
296
297        let null_p = partitioner.null_partition();
298        self.for_each_hash(|idx, opt_h| {
299            if let Some(h) = opt_h {
300                unsafe {
301                    // SAFETY: we assured the number of partitions matches.
302                    let p = partitioner.hash_to_partition(h);
303                    partition_idxs.get_unchecked_mut(p).push(idx);
304                    if BUILD_SKETCHES {
305                        sketches.get_unchecked_mut(p).insert(h);
306                    }
307                }
308            } else if partition_nulls {
309                unsafe {
310                    partition_idxs.get_unchecked_mut(null_p).push(idx);
311                }
312            }
313        });
314    }
315
316    pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
317        self.for_each_hash(|_idx, opt_h| {
318            sketch.insert(opt_h.unwrap_or(0));
319        })
320    }
321
322    /// # Safety
323    /// The indices must be in-bounds.
324    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
325        match self {
326            HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
327            HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
328            HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
329        }
330    }
331}
332
333#[derive(Clone, Debug)]
334pub struct RowEncodedKeys {
335    pub hashes: UInt64Array, // Always non-null, we use the validity of keys.
336    pub keys: BinaryArray<i64>,
337}
338
339impl RowEncodedKeys {
340    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
341        for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
342    }
343
344    /// # Safety
345    /// The indices must be in-bounds.
346    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
347        &self,
348        subset: &[IdxSize],
349        f: F,
350    ) {
351        for_each_hash_subset_prehashed(
352            self.hashes.values().as_slice(),
353            self.keys.validity(),
354            subset,
355            f,
356        );
357    }
358
359    /// # Safety
360    /// The indices must be in-bounds.
361    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
362        let idx_arr = arrow::ffi::mmap::slice(idxs);
363        Self {
364            hashes: polars_compute::gather::primitive::take_primitive_unchecked(
365                &self.hashes,
366                &idx_arr,
367            ),
368            keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
369        }
370    }
371}
372
373/// Single keys without prehashing.
374#[derive(Clone, Debug)]
375pub struct SingleKeys {
376    pub random_state: PlRandomState,
377    pub keys: Series,
378    pub null_is_valid: bool,
379}
380
381impl SingleKeys {
382    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
383        downcast_single_key_ca!(self.keys, |keys| {
384            for_each_hash_single(keys, &self.random_state, f);
385        })
386    }
387
388    /// # Safety
389    /// The indices must be in-bounds.
390    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
391        &self,
392        subset: &[IdxSize],
393        f: F,
394    ) {
395        downcast_single_key_ca!(self.keys, |keys| {
396            for_each_hash_subset_single(keys, subset, &self.random_state, f);
397        })
398    }
399
400    /// # Safety
401    /// The indices must be in-bounds.
402    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
403        Self {
404            random_state: self.random_state.clone(),
405            keys: self.keys.take_slice_unchecked(idxs),
406            null_is_valid: self.null_is_valid,
407        }
408    }
409}
410
411/// Pre-hashed binary view keys with prehashing.
412#[derive(Clone, Debug)]
413pub struct BinviewKeys {
414    pub hashes: UInt64Array,
415    pub keys: BinaryViewArray,
416    pub null_is_valid: bool,
417}
418
419impl BinviewKeys {
420    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
421        for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
422    }
423
424    /// # Safety
425    /// The indices must be in-bounds.
426    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
427        &self,
428        subset: &[IdxSize],
429        f: F,
430    ) {
431        for_each_hash_subset_prehashed(
432            self.hashes.values().as_slice(),
433            self.keys.validity(),
434            subset,
435            f,
436        );
437    }
438
439    /// # Safety
440    /// The indices must be in-bounds.
441    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
442        let idx_arr = arrow::ffi::mmap::slice(idxs);
443        Self {
444            hashes: polars_compute::gather::primitive::take_primitive_unchecked(
445                &self.hashes,
446                &idx_arr,
447            ),
448            keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
449            null_is_valid: self.null_is_valid,
450        }
451    }
452}
453
454fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
455    hashes: &[u64],
456    opt_v: Option<&Bitmap>,
457    mut f: F,
458) {
459    if let Some(validity) = opt_v {
460        for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
461            if is_v {
462                f(idx, Some(*hash))
463            } else {
464                f(idx, None)
465            }
466        }
467    } else {
468        for (idx, h) in hashes.iter().enumerate_idx() {
469            f(idx, Some(*h));
470        }
471    }
472}
473
474/// # Safety
475/// The indices must be in-bounds.
476unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
477    hashes: &[u64],
478    opt_v: Option<&Bitmap>,
479    subset: &[IdxSize],
480    mut f: F,
481) {
482    if let Some(validity) = opt_v {
483        for idx in subset {
484            let hash = *hashes.get_unchecked(*idx as usize);
485            let is_v = validity.get_bit_unchecked(*idx as usize);
486            if is_v {
487                f(*idx, Some(hash))
488            } else {
489                f(*idx, None)
490            }
491        }
492    } else {
493        for idx in subset {
494            f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
495        }
496    }
497}
498
499pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
500where
501    T: PolarsDataType,
502    for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
503    F: FnMut(IdxSize, Option<u64>),
504{
505    let mut idx = 0;
506    if keys.has_nulls() {
507        for arr in keys.downcast_iter() {
508            for opt_k in arr.iter() {
509                f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
510                idx += 1;
511            }
512        }
513    } else {
514        for arr in keys.downcast_iter() {
515            for k in arr.values_iter() {
516                f(idx, Some(random_state.tot_hash_one(k)));
517                idx += 1;
518            }
519        }
520    }
521}
522
523/// # Safety
524/// The indices must be in-bounds.
525unsafe fn for_each_hash_subset_single<T, F>(
526    keys: &ChunkedArray<T>,
527    subset: &[IdxSize],
528    random_state: &PlRandomState,
529    mut f: F,
530) where
531    T: PolarsDataType,
532    for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
533    F: FnMut(IdxSize, Option<u64>),
534{
535    let keys_arr = keys.downcast_as_array();
536
537    if keys_arr.has_nulls() {
538        for idx in subset {
539            let opt_k = keys_arr.get_unchecked(*idx as usize);
540            f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
541        }
542    } else {
543        for idx in subset {
544            let k = keys_arr.value_unchecked(*idx as usize);
545            f(*idx, Some(random_state.tot_hash_one(k)));
546        }
547    }
548}