1#![allow(unsafe_op_in_unsafe_fn)]
2use std::hash::BuildHasher;
3
4use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5use arrow::bitmap::Bitmap;
6use arrow::compute::utils::combine_validities_and_many;
7use polars_core::frame::DataFrame;
8use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10use polars_core::series::Series;
11use polars_utils::IdxSize;
12use polars_utils::cardinality_sketch::CardinalitySketch;
13use polars_utils::hashing::HashPartitioner;
14use polars_utils::itertools::Itertools;
15use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16use polars_utils::vec::PushUnchecked;
17
18#[derive(PartialEq, Eq, PartialOrd, Ord)]
19pub enum HashKeysVariant {
20 RowEncoded,
21 Single,
22 Binview,
23}
24
25pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26 match dt {
27 dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29 #[cfg(feature = "dtype-decimal")]
30 DataType::Decimal(_, _) => HashKeysVariant::Single,
31 #[cfg(feature = "dtype-categorical")]
32 DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34 DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36 DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39 _ => HashKeysVariant::RowEncoded,
40 }
41}
42
43macro_rules! downcast_single_key_ca {
44 (
45 $self:expr, | $ca:ident | $($body:tt)*
46 ) => {{
47 #[allow(unused_imports)]
48 use polars_core::datatypes::DataType::*;
49 match $self.dtype() {
50 #[cfg(feature = "dtype-i8")]
51 DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52 #[cfg(feature = "dtype-i16")]
53 DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54 DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55 DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56 #[cfg(feature = "dtype-u8")]
57 DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58 #[cfg(feature = "dtype-u16")]
59 DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60 DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61 DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62 #[cfg(feature = "dtype-i128")]
63 DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64 #[cfg(feature = "dtype-u128")]
65 DataType::UInt128 => { let $ca = $self.u128().unwrap(); $($body)* },
66 DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
67 DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
68
69 #[cfg(feature = "dtype-date")]
70 DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
71 #[cfg(feature = "dtype-time")]
72 DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
73 #[cfg(feature = "dtype-datetime")]
74 DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
75 #[cfg(feature = "dtype-duration")]
76 DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
77
78 #[cfg(feature = "dtype-decimal")]
79 DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
80 #[cfg(feature = "dtype-categorical")]
81 dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
82 match dt.cat_physical().unwrap() {
83 CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
84 CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
85 CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
86 }
87 },
88
89 _ => unreachable!(),
90 }
91 }}
92}
93
94#[derive(Clone, Debug)]
98pub enum HashKeys {
99 RowEncoded(RowEncodedKeys),
100 Binview(BinviewKeys),
101 Single(SingleKeys),
102}
103
104impl HashKeys {
105 pub fn from_df(
106 df: &DataFrame,
107 random_state: PlRandomState,
108 null_is_valid: bool,
109 force_row_encoding: bool,
110 ) -> Self {
111 let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
112 let use_row_encoding = force_row_encoding
113 || df.width() > 1
114 || first_col_variant == HashKeysVariant::RowEncoded;
115 if use_row_encoding {
116 let keys = df.get_columns();
117 let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
118
119 if !null_is_valid {
120 let validities = keys
121 .iter()
122 .map(|c| c.as_materialized_series().rechunk_validity())
123 .collect_vec();
124 let combined = combine_validities_and_many(&validities);
125 keys_encoded.set_validity(combined);
126 }
127
128 let hashes = keys_encoded
133 .values_iter()
134 .map(|k| random_state.hash_one(k))
135 .collect();
136 Self::RowEncoded(RowEncodedKeys {
137 hashes: PrimitiveArray::from_vec(hashes),
138 keys: keys_encoded,
139 })
140 } else if first_col_variant == HashKeysVariant::Binview {
141 let keys = if let Ok(ca_str) = df[0].str() {
142 ca_str.as_binary()
143 } else {
144 df[0].binary().unwrap().clone()
145 };
146 let keys = keys.rechunk().downcast_as_array().clone();
147
148 let hashes = if keys.has_nulls() {
149 keys.iter()
150 .map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
151 .collect()
152 } else {
153 keys.values_iter()
154 .map(|k| random_state.hash_one(k))
155 .collect()
156 };
157
158 Self::Binview(BinviewKeys {
159 hashes: PrimitiveArray::from_vec(hashes),
160 keys,
161 null_is_valid,
162 })
163 } else {
164 Self::Single(SingleKeys {
165 random_state,
166 keys: df[0].as_materialized_series().rechunk(),
167 null_is_valid,
168 })
169 }
170 }
171
172 pub fn len(&self) -> usize {
173 match self {
174 HashKeys::RowEncoded(s) => s.keys.len(),
175 HashKeys::Single(s) => s.keys.len(),
176 HashKeys::Binview(s) => s.keys.len(),
177 }
178 }
179
180 pub fn is_empty(&self) -> bool {
181 self.len() == 0
182 }
183
184 pub fn validity(&self) -> Option<&Bitmap> {
185 match self {
186 HashKeys::RowEncoded(s) => s.keys.validity(),
187 HashKeys::Single(s) => s.keys.chunks()[0].validity(),
188 HashKeys::Binview(s) => s.keys.validity(),
189 }
190 }
191
192 pub fn null_is_valid(&self) -> bool {
193 match self {
194 HashKeys::RowEncoded(_) => false,
195 HashKeys::Single(s) => s.null_is_valid,
196 HashKeys::Binview(s) => s.null_is_valid,
197 }
198 }
199
200 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
205 match self {
206 HashKeys::RowEncoded(s) => s.for_each_hash(f),
207 HashKeys::Single(s) => s.for_each_hash(f),
208 HashKeys::Binview(s) => s.for_each_hash(f),
209 }
210 }
211
212 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
221 &self,
222 subset: &[IdxSize],
223 f: F,
224 ) {
225 match self {
226 HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
227 HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
228 HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
229 }
230 }
231
232 pub fn gen_partitions(
236 &self,
237 partitioner: &HashPartitioner,
238 partitions: &mut Vec<IdxSize>,
239 partition_nulls: bool,
240 ) {
241 unsafe {
242 let null_p = if partition_nulls | self.null_is_valid() {
243 partitioner.null_partition() as IdxSize
244 } else {
245 IdxSize::MAX
246 };
247 partitions.reserve(self.len());
248 self.for_each_hash(|_idx, opt_h| {
249 partitions.push_unchecked(
250 opt_h
251 .map(|h| partitioner.hash_to_partition(h) as IdxSize)
252 .unwrap_or(null_p),
253 );
254 });
255 }
256 }
257
258 pub fn gen_idxs_per_partition(
262 &self,
263 partitioner: &HashPartitioner,
264 partition_idxs: &mut [Vec<IdxSize>],
265 sketches: &mut [CardinalitySketch],
266 partition_nulls: bool,
267 ) {
268 if sketches.is_empty() {
269 self.gen_idxs_per_partition_impl::<false>(
270 partitioner,
271 partition_idxs,
272 sketches,
273 partition_nulls | self.null_is_valid(),
274 );
275 } else {
276 self.gen_idxs_per_partition_impl::<true>(
277 partitioner,
278 partition_idxs,
279 sketches,
280 partition_nulls | self.null_is_valid(),
281 );
282 }
283 }
284
285 fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
286 &self,
287 partitioner: &HashPartitioner,
288 partition_idxs: &mut [Vec<IdxSize>],
289 sketches: &mut [CardinalitySketch],
290 partition_nulls: bool,
291 ) {
292 assert!(partition_idxs.len() == partitioner.num_partitions());
293 assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
294
295 let null_p = partitioner.null_partition();
296 self.for_each_hash(|idx, opt_h| {
297 if let Some(h) = opt_h {
298 unsafe {
299 let p = partitioner.hash_to_partition(h);
301 partition_idxs.get_unchecked_mut(p).push(idx);
302 if BUILD_SKETCHES {
303 sketches.get_unchecked_mut(p).insert(h);
304 }
305 }
306 } else if partition_nulls {
307 unsafe {
308 partition_idxs.get_unchecked_mut(null_p).push(idx);
309 }
310 }
311 });
312 }
313
314 pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
315 self.for_each_hash(|_idx, opt_h| {
316 sketch.insert(opt_h.unwrap_or(0));
317 })
318 }
319
320 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
323 match self {
324 HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
325 HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
326 HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
327 }
328 }
329}
330
331#[derive(Clone, Debug)]
332pub struct RowEncodedKeys {
333 pub hashes: UInt64Array, pub keys: BinaryArray<i64>,
335}
336
337impl RowEncodedKeys {
338 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
339 for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
340 }
341
342 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
345 &self,
346 subset: &[IdxSize],
347 f: F,
348 ) {
349 for_each_hash_subset_prehashed(
350 self.hashes.values().as_slice(),
351 self.keys.validity(),
352 subset,
353 f,
354 );
355 }
356
357 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
360 let idx_arr = arrow::ffi::mmap::slice(idxs);
361 Self {
362 hashes: polars_compute::gather::primitive::take_primitive_unchecked(
363 &self.hashes,
364 &idx_arr,
365 ),
366 keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
367 }
368 }
369}
370
371#[derive(Clone, Debug)]
373pub struct SingleKeys {
374 pub random_state: PlRandomState,
375 pub keys: Series,
376 pub null_is_valid: bool,
377}
378
379impl SingleKeys {
380 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
381 downcast_single_key_ca!(self.keys, |keys| {
382 for_each_hash_single(keys, &self.random_state, f);
383 })
384 }
385
386 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
389 &self,
390 subset: &[IdxSize],
391 f: F,
392 ) {
393 downcast_single_key_ca!(self.keys, |keys| {
394 for_each_hash_subset_single(keys, subset, &self.random_state, f);
395 })
396 }
397
398 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
401 Self {
402 random_state: self.random_state.clone(),
403 keys: self.keys.take_slice_unchecked(idxs),
404 null_is_valid: self.null_is_valid,
405 }
406 }
407}
408
409#[derive(Clone, Debug)]
411pub struct BinviewKeys {
412 pub hashes: UInt64Array,
413 pub keys: BinaryViewArray,
414 pub null_is_valid: bool,
415}
416
417impl BinviewKeys {
418 pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
419 for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
420 }
421
422 pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
425 &self,
426 subset: &[IdxSize],
427 f: F,
428 ) {
429 for_each_hash_subset_prehashed(
430 self.hashes.values().as_slice(),
431 self.keys.validity(),
432 subset,
433 f,
434 );
435 }
436
437 pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
440 let idx_arr = arrow::ffi::mmap::slice(idxs);
441 Self {
442 hashes: polars_compute::gather::primitive::take_primitive_unchecked(
443 &self.hashes,
444 &idx_arr,
445 ),
446 keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
447 null_is_valid: self.null_is_valid,
448 }
449 }
450}
451
452fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
453 hashes: &[u64],
454 opt_v: Option<&Bitmap>,
455 mut f: F,
456) {
457 if let Some(validity) = opt_v {
458 for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
459 if is_v {
460 f(idx, Some(*hash))
461 } else {
462 f(idx, None)
463 }
464 }
465 } else {
466 for (idx, h) in hashes.iter().enumerate_idx() {
467 f(idx, Some(*h));
468 }
469 }
470}
471
472unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
475 hashes: &[u64],
476 opt_v: Option<&Bitmap>,
477 subset: &[IdxSize],
478 mut f: F,
479) {
480 if let Some(validity) = opt_v {
481 for idx in subset {
482 let hash = *hashes.get_unchecked(*idx as usize);
483 let is_v = validity.get_bit_unchecked(*idx as usize);
484 if is_v {
485 f(*idx, Some(hash))
486 } else {
487 f(*idx, None)
488 }
489 }
490 } else {
491 for idx in subset {
492 f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
493 }
494 }
495}
496
497pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
498where
499 T: PolarsDataType,
500 for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
501 F: FnMut(IdxSize, Option<u64>),
502{
503 let mut idx = 0;
504 if keys.has_nulls() {
505 for arr in keys.downcast_iter() {
506 for opt_k in arr.iter() {
507 f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
508 idx += 1;
509 }
510 }
511 } else {
512 for arr in keys.downcast_iter() {
513 for k in arr.values_iter() {
514 f(idx, Some(random_state.tot_hash_one(k)));
515 idx += 1;
516 }
517 }
518 }
519}
520
521unsafe fn for_each_hash_subset_single<T, F>(
524 keys: &ChunkedArray<T>,
525 subset: &[IdxSize],
526 random_state: &PlRandomState,
527 mut f: F,
528) where
529 T: PolarsDataType,
530 for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
531 F: FnMut(IdxSize, Option<u64>),
532{
533 let keys_arr = keys.downcast_as_array();
534
535 if keys_arr.has_nulls() {
536 for idx in subset {
537 let opt_k = keys_arr.get_unchecked(*idx as usize);
538 f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
539 }
540 } else {
541 for idx in subset {
542 let k = keys_arr.value_unchecked(*idx as usize);
543 f(*idx, Some(random_state.tot_hash_one(k)));
544 }
545 }
546}