polars_expr/
hash_keys.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::hash::BuildHasher;
3
4use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5use arrow::bitmap::Bitmap;
6use arrow::compute::utils::combine_validities_and_many;
7use polars_core::frame::DataFrame;
8use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10use polars_core::series::Series;
11use polars_utils::IdxSize;
12use polars_utils::cardinality_sketch::CardinalitySketch;
13use polars_utils::hashing::HashPartitioner;
14use polars_utils::itertools::Itertools;
15use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16use polars_utils::vec::PushUnchecked;
17
18#[derive(PartialEq, Eq, PartialOrd, Ord)]
19pub enum HashKeysVariant {
20    RowEncoded,
21    Single,
22    Binview,
23}
24
25pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26    match dt {
27        dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29        #[cfg(feature = "dtype-decimal")]
30        DataType::Decimal(_, _) => HashKeysVariant::Single,
31        #[cfg(feature = "dtype-categorical")]
32        DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34        DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36        // TODO: more efficient encoding for these.
37        DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39        _ => HashKeysVariant::RowEncoded,
40    }
41}
42
43macro_rules! downcast_single_key_ca {
44    (
45        $self:expr, | $ca:ident | $($body:tt)*
46    ) => {{
47        #[allow(unused_imports)]
48        use polars_core::datatypes::DataType::*;
49        match $self.dtype() {
50            #[cfg(feature = "dtype-i8")]
51            DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52            #[cfg(feature = "dtype-i16")]
53            DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54            DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55            DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56            #[cfg(feature = "dtype-u8")]
57            DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58            #[cfg(feature = "dtype-u16")]
59            DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60            DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61            DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62            #[cfg(feature = "dtype-i128")]
63            DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64            #[cfg(feature = "dtype-u128")]
65            DataType::UInt128 => { let $ca = $self.u128().unwrap(); $($body)* },
66            DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
67            DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
68
69            #[cfg(feature = "dtype-date")]
70            DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
71            #[cfg(feature = "dtype-time")]
72            DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
73            #[cfg(feature = "dtype-datetime")]
74            DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
75            #[cfg(feature = "dtype-duration")]
76            DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
77
78            #[cfg(feature = "dtype-decimal")]
79            DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
80            #[cfg(feature = "dtype-categorical")]
81            dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
82                match dt.cat_physical().unwrap() {
83                    CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
84                    CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
85                    CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
86                }
87            },
88
89            _ => unreachable!(),
90        }
91    }}
92}
93
94/// Represents a DataFrame plus a hash per row, intended for keys in grouping
95/// or joining. The hashes may or may not actually be physically pre-computed,
96/// this depends per type.
97#[derive(Clone, Debug)]
98pub enum HashKeys {
99    RowEncoded(RowEncodedKeys),
100    Binview(BinviewKeys),
101    Single(SingleKeys),
102}
103
104impl HashKeys {
105    pub fn from_df(
106        df: &DataFrame,
107        random_state: PlRandomState,
108        null_is_valid: bool,
109        force_row_encoding: bool,
110    ) -> Self {
111        let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
112        let use_row_encoding = force_row_encoding
113            || df.width() > 1
114            || first_col_variant == HashKeysVariant::RowEncoded;
115        if use_row_encoding {
116            let keys = df.get_columns();
117            let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
118
119            if !null_is_valid {
120                let validities = keys
121                    .iter()
122                    .map(|c| c.as_materialized_series().rechunk_validity())
123                    .collect_vec();
124                let combined = combine_validities_and_many(&validities);
125                keys_encoded.set_validity(combined);
126            }
127
128            // TODO: use vechash? Not supported yet for lists.
129            // let mut hashes = Vec::with_capacity(df.height());
130            // columns_to_hashes(df.get_columns(), Some(random_state), &mut hashes).unwrap();
131
132            let hashes = keys_encoded
133                .values_iter()
134                .map(|k| random_state.hash_one(k))
135                .collect();
136            Self::RowEncoded(RowEncodedKeys {
137                hashes: PrimitiveArray::from_vec(hashes),
138                keys: keys_encoded,
139            })
140        } else if first_col_variant == HashKeysVariant::Binview {
141            let keys = if let Ok(ca_str) = df[0].str() {
142                ca_str.as_binary()
143            } else {
144                df[0].binary().unwrap().clone()
145            };
146            let keys = keys.rechunk().downcast_as_array().clone();
147
148            let hashes = if keys.has_nulls() {
149                keys.iter()
150                    .map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
151                    .collect()
152            } else {
153                keys.values_iter()
154                    .map(|k| random_state.hash_one(k))
155                    .collect()
156            };
157
158            Self::Binview(BinviewKeys {
159                hashes: PrimitiveArray::from_vec(hashes),
160                keys,
161                null_is_valid,
162            })
163        } else {
164            Self::Single(SingleKeys {
165                random_state,
166                keys: df[0].as_materialized_series().rechunk(),
167                null_is_valid,
168            })
169        }
170    }
171
172    pub fn len(&self) -> usize {
173        match self {
174            HashKeys::RowEncoded(s) => s.keys.len(),
175            HashKeys::Single(s) => s.keys.len(),
176            HashKeys::Binview(s) => s.keys.len(),
177        }
178    }
179
180    pub fn is_empty(&self) -> bool {
181        self.len() == 0
182    }
183
184    pub fn validity(&self) -> Option<&Bitmap> {
185        match self {
186            HashKeys::RowEncoded(s) => s.keys.validity(),
187            HashKeys::Single(s) => s.keys.chunks()[0].validity(),
188            HashKeys::Binview(s) => s.keys.validity(),
189        }
190    }
191
192    pub fn null_is_valid(&self) -> bool {
193        match self {
194            HashKeys::RowEncoded(_) => false,
195            HashKeys::Single(s) => s.null_is_valid,
196            HashKeys::Binview(s) => s.null_is_valid,
197        }
198    }
199
200    /// Calls f with the index of and hash of each element in this HashKeys.
201    ///
202    /// If the element is null and null_is_valid is false the respective hash
203    /// will be None.
204    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
205        match self {
206            HashKeys::RowEncoded(s) => s.for_each_hash(f),
207            HashKeys::Single(s) => s.for_each_hash(f),
208            HashKeys::Binview(s) => s.for_each_hash(f),
209        }
210    }
211
212    /// Calls f with the index of and hash of each element in the given
213    /// subset of indices of the HashKeys.
214    ///
215    /// If the element is null and null_is_valid is false the respective hash
216    /// will be None.
217    ///
218    /// # Safety
219    /// The indices in the subset must be in-bounds.
220    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
221        &self,
222        subset: &[IdxSize],
223        f: F,
224    ) {
225        match self {
226            HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
227            HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
228            HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
229        }
230    }
231
232    /// After this call partitions will be extended with the partition for each
233    /// hash. Nulls are assigned IdxSize::MAX or a specific partition depending
234    /// on whether partition_nulls is true.
235    pub fn gen_partitions(
236        &self,
237        partitioner: &HashPartitioner,
238        partitions: &mut Vec<IdxSize>,
239        partition_nulls: bool,
240    ) {
241        unsafe {
242            let null_p = if partition_nulls | self.null_is_valid() {
243                partitioner.null_partition() as IdxSize
244            } else {
245                IdxSize::MAX
246            };
247            partitions.reserve(self.len());
248            self.for_each_hash(|_idx, opt_h| {
249                partitions.push_unchecked(
250                    opt_h
251                        .map(|h| partitioner.hash_to_partition(h) as IdxSize)
252                        .unwrap_or(null_p),
253                );
254            });
255        }
256    }
257
258    /// After this call partition_idxs[p] will be extended with the indices of
259    /// hashes that belong to partition p, and the cardinality sketches are
260    /// updated accordingly.
261    pub fn gen_idxs_per_partition(
262        &self,
263        partitioner: &HashPartitioner,
264        partition_idxs: &mut [Vec<IdxSize>],
265        sketches: &mut [CardinalitySketch],
266        partition_nulls: bool,
267    ) {
268        if sketches.is_empty() {
269            self.gen_idxs_per_partition_impl::<false>(
270                partitioner,
271                partition_idxs,
272                sketches,
273                partition_nulls | self.null_is_valid(),
274            );
275        } else {
276            self.gen_idxs_per_partition_impl::<true>(
277                partitioner,
278                partition_idxs,
279                sketches,
280                partition_nulls | self.null_is_valid(),
281            );
282        }
283    }
284
285    fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
286        &self,
287        partitioner: &HashPartitioner,
288        partition_idxs: &mut [Vec<IdxSize>],
289        sketches: &mut [CardinalitySketch],
290        partition_nulls: bool,
291    ) {
292        assert!(partition_idxs.len() == partitioner.num_partitions());
293        assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
294
295        let null_p = partitioner.null_partition();
296        self.for_each_hash(|idx, opt_h| {
297            if let Some(h) = opt_h {
298                unsafe {
299                    // SAFETY: we assured the number of partitions matches.
300                    let p = partitioner.hash_to_partition(h);
301                    partition_idxs.get_unchecked_mut(p).push(idx);
302                    if BUILD_SKETCHES {
303                        sketches.get_unchecked_mut(p).insert(h);
304                    }
305                }
306            } else if partition_nulls {
307                unsafe {
308                    partition_idxs.get_unchecked_mut(null_p).push(idx);
309                }
310            }
311        });
312    }
313
314    pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
315        self.for_each_hash(|_idx, opt_h| {
316            sketch.insert(opt_h.unwrap_or(0));
317        })
318    }
319
320    /// # Safety
321    /// The indices must be in-bounds.
322    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
323        match self {
324            HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
325            HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
326            HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
327        }
328    }
329}
330
331#[derive(Clone, Debug)]
332pub struct RowEncodedKeys {
333    pub hashes: UInt64Array, // Always non-null, we use the validity of keys.
334    pub keys: BinaryArray<i64>,
335}
336
337impl RowEncodedKeys {
338    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
339        for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
340    }
341
342    /// # Safety
343    /// The indices must be in-bounds.
344    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
345        &self,
346        subset: &[IdxSize],
347        f: F,
348    ) {
349        for_each_hash_subset_prehashed(
350            self.hashes.values().as_slice(),
351            self.keys.validity(),
352            subset,
353            f,
354        );
355    }
356
357    /// # Safety
358    /// The indices must be in-bounds.
359    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
360        let idx_arr = arrow::ffi::mmap::slice(idxs);
361        Self {
362            hashes: polars_compute::gather::primitive::take_primitive_unchecked(
363                &self.hashes,
364                &idx_arr,
365            ),
366            keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
367        }
368    }
369}
370
371/// Single keys without prehashing.
372#[derive(Clone, Debug)]
373pub struct SingleKeys {
374    pub random_state: PlRandomState,
375    pub keys: Series,
376    pub null_is_valid: bool,
377}
378
379impl SingleKeys {
380    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
381        downcast_single_key_ca!(self.keys, |keys| {
382            for_each_hash_single(keys, &self.random_state, f);
383        })
384    }
385
386    /// # Safety
387    /// The indices must be in-bounds.
388    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
389        &self,
390        subset: &[IdxSize],
391        f: F,
392    ) {
393        downcast_single_key_ca!(self.keys, |keys| {
394            for_each_hash_subset_single(keys, subset, &self.random_state, f);
395        })
396    }
397
398    /// # Safety
399    /// The indices must be in-bounds.
400    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
401        Self {
402            random_state: self.random_state.clone(),
403            keys: self.keys.take_slice_unchecked(idxs),
404            null_is_valid: self.null_is_valid,
405        }
406    }
407}
408
409/// Pre-hashed binary view keys with prehashing.
410#[derive(Clone, Debug)]
411pub struct BinviewKeys {
412    pub hashes: UInt64Array,
413    pub keys: BinaryViewArray,
414    pub null_is_valid: bool,
415}
416
417impl BinviewKeys {
418    pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
419        for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
420    }
421
422    /// # Safety
423    /// The indices must be in-bounds.
424    pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
425        &self,
426        subset: &[IdxSize],
427        f: F,
428    ) {
429        for_each_hash_subset_prehashed(
430            self.hashes.values().as_slice(),
431            self.keys.validity(),
432            subset,
433            f,
434        );
435    }
436
437    /// # Safety
438    /// The indices must be in-bounds.
439    pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
440        let idx_arr = arrow::ffi::mmap::slice(idxs);
441        Self {
442            hashes: polars_compute::gather::primitive::take_primitive_unchecked(
443                &self.hashes,
444                &idx_arr,
445            ),
446            keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
447            null_is_valid: self.null_is_valid,
448        }
449    }
450}
451
452fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
453    hashes: &[u64],
454    opt_v: Option<&Bitmap>,
455    mut f: F,
456) {
457    if let Some(validity) = opt_v {
458        for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
459            if is_v {
460                f(idx, Some(*hash))
461            } else {
462                f(idx, None)
463            }
464        }
465    } else {
466        for (idx, h) in hashes.iter().enumerate_idx() {
467            f(idx, Some(*h));
468        }
469    }
470}
471
472/// # Safety
473/// The indices must be in-bounds.
474unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
475    hashes: &[u64],
476    opt_v: Option<&Bitmap>,
477    subset: &[IdxSize],
478    mut f: F,
479) {
480    if let Some(validity) = opt_v {
481        for idx in subset {
482            let hash = *hashes.get_unchecked(*idx as usize);
483            let is_v = validity.get_bit_unchecked(*idx as usize);
484            if is_v {
485                f(*idx, Some(hash))
486            } else {
487                f(*idx, None)
488            }
489        }
490    } else {
491        for idx in subset {
492            f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
493        }
494    }
495}
496
497pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
498where
499    T: PolarsDataType,
500    for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
501    F: FnMut(IdxSize, Option<u64>),
502{
503    let mut idx = 0;
504    if keys.has_nulls() {
505        for arr in keys.downcast_iter() {
506            for opt_k in arr.iter() {
507                f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
508                idx += 1;
509            }
510        }
511    } else {
512        for arr in keys.downcast_iter() {
513            for k in arr.values_iter() {
514                f(idx, Some(random_state.tot_hash_one(k)));
515                idx += 1;
516            }
517        }
518    }
519}
520
521/// # Safety
522/// The indices must be in-bounds.
523unsafe fn for_each_hash_subset_single<T, F>(
524    keys: &ChunkedArray<T>,
525    subset: &[IdxSize],
526    random_state: &PlRandomState,
527    mut f: F,
528) where
529    T: PolarsDataType,
530    for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
531    F: FnMut(IdxSize, Option<u64>),
532{
533    let keys_arr = keys.downcast_as_array();
534
535    if keys_arr.has_nulls() {
536        for idx in subset {
537            let opt_k = keys_arr.get_unchecked(*idx as usize);
538            f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
539        }
540    } else {
541        for idx in subset {
542            let k = keys_arr.value_unchecked(*idx as usize);
543            f(*idx, Some(random_state.tot_hash_one(k)));
544        }
545    }
546}