1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
use std::hash::{BuildHasher, Hash};

use hashbrown::hash_map::{Entry, RawEntryMut};
use hashbrown::HashMap;
use polars_utils::{flatten, HashSingle};
use rayon::prelude::*;

use super::GroupsProxy;
use crate::datatypes::PlHashMap;
use crate::frame::groupby::{GroupsIdx, IdxItem};
use crate::prelude::compare_inner::PartialEqInner;
use crate::prelude::*;
use crate::utils::{split_df, CustomIterTools};
use crate::vector_hasher::{
    df_rows_to_hashes_threaded, this_partition, AsU64, IdBuildHasher, IdxHash,
};
use crate::POOL;

fn finish_group_order(mut out: Vec<Vec<IdxItem>>, sorted: bool) -> GroupsProxy {
    if sorted {
        // we can just take the first value, no need to flatten
        let mut out = if out.len() == 1 {
            out.pop().unwrap()
        } else {
            // pre-sort every array
            // this will make the final single threaded sort much faster
            POOL.install(|| {
                out.par_iter_mut()
                    .for_each(|g| g.sort_unstable_by_key(|g| g.0))
            });

            flatten(&out, None)
        };
        out.sort_unstable_by_key(|g| g.0);
        let mut idx = GroupsIdx::from_iter(out.into_iter());
        idx.sorted = true;
        GroupsProxy::Idx(idx)
    } else {
        // we can just take the first value, no need to flatten
        if out.len() == 1 {
            GroupsProxy::Idx(GroupsIdx::from(out.pop().unwrap()))
        } else {
            // flattens
            GroupsProxy::Idx(GroupsIdx::from(out))
        }
    }
}

// We must strike a balance between cache coherence and resizing costs.
// Overallocation seems a lot more expensive than resizing so we start reasonable small.
pub(crate) const HASHMAP_INIT_SIZE: usize = 512;

pub(crate) fn groupby<T>(a: impl Iterator<Item = T>, sorted: bool) -> GroupsProxy
where
    T: Hash + Eq,
{
    let mut hash_tbl: PlHashMap<T, (IdxSize, Vec<IdxSize>)> =
        PlHashMap::with_capacity(HASHMAP_INIT_SIZE);
    let mut cnt = 0;
    a.for_each(|k| {
        let idx = cnt;
        cnt += 1;
        let entry = hash_tbl.entry(k);

        match entry {
            Entry::Vacant(entry) => {
                entry.insert((idx, vec![idx]));
            }
            Entry::Occupied(mut entry) => {
                let v = entry.get_mut();
                v.1.push(idx);
            }
        }
    });
    if sorted {
        let mut groups = hash_tbl
            .into_iter()
            .map(|(_k, v)| v)
            .collect_trusted::<Vec<_>>();
        groups.sort_unstable_by_key(|g| g.0);
        let mut idx: GroupsIdx = groups.into_iter().collect();
        idx.sorted = true;
        GroupsProxy::Idx(idx)
    } else {
        GroupsProxy::Idx(hash_tbl.into_values().collect())
    }
}

/// Determine group tuples over different threads. The hash of the key is used to determine the partitions.
/// Note that rehashing of the keys should be cheap and the keys small to allow efficient rehashing and improved cache locality.
///
/// Besides numeric values, we can also use this for pre hashed strings. The keys are simply a ptr to the str + precomputed hash.
/// The hash will be used to rehash, and the str will be used for equality.
pub(crate) fn groupby_threaded_num<T, IntoSlice>(
    keys: Vec<IntoSlice>,
    group_size_hint: usize,
    n_partitions: u64,
    sorted: bool,
) -> GroupsProxy
where
    T: Send + Hash + Eq + Sync + Copy + AsU64,
    IntoSlice: AsRef<[T]> + Send + Sync,
{
    assert!(n_partitions.is_power_of_two());

    // We will create a hashtable in every thread.
    // We use the hash to partition the keys to the matching hashtable.
    // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition.
    let out = POOL
        .install(|| {
            (0..n_partitions).into_par_iter().map(|thread_no| {
                let mut hash_tbl: PlHashMap<T, (IdxSize, Vec<IdxSize>)> =
                    PlHashMap::with_capacity(HASHMAP_INIT_SIZE);

                let mut offset = 0;
                for keys in &keys {
                    let keys = keys.as_ref();
                    let len = keys.len() as IdxSize;
                    let hasher = hash_tbl.hasher().clone();

                    let mut cnt = 0;
                    keys.iter().for_each(|k| {
                        let idx = cnt + offset;
                        cnt += 1;

                        if this_partition(k.as_u64(), thread_no, n_partitions) {
                            let hash = hasher.hash_single(k);
                            let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, k);

                            match entry {
                                RawEntryMut::Vacant(entry) => {
                                    let mut tuples = Vec::with_capacity(group_size_hint);
                                    tuples.push(idx);
                                    entry.insert_with_hasher(hash, *k, (idx, tuples), |k| {
                                        hasher.hash_single(k)
                                    });
                                }
                                RawEntryMut::Occupied(mut entry) => {
                                    let v = entry.get_mut();
                                    v.1.push(idx);
                                }
                            }
                        }
                    });
                    offset += len;
                }
                hash_tbl
                    .into_iter()
                    .map(|(_k, v)| v)
                    .collect_trusted::<Vec<_>>()
            })
        })
        .collect::<Vec<_>>();
    finish_group_order(out, sorted)
}

/// Utility function used as comparison function in the hashmap.
/// The rationale is that equality is an AND operation and therefore its probability of success
/// declines rapidly with the number of keys. Instead of first copying an entire row from both
/// sides and then do the comparison, we do the comparison value by value catching early failures
/// eagerly.
///
/// # Safety
/// Doesn't check any bounds
#[inline]
pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool {
    for s in keys.get_columns() {
        if !s.equal_element(idx_a, idx_b, s) {
            return false;
        }
    }
    true
}

/// Populate a multiple key hashmap with row indexes.
/// Instead of the keys (which could be very large), the row indexes are stored.
/// To check if a row is equal the original DataFrame is also passed as ref.
/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared
/// on equality.
pub(crate) fn populate_multiple_key_hashmap<V, H, F, G>(
    hash_tbl: &mut HashMap<IdxHash, V, H>,
    // row index
    idx: IdxSize,
    // hash
    original_h: u64,
    // keys of the hash table (will not be inserted, the indexes will be used)
    // the keys are needed for the equality check
    keys: &DataFrame,
    // value to insert
    vacant_fn: G,
    // function that gets a mutable ref to the occupied value in the hash table
    occupied_fn: F,
) where
    G: Fn() -> V,
    F: Fn(&mut V),
    H: BuildHasher,
{
    let entry = hash_tbl
        .raw_entry_mut()
        // uses the idx to probe rows in the original DataFrame with keys
        // to check equality to find an entry
        // this does not invalidate the hashmap as this equality function is not used
        // during rehashing/resize (then the keys are already known to be unique).
        // Only during insertion and probing an equality function is needed
        .from_hash(original_h, |idx_hash| {
            // first check the hash values
            // before we incur a cache miss
            idx_hash.hash == original_h && {
                let key_idx = idx_hash.idx;
                // Safety:
                // indices in a groupby operation are always in bounds.
                unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) }
            }
        });
    match entry {
        RawEntryMut::Vacant(entry) => {
            entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn());
        }
        RawEntryMut::Occupied(mut entry) => {
            let (_k, v) = entry.get_key_value_mut();
            occupied_fn(v);
        }
    }
}

#[inline]
pub(crate) unsafe fn compare_keys<'a>(
    keys_cmp: &'a [Box<dyn PartialEqInner + 'a>],
    idx_a: usize,
    idx_b: usize,
) -> bool {
    for cmp in keys_cmp {
        if !cmp.eq_element_unchecked(idx_a, idx_b) {
            return false;
        }
    }
    true
}

// Differs in the because this one uses the PartialEqInner trait objects
// is faster when multiple chunks. Not yet used in join.
pub(crate) fn populate_multiple_key_hashmap2<'a, V, H, F, G>(
    hash_tbl: &mut HashMap<IdxHash, V, H>,
    // row index
    idx: IdxSize,
    // hash
    original_h: u64,
    // keys of the hash table (will not be inserted, the indexes will be used)
    // the keys are needed for the equality check
    keys_cmp: &'a [Box<dyn PartialEqInner + 'a>],
    // value to insert
    vacant_fn: G,
    // function that gets a mutable ref to the occupied value in the hash table
    occupied_fn: F,
) where
    G: Fn() -> V,
    F: Fn(&mut V),
    H: BuildHasher,
{
    let entry = hash_tbl
        .raw_entry_mut()
        // uses the idx to probe rows in the original DataFrame with keys
        // to check equality to find an entry
        // this does not invalidate the hashmap as this equality function is not used
        // during rehashing/resize (then the keys are already known to be unique).
        // Only during insertion and probing an equality function is needed
        .from_hash(original_h, |idx_hash| {
            // first check the hash values before we incur
            // cache misses
            original_h == idx_hash.hash && {
                let key_idx = idx_hash.idx;
                // Safety:
                // indices in a groupby operation are always in bounds.
                unsafe { compare_keys(keys_cmp, key_idx as usize, idx as usize) }
            }
        });
    match entry {
        RawEntryMut::Vacant(entry) => {
            entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn());
        }
        RawEntryMut::Occupied(mut entry) => {
            let (_k, v) = entry.get_key_value_mut();
            occupied_fn(v);
        }
    }
}

pub(crate) fn groupby_threaded_multiple_keys_flat(
    mut keys: DataFrame,
    n_partitions: usize,
    sorted: bool,
) -> PolarsResult<GroupsProxy> {
    let dfs = split_df(&mut keys, n_partitions).unwrap();
    let (hashes, _random_state) = df_rows_to_hashes_threaded(&dfs, None)?;
    let n_partitions = n_partitions as u64;

    // trait object to compare inner types.
    let keys_cmp = keys
        .iter()
        .map(|s| s.into_partial_eq_inner())
        .collect::<Vec<_>>();

    // We will create a hashtable in every thread.
    // We use the hash to partition the keys to the matching hashtable.
    // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition.
    let groups = POOL
        .install(|| {
            (0..n_partitions).into_par_iter().map(|thread_no| {
                let hashes = &hashes;

                let mut hash_tbl: HashMap<IdxHash, (IdxSize, Vec<IdxSize>), IdBuildHasher> =
                    HashMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default());

                let mut offset = 0;
                for hashes in hashes {
                    let len = hashes.len() as IdxSize;

                    let mut idx = 0;
                    for hashes_chunk in hashes.data_views() {
                        for &h in hashes_chunk {
                            // partition hashes by thread no.
                            // So only a part of the hashes go to this hashmap
                            if this_partition(h, thread_no, n_partitions) {
                                let idx = idx + offset;
                                populate_multiple_key_hashmap2(
                                    &mut hash_tbl,
                                    idx,
                                    h,
                                    &keys_cmp,
                                    || (idx, vec![idx]),
                                    |v| v.1.push(idx),
                                );
                            }
                            idx += 1;
                        }
                    }

                    offset += len;
                }
                hash_tbl.into_iter().map(|(_k, v)| v).collect::<Vec<_>>()
            })
        })
        .collect::<Vec<_>>();
    Ok(finish_group_order(groups, sorted))
}