datafusion_spark/function/array/
repeat.rs1use arrow::datatypes::{DataType, Field};
19use datafusion_common::utils::take_function_args;
20use datafusion_common::{Result, ScalarValue, exec_err};
21use datafusion_expr::{
22 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23};
24use datafusion_functions_nested::repeat::ArrayRepeat;
25use std::sync::Arc;
26
27use crate::function::null_utils::{
28 NullMaskResolution, apply_null_mask, compute_null_mask,
29};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkArrayRepeat {
35 signature: Signature,
36}
37
38impl Default for SparkArrayRepeat {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkArrayRepeat {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::user_defined(Volatility::Immutable),
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkArrayRepeat {
53 fn name(&self) -> &str {
54 "array_repeat"
55 }
56
57 fn signature(&self) -> &Signature {
58 &self.signature
59 }
60
61 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
62 Ok(DataType::List(Arc::new(Field::new_list_field(
63 arg_types[0].clone(),
64 true,
65 ))))
66 }
67
68 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69 spark_array_repeat(args)
70 }
71
72 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
73 let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
74
75 let second = match second_type {
77 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
78 DataType::Int64
79 }
80 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
81 DataType::UInt64
82 }
83 _ => return exec_err!("count must be an integer type"),
84 };
85
86 Ok(vec![first_type.clone(), second])
87 }
88}
89
90fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93 let ScalarFunctionArgs {
94 args: arg_values,
95 arg_fields,
96 number_rows,
97 return_field,
98 config_options,
99 } = args;
100 let return_type = return_field.data_type().clone();
101
102 let null_mask = compute_null_mask(&arg_values[1..]);
104
105 if matches!(null_mask, NullMaskResolution::ReturnNull) {
107 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
108 }
109
110 let array_repeat_func = ArrayRepeat::new();
111 let func_args = ScalarFunctionArgs {
112 args: arg_values,
113 arg_fields,
114 number_rows,
115 return_field,
116 config_options,
117 };
118 let result = array_repeat_func.invoke_with_args(func_args)?;
119
120 apply_null_mask(result, null_mask, &return_type)
121}