#![allow(clippy::unnecessary_cast)] #![allow(unsafe_op_in_unsafe_fn)]
use arrow::array::{Array, View};
use polars_buffer::Buffer;
use polars_compute::binview_index_map::{BinaryViewIndexMap, Entry};
use polars_utils::idx_vec::UnitVec;
use polars_utils::itertools::Itertools;
use polars_utils::relaxed_cell::RelaxedCell;
use polars_utils::unitvec;
use super::*;
use crate::hash_keys::HashKeys;
pub struct BinviewKeyIdxTable {
idx_map: BinaryViewIndexMap<UnitVec<RelaxedCell<u64>>>,
idx_offset: IdxSize,
null_keys: Vec<IdxSize>,
nulls_emitted: RelaxedCell<bool>,
}
impl BinviewKeyIdxTable {
pub fn new() -> Self {
Self {
idx_map: BinaryViewIndexMap::default(),
idx_offset: 0,
null_keys: Vec::new(),
nulls_emitted: RelaxedCell::from(false),
}
}
#[inline(always)]
unsafe fn probe_one<const MARK_MATCHES: bool>(
&self,
key_idx: IdxSize,
hash: u64,
key: &View,
buffers: &[Buffer<u8>],
table_match: &mut Vec<IdxSize>,
probe_match: &mut Vec<IdxSize>,
) -> bool {
if let Some(idxs) = unsafe { self.idx_map.get_view(hash, key, buffers) } {
for idx in &idxs[..] {
table_match.push((idx.load() & !(1 << 63)) as IdxSize);
probe_match.push(key_idx);
}
if MARK_MATCHES {
let first_idx = unsafe { idxs.get_unchecked(0) };
let first_idx_val = first_idx.load();
if first_idx_val >> 63 == 0 {
first_idx.store(first_idx_val | (1 << 63));
}
}
true
} else {
false
}
}
unsafe fn probe_impl<
'a,
const MARK_MATCHES: bool,
const EMIT_UNMATCHED: bool,
const NULL_IS_VALID: bool,
>(
&self,
keys: impl Iterator<Item = (IdxSize, u64, Option<&'a View>)>,
buffers: &[Buffer<u8>],
table_match: &mut Vec<IdxSize>,
probe_match: &mut Vec<IdxSize>,
limit: IdxSize,
) -> IdxSize {
let mut keys_processed = 0;
for (key_idx, hash, key) in keys {
let found_match = if let Some(key) = key {
self.probe_one::<MARK_MATCHES>(
key_idx,
hash,
key,
buffers,
table_match,
probe_match,
)
} else if NULL_IS_VALID {
for idx in &self.null_keys {
table_match.push(*idx);
probe_match.push(key_idx);
}
if MARK_MATCHES && !self.nulls_emitted.load() {
self.nulls_emitted.store(true);
}
!self.null_keys.is_empty()
} else {
false
};
if EMIT_UNMATCHED && !found_match {
table_match.push(IdxSize::MAX);
probe_match.push(key_idx);
}
keys_processed += 1;
if table_match.len() >= limit as usize {
break;
}
}
keys_processed
}
#[allow(clippy::too_many_arguments)]
unsafe fn probe_dispatch<'a>(
&self,
keys: impl Iterator<Item = (IdxSize, u64, Option<&'a View>)>,
buffers: &[Buffer<u8>],
table_match: &mut Vec<IdxSize>,
probe_match: &mut Vec<IdxSize>,
mark_matches: bool,
emit_unmatched: bool,
null_is_valid: bool,
limit: IdxSize,
) -> IdxSize {
match (mark_matches, emit_unmatched, null_is_valid) {
(false, false, false) => self.probe_impl::<false, false, false>(
keys,
buffers,
table_match,
probe_match,
limit,
),
(false, false, true) => self.probe_impl::<false, false, true>(
keys,
buffers,
table_match,
probe_match,
limit,
),
(false, true, false) => self.probe_impl::<false, true, false>(
keys,
buffers,
table_match,
probe_match,
limit,
),
(false, true, true) => {
self.probe_impl::<false, true, true>(keys, buffers, table_match, probe_match, limit)
},
(true, false, false) => self.probe_impl::<true, false, false>(
keys,
buffers,
table_match,
probe_match,
limit,
),
(true, false, true) => {
self.probe_impl::<true, false, true>(keys, buffers, table_match, probe_match, limit)
},
(true, true, false) => {
self.probe_impl::<true, true, false>(keys, buffers, table_match, probe_match, limit)
},
(true, true, true) => {
self.probe_impl::<true, true, true>(keys, buffers, table_match, probe_match, limit)
},
}
}
}
impl IdxTable for BinviewKeyIdxTable {
fn new_empty(&self) -> Box<dyn IdxTable> {
Box::new(Self::new())
}
fn reserve(&mut self, additional: usize) {
self.idx_map.reserve(additional);
}
fn num_keys(&self) -> IdxSize {
self.idx_map.len()
}
fn insert_keys(&mut self, _hash_keys: &HashKeys, _track_unmatchable: bool) {
unimplemented!()
}
unsafe fn insert_keys_subset(
&mut self,
hash_keys: &HashKeys,
subset: &[IdxSize],
track_unmatchable: bool,
) {
let HashKeys::Binview(hash_keys) = hash_keys else {
unreachable!()
};
let new_idx_offset = (self.idx_offset as usize)
.checked_add(subset.len())
.unwrap();
assert!(
new_idx_offset < IdxSize::MAX as usize,
"overly large index in BinviewKeyIdxTable"
);
unsafe {
let buffers = hash_keys.keys.data_buffers();
let views = hash_keys.keys.views();
if let Some(validity) = hash_keys.keys.validity() {
for (i, subset_idx) in subset.iter().enumerate_idx() {
let hash = hash_keys.hashes.value_unchecked(*subset_idx as usize);
let key = views.get_unchecked(*subset_idx as usize);
let idx = self.idx_offset + i;
if validity.get_bit_unchecked(*subset_idx as usize) {
match self.idx_map.entry_view(hash, *key, buffers) {
Entry::Occupied(o) => {
o.into_mut().push(RelaxedCell::from(idx as u64));
},
Entry::Vacant(v) => {
v.insert(unitvec![RelaxedCell::from(idx as u64)]);
},
}
} else if track_unmatchable | hash_keys.null_is_valid {
self.null_keys.push(idx);
}
}
} else {
for (i, subset_idx) in subset.iter().enumerate_idx() {
let hash = hash_keys.hashes.value_unchecked(*subset_idx as usize);
let key = views.get_unchecked(*subset_idx as usize);
let idx = self.idx_offset + i;
match self.idx_map.entry_view(hash, *key, buffers) {
Entry::Occupied(o) => {
o.into_mut().push(RelaxedCell::from(idx as u64));
},
Entry::Vacant(v) => {
v.insert(unitvec![RelaxedCell::from(idx as u64)]);
},
}
}
}
}
self.idx_offset = new_idx_offset as IdxSize;
}
fn probe(
&self,
_hash_keys: &HashKeys,
_table_match: &mut Vec<IdxSize>,
_probe_match: &mut Vec<IdxSize>,
_mark_matches: bool,
_emit_unmatched: bool,
_limit: IdxSize,
) -> IdxSize {
unimplemented!()
}
unsafe fn probe_subset(
&self,
hash_keys: &HashKeys,
subset: &[IdxSize],
table_match: &mut Vec<IdxSize>,
probe_match: &mut Vec<IdxSize>,
mark_matches: bool,
emit_unmatched: bool,
limit: IdxSize,
) -> IdxSize {
let HashKeys::Binview(hash_keys) = hash_keys else {
unreachable!()
};
unsafe {
let buffers = hash_keys.keys.data_buffers();
let views = hash_keys.keys.views();
if let Some(validity) = hash_keys.keys.validity() {
let iter = subset.iter().map(|i| {
(
*i,
hash_keys.hashes.value_unchecked(*i as usize),
if validity.get_bit_unchecked(*i as usize) {
Some(views.get_unchecked(*i as usize))
} else {
None
},
)
});
self.probe_dispatch(
iter,
buffers,
table_match,
probe_match,
mark_matches,
emit_unmatched,
hash_keys.null_is_valid,
limit,
)
} else {
let iter = subset.iter().map(|i| {
(
*i,
hash_keys.hashes.value_unchecked(*i as usize),
Some(views.get_unchecked(*i as usize)),
)
});
self.probe_dispatch(
iter,
buffers,
table_match,
probe_match,
mark_matches,
emit_unmatched,
false, limit,
)
}
}
}
fn unmarked_keys(
&self,
out: &mut Vec<IdxSize>,
mut offset: IdxSize,
limit: IdxSize,
) -> IdxSize {
out.clear();
let mut keys_processed = 0;
if !self.nulls_emitted.load() {
if (offset as usize) < self.null_keys.len() {
out.extend(
self.null_keys[offset as usize..]
.iter()
.copied()
.take(limit as usize),
);
keys_processed += out.len() as IdxSize;
offset += out.len() as IdxSize;
if out.len() >= limit as usize {
return keys_processed;
}
}
offset -= self.null_keys.len() as IdxSize;
}
while let Some((_, _, idxs)) = self.idx_map.get_index(offset) {
let first_idx = unsafe { idxs.get_unchecked(0) };
let first_idx_val = first_idx.load();
if first_idx_val >> 63 == 0 {
for idx in &idxs[..] {
out.push((idx.load() & !(1 << 63)) as IdxSize);
}
}
keys_processed += 1;
offset += 1;
if out.len() >= limit as usize {
break;
}
}
keys_processed
}
}