Skip to main content

datafusion_spark/function/array/
slice.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, Int64Builder};
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::cast::{as_int64_array, as_list_array};
21use datafusion_common::utils::ListCoercion;
22use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
23use datafusion_expr::{
24    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
25    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
26};
27use datafusion_functions_nested::extract::array_slice_udf;
28use std::any::Any;
29use std::sync::Arc;
30
31/// Spark slice function implementation
32/// Main difference from DataFusion's array_slice is that the third argument is the length of the slice and not the end index.
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#slice>
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkSlice {
36    signature: Signature,
37}
38
39impl Default for SparkSlice {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkSlice {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature {
49                type_signature: TypeSignature::ArraySignature(
50                    ArrayFunctionSignature::Array {
51                        arguments: vec![
52                            ArrayFunctionArgument::Array,
53                            ArrayFunctionArgument::Index,
54                            ArrayFunctionArgument::Index,
55                        ],
56                        array_coercion: Some(ListCoercion::FixedSizedListToList),
57                    },
58                ),
59                volatility: Volatility::Immutable,
60                parameter_names: None,
61            },
62        }
63    }
64}
65
66impl ScalarUDFImpl for SparkSlice {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70
71    fn name(&self) -> &str {
72        "slice"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80        internal_err!("return_field_from_args should be used instead")
81    }
82
83    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
84        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
85
86        Ok(Arc::new(Field::new(
87            "slice",
88            args.arg_fields[0].data_type().clone(),
89            nullable,
90        )))
91    }
92
93    fn invoke_with_args(
94        &self,
95        mut func_args: ScalarFunctionArgs,
96    ) -> Result<ColumnarValue> {
97        let array_len = func_args
98            .args
99            .iter()
100            .find_map(|arg| match arg {
101                ColumnarValue::Array(array) => Some(array.len()),
102                _ => None,
103            })
104            .unwrap_or(func_args.number_rows);
105
106        let arrays = func_args
107            .args
108            .iter()
109            .map(|arg| match arg {
110                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
111                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
112            })
113            .collect::<Result<Vec<_>>>()?;
114
115        let (start, end) = calculate_start_end(&arrays)?;
116
117        array_slice_udf().invoke_with_args(ScalarFunctionArgs {
118            args: vec![
119                func_args.args.swap_remove(0),
120                ColumnarValue::Array(start),
121                ColumnarValue::Array(end),
122            ],
123            arg_fields: func_args.arg_fields,
124            number_rows: func_args.number_rows,
125            return_field: func_args.return_field,
126            config_options: func_args.config_options,
127        })
128    }
129}
130
131fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
132    let [values, start, length] = take_function_args("slice", args)?;
133
134    let values_len = values.len();
135
136    let start = as_int64_array(&start)?;
137    let length = as_int64_array(&length)?;
138
139    let values = as_list_array(values)?;
140
141    let mut adjusted_start = Int64Builder::with_capacity(values_len);
142    let mut end = Int64Builder::with_capacity(values_len);
143
144    for row in 0..values_len {
145        if values.is_null(row) || start.is_null(row) || length.is_null(row) {
146            adjusted_start.append_null();
147            end.append_null();
148            continue;
149        }
150        let start = start.value(row);
151        let length = length.value(row);
152        let value_length = values.value(row).len() as i64;
153
154        if start == 0 {
155            return exec_err!("Start index must not be zero");
156        }
157        if length < 0 {
158            return exec_err!("Length must be non-negative, but got {}", length);
159        }
160
161        let adjusted_start_value = if start < 0 {
162            start + value_length + 1
163        } else {
164            start
165        };
166
167        adjusted_start.append_value(adjusted_start_value);
168        end.append_value(adjusted_start_value + (length - 1));
169    }
170
171    Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
172}