1#![allow(unsafe_op_in_unsafe_fn)]
2use std::hash::BuildHasher;
3
4use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5use arrow::bitmap::Bitmap;
6use arrow::compute::utils::combine_validities_and_many;
7use polars_core::frame::DataFrame;
8use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10use polars_core::series::Series;
11use polars_utils::IdxSize;
12use polars_utils::cardinality_sketch::CardinalitySketch;
13use polars_utils::hashing::HashPartitioner;
14use polars_utils::itertools::Itertools;
15use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16use polars_utils::vec::PushUnchecked;
17
18#[derive(PartialEq, Eq, PartialOrd, Ord)]
19pub enum HashKeysVariant {
20 RowEncoded,
21 Single,
22 Binview,
23}
24
25pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26 match dt {
27 dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29 #[cfg(feature = "dtype-decimal")]
30 DataType::Decimal(_, _) => HashKeysVariant::Single,
31 #[cfg(feature = "dtype-categorical")]
32 DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34 DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36 DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39 _ => HashKeysVariant::RowEncoded,
40 }
41}
42
43macro_rules! downcast_single_key_ca {
44 (
45 $self:expr, | $ca:ident | $($body:tt)*
46 ) => {{
47 #[allow(unused_imports)]
48 use polars_core::datatypes::DataType::*;
49 match $self.dtype() {
50 #[cfg(feature = "dtype-i8")]
51 DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52 #[cfg(feature = "dtype-i16")]
53 DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54 DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55 DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56 #[cfg(feature = "dtype-u8")]
57 DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58 #[cfg(feature = "dtype-u16")]
59 DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60 DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61 DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62 #[cfg(feature = "dtype-i128")]
63 DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64 #[cfg(feature = "dtype-u128")]
65 DataType::UInt128 => { let $ca = $self.u128().unwrap(); $($body)* },
66 #[cfg(feature = "dtype-f16")]
67 DataType::Float16 => { let $ca = $self.f16().unwrap(); $($body)* },
68 DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
69 DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
70
71 #[cfg(feature = "dtype-date")]
72 DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
73 #[cfg(feature = "dtype-time")]
74 DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
75 #[cfg(feature = "dtype-datetime")]
76 DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
77 #[cfg(feature = "dtype-duration")]
78 DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
79
80 #[cfg(feature = "dtype-decimal")]
81 DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
82 #[cfg(feature = "dtype-categorical")]
83 dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
84 match dt.cat_physical().unwrap() {
85 CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
86 CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
87 CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
88 }
89 },
90
91 _ => unreachable!(),
92 }
93 }}
94}
95
96#[derive(Clone, Debug)]
100pub enum HashKeys {
101 RowEncoded(RowEncodedKeys),
102 Binview(BinviewKeys),
103 Single(SingleKeys),
104}
105
106impl HashKeys {
107 pub fn from_df(
108 df: &DataFrame,
109 random_state: PlRandomState,
110 null_is_valid: bool,
111 force_row_encoding: bool,
112 ) -> Self {
113 let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
114 let use_row_encoding = force_row_encoding
115 || df.width() > 1
116 || first_col_variant == HashKeysVariant::RowEncoded;
117 if use_row_encoding {
118 let keys = df.columns();
119 let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
120
121 if !null_is_valid {
122 let validities = keys
123 .iter()
124 .map(|c| c.as_materialized_series().rechunk_validity())
125 .collect_vec();
126 let combined = combine_validities_and_many(&validities);
127 keys_encoded.set_validity(combined);
128 }
129
130 let hashes = keys_encoded
135 .values_iter()
136 .map(|k| random_state.hash_one(k))
137 .collect();
138 Self::RowEncoded(RowEncodedKeys {
139 hashes: PrimitiveArray::from_vec(hashes),
140 keys: keys_encoded,
141 })
142 } else if first_col_variant == HashKeysVariant::Binview {
143 let keys = if let Ok(ca_str) = df[0].str() {
144 ca_str.as_binary()
145 } else {
146 df[0].binary().unwrap().clone()
147 };
148 let keys = keys.rechunk().downcast_as_array().clone();
149
150 let hashes = if keys.has_nulls() {
151 keys.iter()
152 .map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
153 .collect()
154 } else {
155 keys.values_iter()
156 .map(|k| random_state.hash_one(k))
157 .collect()
158 };
159
160 Self::Binview(BinviewKeys {
161 hashes: PrimitiveArray::from_vec(hashes),
162 keys,
163 null_is_valid,
164 })
165 } else {
166 Self::Single(SingleKeys {
167 random_state,
168 keys: df[0].as_materialized_series().rechunk(),
169 null_is_valid,
170 })
171 }
172 }
173
174 pub fn len(&self) -> usize {
175 match self {
176 HashKeys::RowEncoded(s) => s.keys.len(),
177 HashKeys::Single(s) => s.keys.len(),
178 HashKeys::Binview(s) => s.keys.len(),
179 }
180 }
181
182 pub fn is_empty(&self) -> bool {
183 self.len() == 0
184 }
185
186 pub fn validity(&self) -> Option<&Bitmap> {
187 match self {
188 HashKeys::RowEncoded(s) => s.keys.validity(),
189 HashKeys::Single(s) => s.keys.chunks()[0].validity(),
190 HashKeys::Binview(s) => s.keys.validity(),
191 }
192 }
193
194 pub fn null_is_valid(&self) -> bool {
195 match self {
196 HashKeys::RowEncoded(_) => false,
197 HashKeys::Single(s) => s.null_is_valid,
198 HashKeys::Binview(s) => s.null_is_valid,
199 }
200 }
201
202 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
207 match self {
208 HashKeys::RowEncoded(s) => s.for_each_hash(f),
209 HashKeys::Single(s) => s.for_each_hash(f),
210 HashKeys::Binview(s) => s.for_each_hash(f),
211 }
212 }
213
214 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
223 &self,
224 subset: &[IdxSize],
225 f: F,
226 ) {
227 match self {
228 HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
229 HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
230 HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
231 }
232 }
233
234 pub fn gen_partitions(
238 &self,
239 partitioner: &HashPartitioner,
240 partitions: &mut Vec<IdxSize>,
241 partition_nulls: bool,
242 ) {
243 unsafe {
244 let null_p = if partition_nulls | self.null_is_valid() {
245 partitioner.null_partition() as IdxSize
246 } else {
247 IdxSize::MAX
248 };
249 partitions.reserve(self.len());
250 self.for_each_hash(|_idx, opt_h| {
251 partitions.push_unchecked(
252 opt_h
253 .map(|h| partitioner.hash_to_partition(h) as IdxSize)
254 .unwrap_or(null_p),
255 );
256 });
257 }
258 }
259
260 pub fn gen_idxs_per_partition(
264 &self,
265 partitioner: &HashPartitioner,
266 partition_idxs: &mut [Vec<IdxSize>],
267 sketches: &mut [CardinalitySketch],
268 partition_nulls: bool,
269 ) {
270 if sketches.is_empty() {
271 self.gen_idxs_per_partition_impl::<false>(
272 partitioner,
273 partition_idxs,
274 sketches,
275 partition_nulls | self.null_is_valid(),
276 );
277 } else {
278 self.gen_idxs_per_partition_impl::<true>(
279 partitioner,
280 partition_idxs,
281 sketches,
282 partition_nulls | self.null_is_valid(),
283 );
284 }
285 }
286
287 fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
288 &self,
289 partitioner: &HashPartitioner,
290 partition_idxs: &mut [Vec<IdxSize>],
291 sketches: &mut [CardinalitySketch],
292 partition_nulls: bool,
293 ) {
294 assert!(partition_idxs.len() == partitioner.num_partitions());
295 assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
296
297 let null_p = partitioner.null_partition();
298 self.for_each_hash(|idx, opt_h| {
299 if let Some(h) = opt_h {
300 unsafe {
301 let p = partitioner.hash_to_partition(h);
303 partition_idxs.get_unchecked_mut(p).push(idx);
304 if BUILD_SKETCHES {
305 sketches.get_unchecked_mut(p).insert(h);
306 }
307 }
308 } else if partition_nulls {
309 unsafe {
310 partition_idxs.get_unchecked_mut(null_p).push(idx);
311 }
312 }
313 });
314 }
315
316 pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
317 self.for_each_hash(|_idx, opt_h| {
318 sketch.insert(opt_h.unwrap_or(0));
319 })
320 }
321
322 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
325 match self {
326 HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
327 HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
328 HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
329 }
330 }
331}
332
333#[derive(Clone, Debug)]
334pub struct RowEncodedKeys {
335 pub hashes: UInt64Array, pub keys: BinaryArray<i64>,
337}
338
339impl RowEncodedKeys {
340 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
341 for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
342 }
343
344 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
347 &self,
348 subset: &[IdxSize],
349 f: F,
350 ) {
351 for_each_hash_subset_prehashed(
352 self.hashes.values().as_slice(),
353 self.keys.validity(),
354 subset,
355 f,
356 );
357 }
358
359 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
362 let idx_arr = arrow::ffi::mmap::slice(idxs);
363 Self {
364 hashes: polars_compute::gather::primitive::take_primitive_unchecked(
365 &self.hashes,
366 &idx_arr,
367 ),
368 keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
369 }
370 }
371}
372
373#[derive(Clone, Debug)]
375pub struct SingleKeys {
376 pub random_state: PlRandomState,
377 pub keys: Series,
378 pub null_is_valid: bool,
379}
380
381impl SingleKeys {
382 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
383 downcast_single_key_ca!(self.keys, |keys| {
384 for_each_hash_single(keys, &self.random_state, f);
385 })
386 }
387
388 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
391 &self,
392 subset: &[IdxSize],
393 f: F,
394 ) {
395 downcast_single_key_ca!(self.keys, |keys| {
396 for_each_hash_subset_single(keys, subset, &self.random_state, f);
397 })
398 }
399
400 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
403 Self {
404 random_state: self.random_state.clone(),
405 keys: self.keys.take_slice_unchecked(idxs),
406 null_is_valid: self.null_is_valid,
407 }
408 }
409}
410
411#[derive(Clone, Debug)]
413pub struct BinviewKeys {
414 pub hashes: UInt64Array,
415 pub keys: BinaryViewArray,
416 pub null_is_valid: bool,
417}
418
419impl BinviewKeys {
420 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
421 for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
422 }
423
424 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
427 &self,
428 subset: &[IdxSize],
429 f: F,
430 ) {
431 for_each_hash_subset_prehashed(
432 self.hashes.values().as_slice(),
433 self.keys.validity(),
434 subset,
435 f,
436 );
437 }
438
439 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
442 let idx_arr = arrow::ffi::mmap::slice(idxs);
443 Self {
444 hashes: polars_compute::gather::primitive::take_primitive_unchecked(
445 &self.hashes,
446 &idx_arr,
447 ),
448 keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
449 null_is_valid: self.null_is_valid,
450 }
451 }
452}
453
454fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
455 hashes: &[u64],
456 opt_v: Option<&Bitmap>,
457 mut f: F,
458) {
459 if let Some(validity) = opt_v {
460 for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
461 if is_v {
462 f(idx, Some(*hash))
463 } else {
464 f(idx, None)
465 }
466 }
467 } else {
468 for (idx, h) in hashes.iter().enumerate_idx() {
469 f(idx, Some(*h));
470 }
471 }
472}
473
474unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
477 hashes: &[u64],
478 opt_v: Option<&Bitmap>,
479 subset: &[IdxSize],
480 mut f: F,
481) {
482 if let Some(validity) = opt_v {
483 for idx in subset {
484 let hash = *hashes.get_unchecked(*idx as usize);
485 let is_v = validity.get_bit_unchecked(*idx as usize);
486 if is_v {
487 f(*idx, Some(hash))
488 } else {
489 f(*idx, None)
490 }
491 }
492 } else {
493 for idx in subset {
494 f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
495 }
496 }
497}
498
499pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
500where
501 T: PolarsDataType,
502 for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
503 F: FnMut(IdxSize, Option<u64>),
504{
505 let mut idx = 0;
506 if keys.has_nulls() {
507 for arr in keys.downcast_iter() {
508 for opt_k in arr.iter() {
509 f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
510 idx += 1;
511 }
512 }
513 } else {
514 for arr in keys.downcast_iter() {
515 for k in arr.values_iter() {
516 f(idx, Some(random_state.tot_hash_one(k)));
517 idx += 1;
518 }
519 }
520 }
521}
522
523unsafe fn for_each_hash_subset_single<T, F>(
526 keys: &ChunkedArray<T>,
527 subset: &[IdxSize],
528 random_state: &PlRandomState,
529 mut f: F,
530) where
531 T: PolarsDataType,
532 for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
533 F: FnMut(IdxSize, Option<u64>),
534{
535 let keys_arr = keys.downcast_as_array();
536
537 if keys_arr.has_nulls() {
538 for idx in subset {
539 let opt_k = keys_arr.get_unchecked(*idx as usize);
540 f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
541 }
542 } else {
543 for idx in subset {
544 let k = keys_arr.value_unchecked(*idx as usize);
545 f(*idx, Some(random_state.tot_hash_one(k)));
546 }
547 }
548}