Skip to main content

datafusion_spark/function/array/
slice.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, Int64Builder};
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::cast::{as_int64_array, as_list_array};
21use datafusion_common::utils::ListCoercion;
22use datafusion_common::{
23    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
24};
25use datafusion_expr::{
26    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
27    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
28};
29use datafusion_functions_nested::extract::array_slice_udf;
30use std::sync::Arc;
31
32/// Spark slice function implementation
33/// Main difference from DataFusion's array_slice is that the third argument is the length of the slice and not the end index.
34/// <https://spark.apache.org/docs/latest/api/sql/index.html#slice>
35#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkSlice {
37    signature: Signature,
38}
39
40impl Default for SparkSlice {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SparkSlice {
47    pub fn new() -> Self {
48        Self {
49            signature: Signature {
50                type_signature: TypeSignature::ArraySignature(
51                    ArrayFunctionSignature::Array {
52                        arguments: vec![
53                            ArrayFunctionArgument::Array,
54                            ArrayFunctionArgument::Index,
55                            ArrayFunctionArgument::Index,
56                        ],
57                        array_coercion: Some(ListCoercion::FixedSizedListToList),
58                    },
59                ),
60                volatility: Volatility::Immutable,
61                parameter_names: None,
62            },
63        }
64    }
65}
66
67impl ScalarUDFImpl for SparkSlice {
68    fn name(&self) -> &str {
69        "slice"
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77        internal_err!("return_field_from_args should be used instead")
78    }
79
80    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
81        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
82
83        let data_type = match args.arg_fields[0].data_type() {
84            DataType::Null => {
85                DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
86            }
87            dt => dt.clone(),
88        };
89
90        Ok(Arc::new(Field::new("slice", data_type, nullable)))
91    }
92
93    fn invoke_with_args(
94        &self,
95        mut func_args: ScalarFunctionArgs,
96    ) -> Result<ColumnarValue> {
97        if func_args.args[0].data_type() == DataType::Null {
98            return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
99                DataType::Null,
100                true,
101                1,
102            )));
103        }
104
105        let array_len = func_args
106            .args
107            .iter()
108            .find_map(|arg| match arg {
109                ColumnarValue::Array(array) => Some(array.len()),
110                _ => None,
111            })
112            .unwrap_or(func_args.number_rows);
113
114        let arrays = func_args
115            .args
116            .iter()
117            .map(|arg| match arg {
118                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
119                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
120            })
121            .collect::<Result<Vec<_>>>()?;
122
123        let (start, end) = calculate_start_end(&arrays)?;
124
125        array_slice_udf().invoke_with_args(ScalarFunctionArgs {
126            args: vec![
127                func_args.args.swap_remove(0),
128                ColumnarValue::Array(start),
129                ColumnarValue::Array(end),
130            ],
131            arg_fields: func_args.arg_fields,
132            number_rows: func_args.number_rows,
133            return_field: func_args.return_field,
134            config_options: func_args.config_options,
135        })
136    }
137}
138
139fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
140    let [values, start, length] = take_function_args("slice", args)?;
141
142    let values_len = values.len();
143
144    let start = as_int64_array(&start)?;
145    let length = as_int64_array(&length)?;
146
147    let values = as_list_array(values)?;
148
149    let mut adjusted_start = Int64Builder::with_capacity(values_len);
150    let mut end = Int64Builder::with_capacity(values_len);
151
152    for row in 0..values_len {
153        if values.is_null(row) || start.is_null(row) || length.is_null(row) {
154            adjusted_start.append_null();
155            end.append_null();
156            continue;
157        }
158        let start = start.value(row);
159        let length = length.value(row);
160        let value_length = values.value(row).len() as i64;
161
162        if start == 0 {
163            return exec_err!("Start index must not be zero");
164        }
165        if length < 0 {
166            return exec_err!("Length must be non-negative, but got {}", length);
167        }
168
169        let adjusted_start_value = if start < 0 {
170            start + value_length + 1
171        } else {
172            start
173        };
174
175        // Spark returns an empty array when the adjusted start lands before
176        // position 1 (e.g. slice([1], -2, 2)). array_slice would otherwise
177        // treat 0 the same as 1 and return the first element.
178        if adjusted_start_value < 1 {
179            adjusted_start.append_value(1);
180            end.append_value(0);
181            continue;
182        }
183
184        adjusted_start.append_value(adjusted_start_value);
185        end.append_value(adjusted_start_value + (length - 1));
186    }
187
188    Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use arrow::array::NullArray;
195    use arrow::datatypes::Field;
196    use datafusion_common::ScalarValue;
197    use datafusion_common::cast::as_list_array;
198    use datafusion_expr::ReturnFieldArgs;
199
200    #[test]
201    fn test_spark_slice_function_when_input_is_null() {
202        let slice = SparkSlice::new();
203        let arg_fields: Vec<Arc<Field>> = vec![
204            Arc::new(Field::new("a", DataType::Null, true)),
205            Arc::new(Field::new("s", DataType::Int64, true)),
206            Arc::new(Field::new("l", DataType::Int64, true)),
207        ];
208        let out = slice
209            .return_field_from_args(ReturnFieldArgs {
210                arg_fields: &arg_fields,
211                scalar_arguments: &[],
212            })
213            .unwrap();
214        assert_eq!(
215            out.data_type(),
216            &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
217        );
218    }
219
220    #[test]
221    fn test_spark_slice_function_when_input_array_is_null() {
222        let input_args = vec![
223            ColumnarValue::Array(Arc::new(NullArray::new(1))),
224            ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
225            ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
226        ];
227
228        let args = ScalarFunctionArgs {
229            args: input_args,
230            arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))],
231            number_rows: 1,
232            return_field: Arc::new(Field::new(
233                "slice",
234                DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))),
235                true,
236            )),
237            config_options: Arc::new(Default::default()),
238        };
239        let slice = SparkSlice::new();
240        let result = slice.invoke_with_args(args).unwrap();
241        let arr = result.to_array(1).unwrap();
242        let list = as_list_array(&arr).unwrap();
243        assert_eq!(
244            arr.data_type(),
245            &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
246        );
247        assert!(list.is_null(0));
248    }
249}