datafusion_spark/function/array/
repeat.rs1use arrow::datatypes::{DataType, Field};
19use datafusion_common::utils::take_function_args;
20use datafusion_common::{Result, ScalarValue, exec_err};
21use datafusion_expr::{
22 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23};
24use datafusion_functions_nested::repeat::ArrayRepeat;
25use std::any::Any;
26use std::sync::Arc;
27
28use crate::function::null_utils::{
29 NullMaskResolution, apply_null_mask, compute_null_mask,
30};
31
32#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkArrayRepeat {
36 signature: Signature,
37}
38
39impl Default for SparkArrayRepeat {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkArrayRepeat {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::user_defined(Volatility::Immutable),
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkArrayRepeat {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "array_repeat"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67 Ok(DataType::List(Arc::new(Field::new_list_field(
68 arg_types[0].clone(),
69 true,
70 ))))
71 }
72
73 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74 spark_array_repeat(args)
75 }
76
77 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78 let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
79
80 let second = match second_type {
82 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
83 DataType::Int64
84 }
85 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
86 DataType::UInt64
87 }
88 _ => return exec_err!("count must be an integer type"),
89 };
90
91 Ok(vec![first_type.clone(), second])
92 }
93}
94
95fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98 let ScalarFunctionArgs {
99 args: arg_values,
100 arg_fields,
101 number_rows,
102 return_field,
103 config_options,
104 } = args;
105 let return_type = return_field.data_type().clone();
106
107 let null_mask = compute_null_mask(&arg_values, number_rows)?;
109
110 if matches!(null_mask, NullMaskResolution::ReturnNull) {
112 return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
113 }
114
115 let array_repeat_func = ArrayRepeat::new();
117 let func_args = ScalarFunctionArgs {
118 args: arg_values,
119 arg_fields,
120 number_rows,
121 return_field,
122 config_options,
123 };
124 let result = array_repeat_func.invoke_with_args(func_args)?;
125
126 apply_null_mask(result, null_mask, &return_type)
128}