use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap;
use arrow::array::{
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, PrimitiveArray,
};
use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, i256};
use datafusion_common::{HashMap, ScalarValue, internal_err};
use half::f16;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Clone)]
pub(super) struct PrimitiveIndexMap<T>
where
T: ArrowPrimitiveType,
T::Native: ToHashableKey,
{
data_type: DataType,
map: HashMap<<T::Native as ToHashableKey>::HashableKey, u32>,
}
impl<T> Debug for PrimitiveIndexMap<T>
where
T: ArrowPrimitiveType,
T::Native: ToHashableKey,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrimitiveIndexMap")
.field("map", &self.map)
.finish()
}
}
impl<T> PrimitiveIndexMap<T>
where
T: ArrowPrimitiveType,
T::Native: ToHashableKey,
{
pub(super) fn try_new(
unique_non_null_literals: Vec<ScalarValue>,
) -> datafusion_common::Result<Self> {
let input = ScalarValue::iter_to_array(unique_non_null_literals)?;
if input.null_count() > 0 {
return internal_err!("Literal values for WHEN clauses cannot contain nulls");
}
let map = input
.as_primitive::<T>()
.values()
.iter()
.enumerate()
.map(|(map_index, value)| (value.into_hashable_key(), map_index as u32))
.collect();
Ok(Self {
map,
data_type: input.data_type().clone(),
})
}
fn map_primitive_array_to_when_indices(
&self,
array: &PrimitiveArray<T>,
else_index: u32,
) -> datafusion_common::Result<Vec<u32>> {
let indices = array
.into_iter()
.map(|value| match value {
Some(value) => self
.map
.get(&value.into_hashable_key())
.copied()
.unwrap_or(else_index),
None => else_index,
})
.collect::<Vec<u32>>();
Ok(indices)
}
}
impl<T> WhenLiteralIndexMap for PrimitiveIndexMap<T>
where
T: ArrowPrimitiveType,
T::Native: ToHashableKey,
{
fn map_to_when_indices(
&self,
array: &ArrayRef,
else_index: u32,
) -> datafusion_common::Result<Vec<u32>> {
match array.data_type() {
dt if dt == &self.data_type => {
let primitive_array = array.as_primitive::<T>();
self.map_primitive_array_to_when_indices(primitive_array, else_index)
}
DataType::Dictionary(_, value_type)
if value_type.as_ref() == &self.data_type =>
{
let converted = arrow::compute::cast(array.as_ref(), &self.data_type)?;
self.map_primitive_array_to_when_indices(
converted.as_primitive::<T>(),
else_index,
)
}
_ => internal_err!(
"PrimitiveIndexMap expected array of type {:?} but got {:?}",
self.data_type,
array.data_type()
),
}
}
}
pub(super) trait ToHashableKey: ArrowNativeTypeOp {
type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync;
fn into_hashable_key(self) -> Self::HashableKey;
}
macro_rules! impl_to_hashable_key {
(@single_already_hashable | $t:ty) => {
impl ToHashableKey for $t {
type HashableKey = $t;
#[inline]
fn into_hashable_key(self) -> Self::HashableKey {
self
}
}
};
(@already_hashable | $($t:ty),+ $(,)?) => {
$(
impl_to_hashable_key!(@single_already_hashable | $t);
)+
};
(@float | $t:ty => $hashable:ty) => {
impl ToHashableKey for $t {
type HashableKey = $hashable;
#[inline]
fn into_hashable_key(self) -> Self::HashableKey {
self.to_bits()
}
}
};
}
impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano);
impl_to_hashable_key!(@float | f16 => u16);
impl_to_hashable_key!(@float | f32 => u32);
impl_to_hashable_key!(@float | f64 => u64);
#[cfg(test)]
mod tests {
use super::ToHashableKey;
use arrow::array::downcast_primitive;
#[test]
fn should_implement_to_hashable_key_for_all_primitives() {
#[derive(Debug, Default)]
struct ExampleSet<T>
where
T: arrow::datatypes::ArrowPrimitiveType,
T::Native: ToHashableKey,
{
_map: std::collections::HashSet<<T::Native as ToHashableKey>::HashableKey>,
}
macro_rules! create_matching_set {
($t:ty) => {{
let _lookup_table = ExampleSet::<$t> {
_map: Default::default(),
};
return;
}};
}
let data_type = arrow::datatypes::DataType::Float16;
downcast_primitive! {
data_type => (create_matching_set),
_ => panic!("not implemented for {data_type}"),
}
}
}