1use arrow::array::{BinaryArray, PrimitiveArray, UInt64Array};
2use arrow::compute::utils::combine_validities_and_many;
3use polars_compute::gather::binary::take_unchecked;
4use polars_core::frame::DataFrame;
5use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
6use polars_core::prelude::PlRandomState;
7use polars_core::series::Series;
8use polars_utils::cardinality_sketch::CardinalitySketch;
9use polars_utils::hashing::HashPartitioner;
10use polars_utils::index::ChunkId;
11use polars_utils::itertools::Itertools;
12use polars_utils::vec::PushUnchecked;
13use polars_utils::IdxSize;
14
15#[derive(Clone, Debug)]
19pub enum HashKeys {
20 RowEncoded(RowEncodedKeys),
21 Single(SingleKeys),
22}
23
24impl HashKeys {
25 pub fn from_df(
26 df: &DataFrame,
27 random_state: PlRandomState,
28 null_is_valid: bool,
29 force_row_encoding: bool,
30 ) -> Self {
31 if df.width() > 1 || force_row_encoding {
32 let keys = df.get_columns();
33 let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
34
35 if !null_is_valid {
36 let validities = keys
37 .iter()
38 .map(|c| c.as_materialized_series().rechunk_validity())
39 .collect_vec();
40 let combined = combine_validities_and_many(&validities);
41 keys_encoded.set_validity(combined);
42 }
43
44 let hashes = keys_encoded
49 .values_iter()
50 .map(|k| random_state.hash_one(k))
51 .collect();
52 Self::RowEncoded(RowEncodedKeys {
53 hashes: PrimitiveArray::from_vec(hashes),
54 keys: keys_encoded,
55 })
56 } else {
57 todo!()
58 }
64 }
65
66 pub fn len(&self) -> usize {
67 match self {
68 HashKeys::RowEncoded(s) => s.keys.len(),
69 HashKeys::Single(s) => s.keys.len(),
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76
77 pub fn gen_partition_idxs(
81 &self,
82 partitioner: &HashPartitioner,
83 partition_idxs: &mut [Vec<IdxSize>],
84 sketches: &mut [CardinalitySketch],
85 partition_nulls: bool,
86 ) {
87 if sketches.is_empty() {
88 match self {
89 Self::RowEncoded(s) => s.gen_partition_idxs::<false>(
90 partitioner,
91 partition_idxs,
92 sketches,
93 partition_nulls,
94 ),
95 Self::Single(s) => s.gen_partition_idxs::<false>(
96 partitioner,
97 partition_idxs,
98 sketches,
99 partition_nulls,
100 ),
101 }
102 } else {
103 match self {
104 Self::RowEncoded(s) => s.gen_partition_idxs::<true>(
105 partitioner,
106 partition_idxs,
107 sketches,
108 partition_nulls,
109 ),
110 Self::Single(s) => s.gen_partition_idxs::<true>(
111 partitioner,
112 partition_idxs,
113 sketches,
114 partition_nulls,
115 ),
116 }
117 }
118 }
119
120 pub fn gen_partitioned_gather_idxs(
123 &self,
124 partitioner: &HashPartitioner,
125 gathers_per_key: &[IdxSize],
126 gather_idxs: &mut Vec<ChunkId<32>>,
127 ) {
128 match self {
129 Self::RowEncoded(s) => {
130 s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs)
131 },
132 Self::Single(s) => {
133 s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs)
134 },
135 }
136 }
137
138 pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
141 match self {
142 Self::RowEncoded(s) => Self::RowEncoded(s.gather(idxs)),
143 Self::Single(s) => Self::Single(s.gather(idxs)),
144 }
145 }
146}
147
148#[derive(Clone, Debug)]
149pub struct RowEncodedKeys {
150 pub hashes: UInt64Array,
151 pub keys: BinaryArray<i64>,
152}
153
154impl RowEncodedKeys {
155 pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
156 &self,
157 partitioner: &HashPartitioner,
158 partition_idxs: &mut [Vec<IdxSize>],
159 sketches: &mut [CardinalitySketch],
160 partition_nulls: bool,
161 ) {
162 assert!(partition_idxs.len() == partitioner.num_partitions());
163 assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
164 for p in partition_idxs.iter_mut() {
165 p.clear();
166 }
167
168 if let Some(validity) = self.keys.validity() {
169 for (i, (h, is_v)) in self.hashes.values_iter().zip(validity).enumerate() {
170 if is_v {
171 unsafe {
172 let p = partitioner.hash_to_partition(*h);
174 partition_idxs.get_unchecked_mut(p).push(i as IdxSize);
175 if BUILD_SKETCHES {
176 sketches.get_unchecked_mut(p).insert(*h);
177 }
178 }
179 } else if partition_nulls {
180 unsafe {
182 partition_idxs.get_unchecked_mut(0).push(i as IdxSize);
183 }
184 }
185 }
186 } else {
187 for (i, h) in self.hashes.values_iter().enumerate() {
188 unsafe {
189 let p = partitioner.hash_to_partition(*h);
191 partition_idxs.get_unchecked_mut(p).push(i as IdxSize);
192 if BUILD_SKETCHES {
193 sketches.get_unchecked_mut(p).insert(*h);
194 }
195 }
196 }
197 }
198 }
199
200 pub fn gen_partitioned_gather_idxs(
201 &self,
202 partitioner: &HashPartitioner,
203 gathers_per_key: &[IdxSize],
204 gather_idxs: &mut Vec<ChunkId<32>>,
205 ) {
206 assert!(gathers_per_key.len() == self.keys.len());
207 unsafe {
208 let mut offsets = vec![0; partitioner.num_partitions()];
209 for (hash, &n) in self.hashes.values_iter().zip(gathers_per_key) {
210 let p = partitioner.hash_to_partition(*hash);
211 let offset = *offsets.get_unchecked(p);
212 for i in offset..offset + n {
213 gather_idxs.push(ChunkId::store(p as IdxSize, i));
214 }
215 *offsets.get_unchecked_mut(p) += n;
216 }
217 }
218 }
219
220 pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
223 let mut hashes = Vec::with_capacity(idxs.len());
224 for idx in idxs {
225 hashes.push_unchecked(*self.hashes.values().get_unchecked(*idx as usize));
226 }
227 let idx_arr = arrow::ffi::mmap::slice(idxs);
228 let keys = take_unchecked(&self.keys, &idx_arr);
229 Self {
230 hashes: PrimitiveArray::from_vec(hashes),
231 keys,
232 }
233 }
234}
235
236#[derive(Clone, Debug)]
239pub struct SingleKeys {
240 pub random_state: PlRandomState,
241 pub hashes: Option<Vec<u64>>,
242 pub keys: Series,
243}
244
245impl SingleKeys {
246 pub fn gen_partition_idxs<const BUILD_SKETCHES: bool>(
247 &self,
248 partitioner: &HashPartitioner,
249 partition_idxs: &mut [Vec<IdxSize>],
250 _sketches: &mut [CardinalitySketch],
251 _partition_nulls: bool,
252 ) {
253 assert!(partitioner.num_partitions() == partition_idxs.len());
254 for p in partition_idxs.iter_mut() {
255 p.clear();
256 }
257
258 todo!()
259 }
260
261 #[allow(clippy::ptr_arg)] pub fn gen_partitioned_gather_idxs(
263 &self,
264 _partitioner: &HashPartitioner,
265 _gathers_per_key: &[IdxSize],
266 _gather_idxs: &mut Vec<ChunkId<32>>,
267 ) {
268 todo!()
269 }
270
271 pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self {
274 let hashes = self.hashes.as_ref().map(|hashes| {
275 let mut out = Vec::with_capacity(idxs.len());
276 for idx in idxs {
277 out.push_unchecked(*hashes.get_unchecked(*idx as usize));
278 }
279 out
280 });
281 Self {
282 random_state: self.random_state.clone(),
283 hashes,
284 keys: self.keys.take_slice_unchecked(idxs),
285 }
286 }
287}