use std::{any::Any, fmt::Display, hash::Hash, sync::Arc};
use ahash::RandomState;
use arrow::{
array::{BooleanArray, UInt64Array},
buffer::MutableBuffer,
datatypes::{DataType, Schema},
util::bit_util,
};
use datafusion_common::{Result, internal_datafusion_err, internal_err};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::{
DynHash, PhysicalExpr, PhysicalExprRef,
};
use crate::{hash_utils::create_hashes, joins::utils::JoinHashMapType};
#[derive(Clone, Debug)]
pub struct SeededRandomState {
random_state: RandomState,
seeds: (u64, u64, u64, u64),
}
impl SeededRandomState {
pub const fn with_seeds(k0: u64, k1: u64, k2: u64, k3: u64) -> Self {
Self {
random_state: RandomState::with_seeds(k0, k1, k2, k3),
seeds: (k0, k1, k2, k3),
}
}
pub fn random_state(&self) -> &RandomState {
&self.random_state
}
pub fn seeds(&self) -> (u64, u64, u64, u64) {
self.seeds
}
}
pub struct HashExpr {
on_columns: Vec<PhysicalExprRef>,
random_state: SeededRandomState,
description: String,
}
impl HashExpr {
pub fn new(
on_columns: Vec<PhysicalExprRef>,
random_state: SeededRandomState,
description: String,
) -> Self {
Self {
on_columns,
random_state,
description,
}
}
pub fn on_columns(&self) -> &[PhysicalExprRef] {
&self.on_columns
}
pub fn seeds(&self) -> (u64, u64, u64, u64) {
self.random_state.seeds()
}
pub fn description(&self) -> &str {
&self.description
}
}
impl std::fmt::Debug for HashExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cols = self
.on_columns
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ");
let (s1, s2, s3, s4) = self.seeds();
write!(f, "{}({cols}, [{s1},{s2},{s3},{s4}])", self.description)
}
}
impl Hash for HashExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.on_columns.dyn_hash(state);
self.description.hash(state);
self.seeds().hash(state);
}
}
impl PartialEq for HashExpr {
fn eq(&self, other: &Self) -> bool {
self.on_columns == other.on_columns
&& self.description == other.description
&& self.seeds() == other.seeds()
}
}
impl Eq for HashExpr {}
impl Display for HashExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.description)
}
}
impl PhysicalExpr for HashExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
self.on_columns.iter().collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(HashExpr::new(
children,
self.random_state.clone(),
self.description.clone(),
)))
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::UInt64)
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(false)
}
fn evaluate(
&self,
batch: &arrow::record_batch::RecordBatch,
) -> Result<ColumnarValue> {
let num_rows = batch.num_rows();
let keys_values = self
.on_columns
.iter()
.map(|c| c.evaluate(batch)?.into_array(num_rows))
.collect::<Result<Vec<_>>>()?;
let mut hashes_buffer = vec![0; num_rows];
create_hashes(
&keys_values,
self.random_state.random_state(),
&mut hashes_buffer,
)?;
Ok(ColumnarValue::Array(Arc::new(UInt64Array::from(
hashes_buffer,
))))
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.description)
}
}
pub struct HashTableLookupExpr {
hash_expr: PhysicalExprRef,
hash_map: Arc<dyn JoinHashMapType>,
description: String,
}
impl HashTableLookupExpr {
pub fn new(
hash_expr: PhysicalExprRef,
hash_map: Arc<dyn JoinHashMapType>,
description: String,
) -> Self {
Self {
hash_expr,
hash_map,
description,
}
}
}
impl std::fmt::Debug for HashTableLookupExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({:?})", self.description, self.hash_expr)
}
}
impl Hash for HashTableLookupExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.hash_expr.dyn_hash(state);
self.description.hash(state);
Arc::as_ptr(&self.hash_map).hash(state);
}
}
impl PartialEq for HashTableLookupExpr {
fn eq(&self, other: &Self) -> bool {
self.hash_expr.as_ref() == other.hash_expr.as_ref()
&& self.description == other.description
&& Arc::ptr_eq(&self.hash_map, &other.hash_map)
}
}
impl Eq for HashTableLookupExpr {}
impl Display for HashTableLookupExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.description)
}
}
impl PhysicalExpr for HashTableLookupExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.hash_expr]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
if children.len() != 1 {
return internal_err!(
"HashTableLookupExpr expects exactly 1 child, got {}",
children.len()
);
}
Ok(Arc::new(HashTableLookupExpr::new(
Arc::clone(&children[0]),
Arc::clone(&self.hash_map),
self.description.clone(),
)))
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(false)
}
fn evaluate(
&self,
batch: &arrow::record_batch::RecordBatch,
) -> Result<ColumnarValue> {
let num_rows = batch.num_rows();
let hash_array = self.hash_expr.evaluate(batch)?.into_array(num_rows)?;
let hash_array = hash_array.as_any().downcast_ref::<UInt64Array>().ok_or(
internal_datafusion_err!(
"HashTableLookupExpr expects UInt64Array from hash expression"
),
)?;
let mut buf = MutableBuffer::from_len_zeroed(bit_util::ceil(num_rows, 8));
for (idx, hash_value) in hash_array.values().iter().enumerate() {
let (matched_indices, _) = self
.hash_map
.get_matched_indices(Box::new(std::iter::once((idx, hash_value))), None);
if !matched_indices.is_empty() {
bit_util::set_bit(buf.as_slice_mut(), idx);
}
}
Ok(ColumnarValue::Array(Arc::new(
BooleanArray::new_from_packed(buf, 0, num_rows),
)))
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.description)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::joins::join_hash_map::JoinHashMapU32;
use datafusion_physical_expr::expressions::Column;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
fn compute_hash<T: Hash>(value: &T) -> u64 {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
hasher.finish()
}
#[test]
fn test_hash_expr_eq_same() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1));
let expr1 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
let expr2 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
assert_eq!(expr1, expr2);
}
#[test]
fn test_hash_expr_eq_different_columns() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1));
let col_c: PhysicalExprRef = Arc::new(Column::new("c", 2));
let expr1 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
let expr2 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_c)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_expr_eq_different_description() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let expr1 = HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"hash_one".to_string(),
);
let expr2 = HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"hash_two".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_expr_eq_different_seeds() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let expr1 = HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
let expr2 = HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(5, 6, 7, 8),
"test_hash".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_expr_hash_consistency() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1));
let expr1 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
let expr2 = HashExpr::new(
vec![Arc::clone(&col_a), Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"test_hash".to_string(),
);
assert_eq!(expr1, expr2);
assert_eq!(compute_hash(&expr1), compute_hash(&expr2));
}
#[test]
fn test_hash_table_lookup_expr_eq_same() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_map: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let expr1 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup".to_string(),
);
let expr2 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup".to_string(),
);
assert_eq!(expr1, expr2);
}
#[test]
fn test_hash_table_lookup_expr_eq_different_hash_expr() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1));
let hash_expr1: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_expr2: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_b)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_map: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let expr1 = HashTableLookupExpr::new(
Arc::clone(&hash_expr1),
Arc::clone(&hash_map),
"lookup".to_string(),
);
let expr2 = HashTableLookupExpr::new(
Arc::clone(&hash_expr2),
Arc::clone(&hash_map),
"lookup".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_table_lookup_expr_eq_different_description() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_map: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let expr1 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup_one".to_string(),
);
let expr2 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup_two".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_table_lookup_expr_eq_different_hash_map() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_map1: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let hash_map2: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let expr1 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
hash_map1,
"lookup".to_string(),
);
let expr2 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
hash_map2,
"lookup".to_string(),
);
assert_ne!(expr1, expr2);
}
#[test]
fn test_hash_table_lookup_expr_hash_consistency() {
let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0));
let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new(
vec![Arc::clone(&col_a)],
SeededRandomState::with_seeds(1, 2, 3, 4),
"inner_hash".to_string(),
));
let hash_map: Arc<dyn JoinHashMapType> =
Arc::new(JoinHashMapU32::with_capacity(10));
let expr1 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup".to_string(),
);
let expr2 = HashTableLookupExpr::new(
Arc::clone(&hash_expr),
Arc::clone(&hash_map),
"lookup".to_string(),
);
assert_eq!(expr1, expr2);
assert_eq!(compute_hash(&expr1), compute_hash(&expr2));
}
}