use datafusion::logical_expr::expr::Case as DFCase;
use datafusion::logical_expr::{lit, Expr as DFExpr};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_nested::expr_fn as nested_fn;
use datafusion_functions_nested::map::map_udf;
use hamelin_lib::func::defs::{
GetMap, MapEmpty, MapFromArrays, MapFromKeyValue, MapFromPairs, MapKeys, MapValues,
};
use super::DataFusionTranslationRegistry;
pub fn register(registry: &mut DataFusionTranslationRegistry) {
registry.register::<MapEmpty>(|_params| {
let empty_keys = datafusion_functions_nested::expr_fn::make_array(vec![]);
let empty_values = datafusion_functions_nested::expr_fn::make_array(vec![]);
Ok(map_udf().call(vec![empty_keys, empty_values]))
});
registry.register::<MapFromArrays>(|mut params| {
let keys = params.take()?.expr;
let values = params.take()?.expr;
let not_null = keys.clone().is_not_null().and(values.clone().is_not_null());
let array_len = crate::udf::hamelin_array_length_udf();
let same_len = array_len
.call(vec![keys.clone()])
.eq(array_len.call(vec![values.clone()]));
let condition = not_null.and(same_len);
let map_expr = map_udf().call(vec![keys, values]);
Ok(DFExpr::Case(DFCase {
expr: None,
when_then_expr: vec![(Box::new(condition), Box::new(map_expr))],
else_expr: Some(Box::new(lit(datafusion::common::ScalarValue::Null))),
}))
});
registry.register::<MapFromPairs>(|mut params| {
let pairs_array = params.take()?.expr;
Ok(crate::udf::map_from_entries_udf().call(vec![pairs_array]))
});
registry.register::<MapFromKeyValue>(|mut params| {
let mut keys = vec![];
let mut values = vec![];
while let Ok(pair) = params.take() {
let pair_expr = pair.expr;
keys.push(pair_expr.clone().field("f0"));
values.push(pair_expr.field("f1"));
}
let keys_array = datafusion_functions_nested::expr_fn::make_array(keys);
let values_array = datafusion_functions_nested::expr_fn::make_array(values);
Ok(map_udf().call(vec![keys_array, values_array]))
});
registry.register::<GetMap>(|mut params| {
let map_expr = params.take()?.expr;
let key = params.take()?.expr;
let extracted = nested_fn::map_extract(map_expr, key);
Ok(nested_fn::array_element(
extracted,
datafusion::logical_expr::lit(1i64),
))
});
registry.register::<MapKeys>(|mut params| {
let map_expr = params.take()?.expr;
Ok(nested_fn::map_keys(map_expr))
});
registry.register::<MapValues>(|mut params| {
let map_expr = params.take()?.expr;
Ok(nested_fn::map_values(map_expr))
});
}