mod boolean_lookup_table;
mod bytes_like_lookup_table;
mod primitive_lookup_table;
use crate::expressions::Literal;
use crate::expressions::case::CaseBody;
use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap;
use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap;
use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap;
use arrow::array::{Array, ArrayRef, UInt32Array, downcast_primitive};
use arrow::datatypes::DataType;
use datafusion_common::{ScalarValue, arrow_datafusion_err, plan_datafusion_err};
use indexmap::IndexMap;
use std::fmt::Debug;
#[derive(Debug)]
pub(in super::super) struct LiteralLookupTable {
lookup: Box<dyn WhenLiteralIndexMap>,
else_index: u32,
then_and_else_values: ArrayRef,
}
impl LiteralLookupTable {
pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> {
if body.when_then_expr.is_empty() {
return None;
}
if body.when_then_expr.len() == 1 {
return None;
}
let when_then_exprs_maybe_literals = body
.when_then_expr
.iter()
.map(|(when, then)| {
let when_maybe_literal = when.as_any().downcast_ref::<Literal>();
let then_maybe_literal = then.as_any().downcast_ref::<Literal>();
when_maybe_literal.zip(then_maybe_literal)
})
.collect::<Vec<_>>();
if when_then_exprs_maybe_literals.contains(&None) {
return None;
}
let when_then_exprs_scalars = when_then_exprs_maybe_literals
.into_iter()
.flatten()
.map(|(when_lit, then_lit)| {
(when_lit.value().clone(), then_lit.value().clone())
})
.filter(|(when_lit, _)| !when_lit.is_null())
.collect::<Vec<_>>();
if when_then_exprs_scalars.is_empty() {
return None;
}
let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = {
let mut map = IndexMap::with_capacity(body.when_then_expr.len());
for (when, then) in when_then_exprs_scalars.into_iter() {
if !map.contains_key(&when) {
map.insert(when, then);
}
}
map.into_iter().unzip()
};
let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr {
let literal = else_expr.as_any().downcast_ref::<Literal>()?;
literal.value().clone()
} else {
let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else {
return None;
};
null_scalar
};
{
let when_data_type = when[0].data_type();
if when.iter().any(|l| l.data_type() != when_data_type) {
return None;
}
}
{
let data_type = then[0].data_type();
if then.iter().any(|l| l.data_type() != data_type) {
return None;
}
if else_value.data_type() != data_type {
return None;
}
}
let then_and_else_values = ScalarValue::iter_to_array(
then.iter()
.chain(std::iter::once(&else_value))
.cloned(),
)
.ok()?;
let else_index = then_and_else_values.len() as u32 - 1;
let lookup = try_creating_lookup_table(when).ok()?;
Some(Self {
lookup,
then_and_else_values,
else_index,
})
}
pub(in super::super) fn map_keys_to_values(
&self,
keys_array: &ArrayRef,
) -> datafusion_common::Result<ArrayRef> {
let take_indices = self
.lookup
.map_to_when_indices(keys_array, self.else_index)?;
let take_indices = UInt32Array::from(take_indices);
let output =
arrow::compute::take(&self.then_and_else_values, &take_indices, None)
.map_err(|e| arrow_datafusion_err!(e))?;
Ok(output)
}
}
pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync {
fn map_to_when_indices(
&self,
array: &ArrayRef,
else_index: u32,
) -> datafusion_common::Result<Vec<u32>>;
}
fn try_creating_lookup_table(
unique_non_null_literals: Vec<ScalarValue>,
) -> datafusion_common::Result<Box<dyn WhenLiteralIndexMap>> {
assert_ne!(
unique_non_null_literals.len(),
0,
"Must have at least one literal"
);
match unique_non_null_literals[0].data_type() {
DataType::Boolean => {
let lookup_table = BooleanIndexMap::try_new(unique_non_null_literals)?;
Ok(Box::new(lookup_table))
}
data_type if data_type.is_primitive() => {
macro_rules! create_matching_map {
($t:ty) => {{
let lookup_table =
PrimitiveIndexMap::<$t>::try_new(unique_non_null_literals)?;
Ok(Box::new(lookup_table))
}};
}
downcast_primitive! {
data_type => (create_matching_map),
_ => Err(plan_datafusion_err!(
"Unsupported field type for primitive: {:?}",
data_type
)),
}
}
DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
| DataType::FixedSizeBinary(_)
| DataType::Utf8View
| DataType::BinaryView => {
let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?;
Ok(Box::new(lookup_table))
}
DataType::Dictionary(_key, value)
if matches!(
value.as_ref(),
DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
| DataType::FixedSizeBinary(_)
| DataType::Utf8View
| DataType::BinaryView
) =>
{
let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?;
Ok(Box::new(lookup_table))
}
_ => Err(plan_datafusion_err!(
"Unsupported data type for lookup table: {}",
unique_non_null_literals[0].data_type()
)),
}
}