use arrow::array::builder::GenericStringBuilder;
use arrow::array::cast::as_dictionary_array;
use arrow::array::types::Int32Type;
use arrow::array::{make_array, Array, DictionaryArray};
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::fmt::Write;
use std::sync::Arc;
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, false)
}
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, true)
}
fn spark_read_side_padding2(
args: &[ColumnarValue],
truncate: bool,
) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
DataType::LargeUtf8 => {
spark_read_side_padding_internal::<i64>(array, *length, truncate)
}
DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
} else {
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
};
let values = col.to_array(0)?;
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
Ok(ColumnarValue::Array(make_array(result.into())))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad/read_side_padding",
))),
}
}
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
truncate: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);
let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
for string in string_array.iter() {
match string {
Some(string) => {
let char_len = string.chars().count();
if length <= char_len {
if truncate {
let idx = string
.char_indices()
.nth(length)
.map(|(i, _)| i)
.unwrap_or(string.len());
builder.append_value(&string[..idx]);
} else {
builder.append_value(string);
}
} else {
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
}
}
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}