Skip to main content

datafusion_spark/function/array/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field};
19use datafusion_common::utils::take_function_args;
20use datafusion_common::{Result, ScalarValue, exec_err};
21use datafusion_expr::{
22    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23};
24use datafusion_functions_nested::repeat::ArrayRepeat;
25use std::sync::Arc;
26
27use crate::function::null_utils::{
28    NullMaskResolution, apply_null_mask, compute_null_mask,
29};
30
31/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL count: in Spark if the count is NULL, the result is NULL.
32/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkArrayRepeat {
35    signature: Signature,
36}
37
38impl Default for SparkArrayRepeat {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkArrayRepeat {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::user_defined(Volatility::Immutable),
48        }
49    }
50}
51
52impl ScalarUDFImpl for SparkArrayRepeat {
53    fn name(&self) -> &str {
54        "array_repeat"
55    }
56
57    fn signature(&self) -> &Signature {
58        &self.signature
59    }
60
61    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
62        Ok(DataType::List(Arc::new(Field::new_list_field(
63            arg_types[0].clone(),
64            true,
65        ))))
66    }
67
68    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69        spark_array_repeat(args)
70    }
71
72    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
73        let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
74
75        // Coerce the second argument to Int64/UInt64 if it's a numeric type
76        let second = match second_type {
77            DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
78                DataType::Int64
79            }
80            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
81                DataType::UInt64
82            }
83            _ => return exec_err!("count must be an integer type"),
84        };
85
86        Ok(vec![first_type.clone(), second])
87    }
88}
89
90/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
91/// if the count argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
92fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93    let ScalarFunctionArgs {
94        args: arg_values,
95        arg_fields,
96        number_rows,
97        return_field,
98        config_options,
99    } = args;
100    let return_type = return_field.data_type().clone();
101
102    // A NULL element should be repeated into the array, not cause a NULL result.
103    let null_mask = compute_null_mask(&arg_values[1..]);
104
105    // If count is null then return NULL immediately
106    if matches!(null_mask, NullMaskResolution::ReturnNull) {
107        return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
108    }
109
110    let array_repeat_func = ArrayRepeat::new();
111    let func_args = ScalarFunctionArgs {
112        args: arg_values,
113        arg_fields,
114        number_rows,
115        return_field,
116        config_options,
117    };
118    let result = array_repeat_func.invoke_with_args(func_args)?;
119
120    apply_null_mask(result, null_mask, &return_type)
121}