polars_expr/
hash_keys.rs

1use arrow::array::{BinaryArray, PrimitiveArray, UInt64Array};
2use arrow::compute::utils::combine_validities_and_many;
3use polars_compute::gather::binary::take_unchecked;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
6use polars_core::prelude::PlRandomState;
7use polars_core::series::Series;
8use polars_utils::cardinality_sketch::CardinalitySketch;
9use polars_utils::hashing::HashPartitioner;
10use polars_utils::index::ChunkId;
11use polars_utils::itertools::Itertools;
12use polars_utils::vec::PushUnchecked;
13use polars_utils::IdxSize;
14
15/// Represents a DataFrame plus a hash per row, intended for keys in grouping
16/// or joining. The hashes may or may not actually be physically pre-computed,
17/// this depends per type.
18#[derive(Clone, Debug)]
19pub enum HashKeys {
20    RowEncoded(RowEncodedKeys),
21    Single(SingleKeys),
22}
23
24impl HashKeys {
25    pub fn from_df(
26        df: &DataFrame,
27        random_state: PlRandomState,
28        null_is_valid: bool,
29        force_row_encoding: bool,
30    ) -> Self {
31        if df.width() > 1 || force_row_encoding {
32            let keys = df.get_columns();
33            let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
34
35            if !null_is_valid {
36                let validities = keys
37                    .iter()
38                    .map(|c| c.as_materialized_series().rechunk_validity())
39                    .collect_vec();
40                let combined = combine_validities_and_many(&validities);
41                keys_encoded.set_validity(combined);
42            }
43
44            // TODO: use vechash? Not supported yet for lists.
45            // let mut hashes = Vec::with_capacity(df.height());
46            // columns_to_hashes(df.get_columns(), Some(random_state), &mut hashes).unwrap();
47
48            let hashes = keys_encoded
49                .values_iter()
50                .map(|k| random_state.hash_one(k))
51                .collect();
52            Self::RowEncoded(RowEncodedKeys {
53                hashes: PrimitiveArray::from_vec(hashes),
54                keys: keys_encoded,
55            })
56        } else {
57            todo!()
58            // Self::Single(SingleKeys {
59            //     random_state,
60            //     hashes: todo!(),
61            //     keys: df[0].as_materialized_series().clone(),
62            // })
63        }
64    }
65
66    pub fn len(&self) -> usize {
67        match self {
68            HashKeys::RowEncoded(s) => s.keys.len(),
69            HashKeys::Single(s) => s.keys.len(),
70        }
71    }
72
73    pub fn is_empty(&self) -> bool {
74        self.len() == 0
75    }
76
77    /// After this call partition_idxs[p] will contain the indices of hashes
78    /// that belong to partition p, and the cardinality sketches are updated
79    /// accordingly.
80    pub fn gen_partition_idxs(
81        &self,
82        partitioner: &HashPartitioner,
83        partition_idxs: &mut [Vec<IdxSize>],
84        sketches: &mut [CardinalitySketch],
85        partition_nulls: bool,
86    ) {
87        if sketches.is_empty() {
88            match self {
89                Self::RowEncoded(s) => s.gen_partition_idxs::<false>(
90                    partitioner,
91                    partition_idxs,
92                    sketches,
93                    partition_nulls,
94                ),
95                Self::Single(s) => s.gen_partition_idxs::<false>(
96                    partitioner,
97                    partition_idxs,
98                    sketches,
99                    partition_nulls,
100                ),
101            }
102        } else {
103            match self {
104                Self::RowEncoded(s) => s.gen_partition_idxs::<true>(
105                    partitioner,
106                    partition_idxs,
107                    sketches,
108                    partition_nulls,
109                ),
110                Self::Single(s) => s.gen_partition_idxs::<true>(
111                    partitioner,
112                    partition_idxs,
113                    sketches,
114                    partition_nulls,
115                ),
116            }
117        }
118    }
119
120    /// Generates indices for a chunked gather such that the ith key gathers
121    /// the next gathers_per_key[i] elements from the partition[i]th chunk.
122    pub fn gen_partitioned_gather_idxs(
123        &self,
124        partitioner: &HashPartitioner,
125        gathers_per_key: &[IdxSize],
126        gather_idxs: &mut Vec<ChunkId<32>>,
127    ) {
128        match self {
129            Self::RowEncoded(s) => {
130                s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs)
131            },
132            Self::Single(s) => {
133                s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs)
134            },
135        }
136    }
137
138    /// # Safety
139    /// The indices must be in-bounds.
140    pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
141        match self {
142            Self::RowEncoded(s) => Self::RowEncoded(s.gather(idxs)),
143            Self::Single(s) => Self::Single(s.gather(idxs)),
144        }
145    }
146}
147
148#[derive(Clone, Debug)]
149pub struct RowEncodedKeys {
150    pub hashes: UInt64Array,
151    pub keys: BinaryArray<i64>,
152}
153
154impl RowEncodedKeys {
155    pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
156        &self,
157        partitioner: &HashPartitioner,
158        partition_idxs: &mut [Vec<IdxSize>],
159        sketches: &mut [CardinalitySketch],
160        partition_nulls: bool,
161    ) {
162        assert!(partition_idxs.len() == partitioner.num_partitions());
163        assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
164        for p in partition_idxs.iter_mut() {
165            p.clear();
166        }
167
168        if let Some(validity) = self.keys.validity() {
169            for (i, (h, is_v)) in self.hashes.values_iter().zip(validity).enumerate() {
170                if is_v {
171                    unsafe {
172                        // SAFETY: we assured the number of partitions matches.
173                        let p = partitioner.hash_to_partition(*h);
174                        partition_idxs.get_unchecked_mut(p).push(i as IdxSize);
175                        if BUILD_SKETCHES {
176                            sketches.get_unchecked_mut(p).insert(*h);
177                        }
178                    }
179                } else if partition_nulls {
180                    // Arbitrarily put nulls in partition 0.
181                    unsafe {
182                        partition_idxs.get_unchecked_mut(0).push(i as IdxSize);
183                    }
184                }
185            }
186        } else {
187            for (i, h) in self.hashes.values_iter().enumerate() {
188                unsafe {
189                    // SAFETY: we assured the number of partitions matches.
190                    let p = partitioner.hash_to_partition(*h);
191                    partition_idxs.get_unchecked_mut(p).push(i as IdxSize);
192                    if BUILD_SKETCHES {
193                        sketches.get_unchecked_mut(p).insert(*h);
194                    }
195                }
196            }
197        }
198    }
199
200    pub fn gen_partitioned_gather_idxs(
201        &self,
202        partitioner: &HashPartitioner,
203        gathers_per_key: &[IdxSize],
204        gather_idxs: &mut Vec<ChunkId<32>>,
205    ) {
206        assert!(gathers_per_key.len() == self.keys.len());
207        unsafe {
208            let mut offsets = vec![0; partitioner.num_partitions()];
209            for (hash, &n) in self.hashes.values_iter().zip(gathers_per_key) {
210                let p = partitioner.hash_to_partition(*hash);
211                let offset = *offsets.get_unchecked(p);
212                for i in offset..offset + n {
213                    gather_idxs.push(ChunkId::store(p as IdxSize, i));
214                }
215                *offsets.get_unchecked_mut(p) += n;
216            }
217        }
218    }
219
220    /// # Safety
221    /// The indices must be in-bounds.
222    pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
223        let mut hashes = Vec::with_capacity(idxs.len());
224        for idx in idxs {
225            hashes.push_unchecked(*self.hashes.values().get_unchecked(*idx as usize));
226        }
227        let idx_arr = arrow::ffi::mmap::slice(idxs);
228        let keys = take_unchecked(&self.keys, &idx_arr);
229        Self {
230            hashes: PrimitiveArray::from_vec(hashes),
231            keys,
232        }
233    }
234}
235
236/// Single keys. Does not pre-hash for boolean & integer types, only for strings
237/// and nested types.
238#[derive(Clone, Debug)]
239pub struct SingleKeys {
240    pub random_state: PlRandomState,
241    pub hashes: Option<Vec<u64>>,
242    pub keys: Series,
243}
244
245impl SingleKeys {
246    pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
247        &self,
248        partitioner: &HashPartitioner,
249        partition_idxs: &mut [Vec<IdxSize>],
250        _sketches: &mut [CardinalitySketch],
251        _partition_nulls: bool,
252    ) {
253        assert!(partitioner.num_partitions() == partition_idxs.len());
254        for p in partition_idxs.iter_mut() {
255            p.clear();
256        }
257
258        todo!()
259    }
260
261    #[allow(clippy::ptr_arg)] // Remove when implemented.
262    pub fn gen_partitioned_gather_idxs(
263        &self,
264        _partitioner: &HashPartitioner,
265        _gathers_per_key: &[IdxSize],
266        _gather_idxs: &mut Vec<ChunkId<32>>,
267    ) {
268        todo!()
269    }
270
271    /// # Safety
272    /// The indices must be in-bounds.
273    pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
274        let hashes = self.hashes.as_ref().map(|hashes| {
275            let mut out = Vec::with_capacity(idxs.len());
276            for idx in idxs {
277                out.push_unchecked(*hashes.get_unchecked(*idx as usize));
278            }
279            out
280        });
281        Self {
282            random_state: self.random_state.clone(),
283            hashes,
284            keys: self.keys.take_slice_unchecked(idxs),
285        }
286    }
287}