use std::any::Any;
use std::sync::Arc;
use arrow::array::{
Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder,
};
use arrow::compute::{can_cast_types, cast};
use arrow::datatypes::DataType::{Int64, Utf8};
use arrow::datatypes::{DataType, Int64Type};
use datafusion_common::cast::as_string_array;
use datafusion_common::{DataFusionError, Result, plan_datafusion_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkElt {
signature: Signature,
}
impl Default for SparkElt {
fn default() -> Self {
SparkElt::new()
}
}
impl SparkElt {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkElt {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"elt"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(elt, vec![])(&args.args)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let length = arg_types.len();
if length < 2 {
plan_datafusion_err!(
"ELT function expects at least 2 arguments: index, value1"
);
}
let idx_dt: &DataType = &arg_types[0];
if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) {
return Err(DataFusionError::Plan(format!(
"ELT index must be Int64 (or castable to Int64), got {idx_dt:?}"
)));
}
let mut coerced = Vec::with_capacity(arg_types.len());
coerced.push(Int64);
for _ in 1..length {
coerced.push(Utf8);
}
Ok(coerced)
}
}
fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
let n_rows = args[0].len();
let idx: &PrimitiveArray<Int64Type> =
args[0].as_primitive_opt::<Int64Type>().ok_or_else(|| {
DataFusionError::Plan(format!(
"ELT function: first argument must be Int64 (got {:?})",
args[0].data_type()
))
})?;
let num_values = args.len() - 1;
let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(num_values);
for a in args.iter().skip(1) {
let casted = cast(a, &Utf8)?;
let sa = as_string_array(&casted)?;
cols.push(Arc::new(sa.clone()));
}
let mut builder = StringBuilder::new();
for i in 0..n_rows {
if idx.is_null(i) {
builder.append_null();
continue;
}
let index = idx.value(i);
if index < 1 || (index as usize) > num_values {
builder.append_null();
continue;
}
let value_idx = (index as usize) - 1;
let col = &cols[value_idx];
if col.is_null(i) {
builder.append_null();
} else {
builder.append_value(col.value(i));
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
use datafusion_common::Result;
use arrow::array::{ArrayRef, StringArray};
use datafusion_common::DataFusionError;
use std::sync::Arc;
fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<Arc<StringArray>> {
let arr = elt(&arrs)?;
let string_array = arr
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?;
Ok(Arc::new(string_array.clone()))
}
#[test]
fn elt_utf8_basic() -> Result<()> {
let idx = Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
Some(3),
Some(4),
Some(0),
None,
]));
let v1 = Arc::new(StringArray::from(vec![
Some("a1"),
Some("a2"),
Some("a3"),
Some("a4"),
Some("a5"),
Some("a6"),
]));
let v2 = Arc::new(StringArray::from(vec![
Some("b1"),
Some("b2"),
None,
Some("b4"),
Some("b5"),
Some("b6"),
]));
let v3 = Arc::new(StringArray::from(vec![
Some("c1"),
Some("c2"),
Some("c3"),
None,
Some("c5"),
Some("c6"),
]));
let out = run_elt_arrays(vec![idx, v1, v2, v3])?;
assert_eq!(out.len(), 6);
assert_eq!(out.value(0), "a1");
assert_eq!(out.value(1), "b2");
assert_eq!(out.value(2), "c3");
assert!(out.is_null(3));
assert!(out.is_null(4));
assert!(out.is_null(5));
Ok(())
}
#[test]
fn elt_int64_basic() -> Result<()> {
let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)]));
let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)]));
let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)]));
let out = run_elt_arrays(vec![idx, v1, v2])?;
assert_eq!(out.len(), 3);
assert_eq!(out.value(0), "100");
assert_eq!(out.value(1), "20");
assert_eq!(out.value(2), "300");
Ok(())
}
#[test]
fn elt_out_of_range_all_null() -> Result<()> {
let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)]));
let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")]));
let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")]));
let out = run_elt_arrays(vec![idx, v1, v2])?;
assert!(out.is_null(0));
assert!(out.is_null(1));
assert!(out.is_null(2));
Ok(())
}
#[test]
fn elt_utf8_returns_utf8() -> Result<()> {
let idx = Arc::new(Int64Array::from(vec![Some(1)]));
let v1 = Arc::new(StringArray::from(vec![Some("scala")]));
let v2 = Arc::new(StringArray::from(vec![Some("java")]));
let out = run_elt_arrays(vec![idx, v1, v2])?;
assert_eq!(out.data_type(), &Utf8);
Ok(())
}
}