use std::any::Any;
use std::sync::{Arc, OnceLock};
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, GenericListArray, MapArray, OffsetSizeTrait, StructArray,
};
use datafusion::arrow::buffer::OffsetBuffer;
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use crate::struct_expansion::map_data_type;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct MapFromEntriesUdf {
signature: Signature,
}
impl Default for MapFromEntriesUdf {
fn default() -> Self {
Self::new()
}
}
impl MapFromEntriesUdf {
pub fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for MapFromEntriesUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_map_from_entries"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let field = match &arg_types[0] {
DataType::List(field) | DataType::LargeList(field) => field,
_ => {
return exec_err!(
"map_from_entries expects array type, got {:?}",
arg_types[0]
)
}
};
match field.data_type() {
DataType::Struct(fields) if fields.len() == 2 => Ok(map_data_type(
fields[0].data_type().clone(),
fields[1].data_type().clone(),
)),
_ => exec_err!(
"map_from_entries expects array of 2-field structs, got array of {:?}",
field.data_type()
),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 1 {
return exec_err!(
"map_from_entries expects exactly 1 argument, got {}",
args.len()
);
}
match &args[0] {
ColumnarValue::Scalar(scalar) => {
let result = convert_scalar_to_map(scalar)?;
Ok(ColumnarValue::Scalar(result))
}
ColumnarValue::Array(array) => {
let result = convert_array_to_maps(array)?;
Ok(ColumnarValue::Array(result))
}
}
}
}
fn convert_scalar_to_map(scalar: &ScalarValue) -> Result<ScalarValue> {
let entries = match scalar {
ScalarValue::List(arr) => {
if arr.is_empty() || arr.is_null(0) {
return create_null_map_scalar_from_type(arr.value_type());
}
arr.value(0)
}
ScalarValue::LargeList(arr) => {
if arr.is_empty() || arr.is_null(0) {
return create_null_map_scalar_from_type(arr.value_type());
}
arr.value(0)
}
_ => {
return exec_err!(
"map_from_entries expects list type, got {:?}",
scalar.data_type()
)
}
};
convert_struct_array_to_map_scalar(entries.as_struct())
}
fn convert_struct_array_to_map_scalar(struct_array: &StructArray) -> Result<ScalarValue> {
if struct_array.num_columns() != 2 {
return exec_err!(
"map_from_entries expects struct with 2 fields, got {}",
struct_array.num_columns()
);
}
let keys = struct_array.column(0);
let values = struct_array.column(1);
let key_type = keys.data_type().clone();
let value_type = values.data_type().clone();
if struct_array.is_empty() {
let entry_fields = Fields::from(vec![
Field::new("key", key_type, false),
Field::new("value", value_type, true),
]);
let empty_entries = StructArray::try_new(
entry_fields.clone(),
vec![
datafusion::arrow::array::new_empty_array(&entry_fields[0].data_type()),
datafusion::arrow::array::new_empty_array(&entry_fields[1].data_type()),
],
None,
)
.map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))?;
let map_array = MapArray::new(
Arc::new(Field::new("entries", DataType::Struct(entry_fields), false)),
OffsetBuffer::from_lengths([0usize]),
empty_entries,
None,
false,
);
return ScalarValue::try_from_array(&map_array, 0);
}
let new_struct = StructArray::new(
Fields::from(vec![
Field::new("key", key_type, false),
Field::new("value", value_type, true),
]),
vec![Arc::clone(keys), Arc::clone(values)],
struct_array.nulls().cloned(),
);
let map_array = MapArray::new(
Arc::new(Field::new("entries", new_struct.data_type().clone(), false)),
OffsetBuffer::from_lengths([struct_array.len()]),
new_struct,
None,
false,
);
ScalarValue::try_from_array(&map_array, 0)
}
fn create_null_map_scalar_from_type(element_type: DataType) -> Result<ScalarValue> {
let (key_type, value_type) = match element_type {
DataType::Struct(fields) if fields.len() == 2 => {
(fields[0].data_type().clone(), fields[1].data_type().clone())
}
_ => (DataType::Utf8, DataType::Utf8), };
let entries_field = Field::new(
"entries",
DataType::Struct(Fields::from(vec![
Field::new("key", key_type, false),
Field::new("value", value_type, true),
])),
false,
);
let map_type = DataType::Map(Arc::new(entries_field), false);
ScalarValue::try_new_null(&map_type)
}
fn convert_array_to_maps(array: &ArrayRef) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => convert_list_to_maps(array.as_list::<i32>()),
DataType::LargeList(_) => convert_list_to_maps(array.as_list::<i64>()),
dt => exec_err!("map_from_entries expects array type, got {:?}", dt),
}
}
fn convert_list_to_maps<O: OffsetSizeTrait>(list_array: &GenericListArray<O>) -> Result<ArrayRef> {
let struct_type = match list_array.value_type() {
DataType::Struct(fields) if fields.len() == 2 => fields,
other => {
return exec_err!(
"map_from_entries expects array of 2-field structs, got array of {:?}",
other
)
}
};
let key_type = struct_type[0].data_type().clone();
let value_type = struct_type[1].data_type().clone();
let mut all_keys: Vec<ScalarValue> = Vec::new();
let mut all_values: Vec<ScalarValue> = Vec::new();
let mut offsets: Vec<i32> = vec![0];
let mut nulls: Vec<bool> = Vec::new();
for i in 0..list_array.len() {
if list_array.is_null(i) {
nulls.push(false); offsets.push(offsets.last().copied().unwrap_or(0));
continue;
}
nulls.push(true);
let entries = list_array.value(i);
let struct_array = entries.as_struct();
let keys = struct_array.column(0);
let values = struct_array.column(1);
for j in 0..struct_array.len() {
all_keys.push(ScalarValue::try_from_array(keys, j)?);
all_values.push(ScalarValue::try_from_array(values, j)?);
}
offsets.push(all_keys.len() as i32);
}
let keys_array = if all_keys.is_empty() {
datafusion::arrow::array::new_empty_array(&key_type)
} else {
ScalarValue::iter_to_array(all_keys.into_iter())?
};
let values_array = if all_values.is_empty() {
datafusion::arrow::array::new_empty_array(&value_type)
} else {
ScalarValue::iter_to_array(all_values.into_iter())?
};
let struct_array = StructArray::new(
Fields::from(vec![
Field::new("key", key_type.clone(), false),
Field::new("value", value_type.clone(), true),
]),
vec![keys_array, values_array],
None,
);
let null_buffer = if nulls.iter().all(|&b| b) {
None
} else {
Some(datafusion::arrow::buffer::NullBuffer::from(nulls))
};
let map_array = MapArray::new(
Arc::new(Field::new(
"entries",
struct_array.data_type().clone(),
false,
)),
OffsetBuffer::new(offsets.into()),
struct_array,
null_buffer,
false,
);
Ok(Arc::new(map_array))
}
static MAP_FROM_ENTRIES_UDF: OnceLock<ScalarUDF> = OnceLock::new();
pub fn map_from_entries_udf() -> ScalarUDF {
MAP_FROM_ENTRIES_UDF
.get_or_init(|| ScalarUDF::new_from_impl(MapFromEntriesUdf::new()))
.clone()
}