use std::any::Any;
use std::sync::Arc;
use std::vec;
use crate::utils::make_scalar_function;
use arrow::array::{
new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray,
MutableArrayData, NullArray, OffsetSizeTrait,
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType::{List, Null},
Field,
};
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_common::{plan_err, Result};
use datafusion_expr::binary::{
try_type_union_resolution_with_struct, type_union_resolution,
};
use datafusion_expr::TypeSignature;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
make_udf_expr_and_func!(
MakeArray,
make_array,
"Returns an Arrow array using the specified input expressions.",
make_array_udf
);
#[user_doc(
doc_section(label = "Array Functions"),
description = "Returns an array using the specified input expressions.",
syntax_example = "make_array(expression1[, ..., expression_n])",
sql_example = r#"```sql
> select make_array(1, 2, 3, 4, 5);
+----------------------------------------------------------+
| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) |
+----------------------------------------------------------+
| [1, 2, 3, 4, 5] |
+----------------------------------------------------------+
```"#,
argument(
name = "expression_n",
description = "Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators."
)
)]
#[derive(Debug)]
pub struct MakeArray {
signature: Signature,
aliases: Vec<String>,
}
impl Default for MakeArray {
fn default() -> Self {
Self::new()
}
}
impl MakeArray {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::Nullary, TypeSignature::UserDefined],
Volatility::Immutable,
),
aliases: vec![String::from("make_list")],
}
}
}
impl ScalarUDFImpl for MakeArray {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"make_array"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types.len() {
0 => Ok(empty_array_type()),
_ => {
Ok(List(Arc::new(Field::new_list_field(
arg_types[0].to_owned(),
true,
))))
}
}
}
fn invoke_with_args(
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
make_scalar_function(make_array_inner)(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let mut errors = vec![];
match try_type_union_resolution_with_struct(arg_types) {
Ok(r) => return Ok(r),
Err(e) => {
errors.push(e);
}
}
if let Some(new_type) = type_union_resolution(arg_types) {
if new_type.is_null() {
Ok(vec![DataType::Int64; arg_types.len()])
} else {
Ok(vec![new_type; arg_types.len()])
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}, errors are {:?}",
arg_types,
self.name(),
errors
)
}
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
pub(super) fn empty_array_type() -> DataType {
List(Arc::new(Field::new_list_field(DataType::Int64, true)))
}
pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
let mut data_type = Null;
for arg in arrays {
let arg_data_type = arg.data_type();
if !arg_data_type.equals_datatype(&Null) {
data_type = arg_data_type.clone();
break;
}
}
match data_type {
Null => {
let length = arrays.iter().map(|a| a.len()).sum();
let array = new_null_array(&DataType::Int64, length);
Ok(Arc::new(
SingleRowListArrayBuilder::new(array).build_list_array(),
))
}
_ => array_array::<i32>(arrays, data_type),
}
}
fn array_array<O: OffsetSizeTrait>(
args: &[ArrayRef],
data_type: DataType,
) -> Result<ArrayRef> {
if args.is_empty() {
return plan_err!("Array requires at least one argument");
}
let mut data = vec![];
let mut total_len = 0;
for arg in args {
let arg_data = if arg.as_any().is::<NullArray>() {
ArrayData::new_empty(&data_type)
} else {
arg.to_data()
};
total_len += arg_data.len();
data.push(arg_data);
}
let mut offsets: Vec<O> = Vec::with_capacity(total_len);
offsets.push(O::usize_as(0));
let capacity = Capacities::Array(total_len);
let data_ref = data.iter().collect::<Vec<_>>();
let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity);
let num_rows = args[0].len();
for row_idx in 0..num_rows {
for (arr_idx, arg) in args.iter().enumerate() {
if !arg.as_any().is::<NullArray>()
&& !arg.is_null(row_idx)
&& arg.is_valid(row_idx)
{
mutable.extend(arr_idx, row_idx, row_idx + 1);
} else {
mutable.extend_nulls(1);
}
}
offsets.push(O::usize_as(mutable.len()));
}
let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(data_type, true)),
OffsetBuffer::new(offsets.into()),
arrow::array::make_array(data),
None,
)?))
}