use arrow::array::Array;
use polars_row::RowEncodingOptions;
use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry};
use polars_utils::itertools::Itertools;
use polars_utils::vec::PushUnchecked;
use self::row_encode::get_row_encoding_context;
use super::*;
use crate::hash_keys::HashKeys;
#[derive(Default)]
pub struct RowEncodedHashGrouper {
idx_map: BytesIndexMap<()>,
}
impl RowEncodedHashGrouper {
pub fn new() -> Self {
Self {
idx_map: BytesIndexMap::new(),
}
}
fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize {
match self.idx_map.entry(hash, key) {
Entry::Occupied(o) => o.index(),
Entry::Vacant(v) => {
let index = v.index();
v.insert(());
index
},
}
}
fn contains_key(&self, hash: u64, key: &[u8]) -> bool {
self.idx_map.contains_key(hash, key)
}
fn finalize_keys(&self, key_schema: &Schema, mut key_rows: Vec<&[u8]>) -> DataFrame {
let key_dtypes = key_schema
.iter()
.map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest()))
.collect::<Vec<_>>();
let ctxts = key_schema
.iter()
.map(|(_, dt)| get_row_encoding_context(dt))
.collect::<Vec<_>>();
let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()];
let key_columns =
unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &ctxts, &key_dtypes) };
let cols = key_schema
.iter()
.zip(key_columns)
.map(|((name, dt), col)| {
let s = Series::try_from((name.clone(), col)).unwrap();
unsafe { s.from_physical_unchecked(dt) }
.unwrap()
.into_column()
})
.collect();
unsafe { DataFrame::new_unchecked_infer_height(cols) }
}
}
impl Grouper for RowEncodedHashGrouper {
fn new_empty(&self) -> Box<dyn Grouper> {
Box::new(Self::new())
}
fn reserve(&mut self, additional: usize) {
self.idx_map.reserve(additional);
}
fn num_groups(&self) -> IdxSize {
self.idx_map.len()
}
unsafe fn insert_keys_subset(
&mut self,
keys: &HashKeys,
subset: &[IdxSize],
group_idxs: Option<&mut Vec<IdxSize>>,
) {
let HashKeys::RowEncoded(keys) = keys else {
unreachable!()
};
unsafe {
if let Some(group_idxs) = group_idxs {
group_idxs.reserve(subset.len());
keys.for_each_hash_subset(subset, |idx, opt_hash| {
if let Some(hash) = opt_hash {
let key = keys.keys.value_unchecked(idx as usize);
group_idxs.push_unchecked(self.insert_key(hash, key));
}
});
} else {
keys.for_each_hash_subset(subset, |idx, opt_hash| {
if let Some(hash) = opt_hash {
let key = keys.keys.value_unchecked(idx as usize);
self.insert_key(hash, key);
}
});
}
}
}
fn get_keys_in_group_order(&self, schema: &Schema) -> DataFrame {
unsafe {
let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize);
for (_, key) in self.idx_map.iter_hash_keys() {
key_rows.push_unchecked(key);
}
self.finalize_keys(schema, key_rows)
}
}
unsafe fn probe_partitioned_groupers(
&self,
groupers: &[Box<dyn Grouper>],
keys: &HashKeys,
partitioner: &HashPartitioner,
invert: bool,
probe_matches: &mut Vec<IdxSize>,
) {
let HashKeys::RowEncoded(keys) = keys else {
unreachable!()
};
assert!(partitioner.num_partitions() == groupers.len());
unsafe {
if keys.keys.has_nulls() {
for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {
let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {
let p = partitioner.hash_to_partition(*hash);
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
let grouper =
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
grouper.contains_key(*hash, key)
} else {
false
};
if has_group != invert {
probe_matches.push(idx);
}
}
} else {
for (idx, (hash, key)) in keys
.hashes
.values_iter()
.zip(keys.keys.values_iter())
.enumerate_idx()
{
let p = partitioner.hash_to_partition(*hash);
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
let grouper =
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
if grouper.contains_key(*hash, key) != invert {
probe_matches.push(idx);
}
}
}
}
}
unsafe fn contains_key_partitioned_groupers(
&self,
groupers: &[Box<dyn Grouper>],
keys: &HashKeys,
partitioner: &HashPartitioner,
invert: bool,
contains_key: &mut BitmapBuilder,
) {
let HashKeys::RowEncoded(keys) = keys else {
unreachable!()
};
assert!(partitioner.num_partitions() == groupers.len());
unsafe {
if keys.keys.has_nulls() {
for (idx, hash) in keys.hashes.values_iter().enumerate_idx() {
let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) {
let p = partitioner.hash_to_partition(*hash);
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
let grouper =
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
grouper.contains_key(*hash, key)
} else {
false
};
contains_key.push(has_group != invert);
}
} else {
for (hash, key) in keys.hashes.values_iter().zip(keys.keys.values_iter()) {
let p = partitioner.hash_to_partition(*hash);
let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p);
let grouper =
&*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper);
contains_key.push(grouper.contains_key(*hash, key) != invert);
}
}
}
}
fn as_any(&self) -> &dyn Any {
self
}
}