use std::hash::{Hash, Hasher};
use rustc_hash::FxHasher;
use crate::core::{Row, Value};
const EMPTY: u32 = u32::MAX;
const MIN_BUCKETS: usize = 16;
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct HashEntry {
hash: u64,
row_idx: u32,
next: u32,
}
impl HashEntry {
#[inline]
fn new(hash: u64, row_idx: u32, next: u32) -> Self {
Self {
hash,
row_idx,
next,
}
}
}
pub struct JoinHashTable {
bucket_heads: Vec<i32>,
entries: Vec<HashEntry>,
bucket_mask: u64,
len: usize,
}
impl JoinHashTable {
pub fn with_capacity(row_count: usize) -> Self {
let bucket_count = (row_count * 4 / 3).max(MIN_BUCKETS).next_power_of_two();
let bucket_mask = (bucket_count - 1) as u64;
Self {
bucket_heads: vec![-1; bucket_count],
entries: Vec::with_capacity(row_count),
bucket_mask,
len: 0,
}
}
pub fn empty() -> Self {
Self {
bucket_heads: vec![-1; MIN_BUCKETS],
entries: Vec::new(),
bucket_mask: (MIN_BUCKETS - 1) as u64,
len: 0,
}
}
pub fn build(rows: &[Row], key_indices: &[usize]) -> Self {
let mut table = Self::with_capacity(rows.len());
for (idx, row) in rows.iter().enumerate() {
let hash = hash_row_keys(row, key_indices);
table.insert(hash, idx as u32);
}
table
}
pub fn build_with_bloom(
rows: &[Row],
key_indices: &[usize],
bloom_builder: &mut crate::optimizer::bloom::BloomFilterBuilder,
) -> Self {
let mut table = Self::with_capacity(rows.len());
for (idx, row) in rows.iter().enumerate() {
let hash = hash_row_keys(row, key_indices);
table.insert(hash, idx as u32);
bloom_builder.insert_raw_hash(hash);
}
table
}
#[inline]
pub fn insert(&mut self, hash: u64, row_idx: u32) {
let bucket = (hash & self.bucket_mask) as usize;
let old_head = self.bucket_heads[bucket];
let entry_idx = self.len as u32;
let next = if old_head >= 0 {
old_head as u32
} else {
EMPTY
};
self.entries.push(HashEntry::new(hash, row_idx, next));
self.bucket_heads[bucket] = entry_idx as i32;
self.len += 1;
}
#[inline]
pub fn probe(&self, hash: u64) -> ProbeIter<'_> {
let bucket = (hash & self.bucket_mask) as usize;
let first = self.bucket_heads[bucket];
ProbeIter {
table: self,
hash,
current: first,
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn bucket_count(&self) -> usize {
self.bucket_heads.len()
}
#[inline]
pub fn load_factor(&self) -> f64 {
self.len as f64 / self.bucket_heads.len() as f64
}
}
pub struct ProbeIter<'a> {
table: &'a JoinHashTable,
hash: u64,
current: i32,
}
impl Iterator for ProbeIter<'_> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
while self.current >= 0 {
let entry = &self.table.entries[self.current as usize];
self.current = if entry.next == EMPTY {
-1
} else {
entry.next as i32
};
if entry.hash == self.hash {
return Some(entry.row_idx as usize);
}
}
None
}
}
#[inline]
pub fn hash_keys_with<'a, F>(key_indices: &[usize], get_value: F) -> u64
where
F: Fn(usize) -> Option<&'a Value>,
{
if key_indices.len() == 1 {
if let Some(Value::Integer(i)) = get_value(key_indices[0]) {
return (*i as u64).wrapping_mul(0x517cc1b727220a95);
}
}
let mut hasher = FxHasher::default();
for &idx in key_indices {
if let Some(value) = get_value(idx) {
hash_value(&mut hasher, value);
} else {
0xDEADBEEF_u64.hash(&mut hasher);
}
}
hasher.finish()
}
#[inline]
pub fn hash_row_keys(row: &Row, key_indices: &[usize]) -> u64 {
if key_indices.len() == 1 {
if let Some(Value::Integer(i)) = row.get(key_indices[0]) {
return (*i as u64).wrapping_mul(0x517cc1b727220a95);
}
}
let mut hasher = FxHasher::default();
for &idx in key_indices {
if let Some(value) = row.get(idx) {
hash_value(&mut hasher, value);
} else {
0xDEADBEEF_u64.hash(&mut hasher);
}
}
hasher.finish()
}
#[inline]
fn hash_value<H: Hasher>(hasher: &mut H, value: &Value) {
match value {
Value::Integer(i) => {
1_u8.hash(hasher);
i.hash(hasher);
}
Value::Float(f) => {
2_u8.hash(hasher);
f.to_bits().hash(hasher);
}
Value::Text(s) => {
3_u8.hash(hasher);
s.hash(hasher);
}
Value::Boolean(b) => {
4_u8.hash(hasher);
b.hash(hasher);
}
Value::Null(_) => {
5_u8.hash(hasher);
}
Value::Timestamp(ts) => {
6_u8.hash(hasher);
ts.timestamp_nanos_opt().hash(hasher);
}
Value::Extension(data) => {
10_u8.hash(hasher);
data.hash(hasher);
}
}
}
#[inline]
pub fn verify_key_equality(row1: &Row, row2: &Row, indices1: &[usize], indices2: &[usize]) -> bool {
debug_assert_eq!(indices1.len(), indices2.len());
if indices1.len() == 1 {
let v1 = row1.get(indices1[0]);
let v2 = row2.get(indices2[0]);
return match (v1, v2) {
(Some(Value::Integer(a)), Some(Value::Integer(b))) => a == b,
(Some(a), Some(b)) => values_equal(a, b),
_ => false,
};
}
for (&idx1, &idx2) in indices1.iter().zip(indices2.iter()) {
let v1 = row1.get(idx1);
let v2 = row2.get(idx2);
match (v1, v2) {
(Some(Value::Integer(a)), Some(Value::Integer(b))) => {
if a != b {
return false;
}
}
(Some(a), Some(b)) => {
if !values_equal(a, b) {
return false;
}
}
(None, None) => {
return false;
}
_ => return false,
}
}
true
}
#[inline]
fn values_equal(a: &Value, b: &Value) -> bool {
match (a, b) {
(Value::Integer(x), Value::Integer(y)) => x == y,
(Value::Float(x), Value::Float(y)) => x.to_bits() == y.to_bits(),
(Value::Text(x), Value::Text(y)) => x == y,
(Value::Boolean(x), Value::Boolean(y)) => x == y,
(Value::Null(_), Value::Null(_)) => false, (Value::Timestamp(x), Value::Timestamp(y)) => x == y,
(Value::Extension(x), Value::Extension(y)) => x == y,
(Value::Integer(x), Value::Float(y)) => (*x as f64).to_bits() == y.to_bits(),
(Value::Float(x), Value::Integer(y)) => x.to_bits() == (*y as f64).to_bits(),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_row(values: Vec<i64>) -> Row {
Row::from_values(values.into_iter().map(Value::integer).collect())
}
#[test]
fn test_basic_insert_and_probe() {
let mut table = JoinHashTable::with_capacity(4);
table.insert(100, 0);
table.insert(200, 1);
table.insert(100, 2); table.insert(300, 3);
assert_eq!(table.len(), 4);
let matches: Vec<_> = table.probe(100).collect();
assert_eq!(matches.len(), 2);
assert!(matches.contains(&0));
assert!(matches.contains(&2));
let matches: Vec<_> = table.probe(200).collect();
assert_eq!(matches, vec![1]);
let matches: Vec<_> = table.probe(999).collect();
assert!(matches.is_empty());
}
#[test]
fn test_build_from_rows() {
let rows = vec![
make_row(vec![1, 10]),
make_row(vec![2, 20]),
make_row(vec![1, 30]), make_row(vec![3, 40]),
];
let key_indices = vec![0]; let table = JoinHashTable::build(&rows, &key_indices);
assert_eq!(table.len(), 4);
let hash = hash_row_keys(&rows[0], &key_indices);
let matches: Vec<_> = table.probe(hash).collect();
assert_eq!(matches.len(), 2);
}
#[test]
fn test_empty_table() {
let table = JoinHashTable::empty();
assert!(table.is_empty());
assert_eq!(table.len(), 0);
let matches: Vec<_> = table.probe(100).collect();
assert!(matches.is_empty());
}
#[test]
fn test_load_factor() {
let mut table = JoinHashTable::with_capacity(100);
for i in 0..100 {
table.insert(i as u64, i as u32);
}
let load = table.load_factor();
assert!(
load > 0.3 && load <= 1.0,
"Load factor {} out of expected range",
load
);
assert_eq!(table.len(), 100);
}
#[test]
fn test_verify_key_equality() {
let row1 = Row::from_values(vec![Value::integer(1), Value::text("hello")]);
let row2 = Row::from_values(vec![Value::integer(1), Value::text("hello")]);
let row3 = Row::from_values(vec![Value::integer(2), Value::text("hello")]);
assert!(verify_key_equality(&row1, &row2, &[0, 1], &[0, 1]));
assert!(!verify_key_equality(&row1, &row3, &[0, 1], &[0, 1]));
}
#[test]
fn test_hash_row_keys() {
let row1 = make_row(vec![1, 2, 3]);
let row2 = make_row(vec![1, 2, 3]);
let row3 = make_row(vec![1, 2, 4]);
let indices = vec![0, 1];
assert_eq!(
hash_row_keys(&row1, &indices),
hash_row_keys(&row2, &indices)
);
let row4 = make_row(vec![1, 2, 999]);
assert_eq!(
hash_row_keys(&row1, &indices),
hash_row_keys(&row4, &indices)
);
assert_ne!(hash_row_keys(&row1, &[0, 2]), hash_row_keys(&row3, &[0, 2]));
}
#[test]
fn test_chain_collision() {
let mut table = JoinHashTable {
bucket_heads: vec![-1; 4], entries: Vec::new(),
bucket_mask: 3,
len: 0,
};
table.insert(0, 0);
table.insert(4, 1);
table.insert(8, 2);
table.insert(12, 3);
assert_eq!(table.probe(0).count(), 1);
assert_eq!(table.probe(4).count(), 1);
assert_eq!(table.probe(8).count(), 1);
assert_eq!(table.probe(12).count(), 1);
}
}