Skip to main content

datafusion_spark/function/array/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field};
19use datafusion_common::utils::take_function_args;
20use datafusion_common::{Result, ScalarValue, exec_err};
21use datafusion_expr::{
22    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23};
24use datafusion_functions_nested::repeat::ArrayRepeat;
25use std::any::Any;
26use std::sync::Arc;
27
28use crate::function::null_utils::{
29    NullMaskResolution, apply_null_mask, compute_null_mask,
30};
31
32/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL.
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkArrayRepeat {
36    signature: Signature,
37}
38
39impl Default for SparkArrayRepeat {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkArrayRepeat {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::user_defined(Volatility::Immutable),
49        }
50    }
51}
52
53impl ScalarUDFImpl for SparkArrayRepeat {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "array_repeat"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67        Ok(DataType::List(Arc::new(Field::new_list_field(
68            arg_types[0].clone(),
69            true,
70        ))))
71    }
72
73    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74        spark_array_repeat(args)
75    }
76
77    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78        let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
79
80        // Coerce the second argument to Int64/UInt64 if it's a numeric type
81        let second = match second_type {
82            DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
83                DataType::Int64
84            }
85            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
86                DataType::UInt64
87            }
88            _ => return exec_err!("count must be an integer type"),
89        };
90
91        Ok(vec![first_type.clone(), second])
92    }
93}
94
95/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
96/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
97fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98    let ScalarFunctionArgs {
99        args: arg_values,
100        arg_fields,
101        number_rows,
102        return_field,
103        config_options,
104    } = args;
105    let return_type = return_field.data_type().clone();
106
107    // Step 1: Check for NULL mask in incoming args
108    let null_mask = compute_null_mask(&arg_values, number_rows)?;
109
110    // If any argument is null then return NULL immediately
111    if matches!(null_mask, NullMaskResolution::ReturnNull) {
112        return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
113    }
114
115    // Step 2: Delegate to DataFusion's array_repeat
116    let array_repeat_func = ArrayRepeat::new();
117    let func_args = ScalarFunctionArgs {
118        args: arg_values,
119        arg_fields,
120        number_rows,
121        return_field,
122        config_options,
123    };
124    let result = array_repeat_func.invoke_with_args(func_args)?;
125
126    // Step 3: Apply NULL mask to result
127    apply_null_mask(result, null_mask, &return_type)
128}