datafusion_spark/function/array/
slice.rs1use arrow::array::{Array, ArrayRef, Int64Builder};
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::cast::{as_int64_array, as_list_array};
21use datafusion_common::utils::ListCoercion;
22use datafusion_common::{
23 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
24};
25use datafusion_expr::{
26 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
27 ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
28};
29use datafusion_functions_nested::extract::array_slice_udf;
30use std::sync::Arc;
31
32#[derive(Debug, PartialEq, Eq, Hash)]
36pub struct SparkSlice {
37 signature: Signature,
38}
39
40impl Default for SparkSlice {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkSlice {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature {
50 type_signature: TypeSignature::ArraySignature(
51 ArrayFunctionSignature::Array {
52 arguments: vec![
53 ArrayFunctionArgument::Array,
54 ArrayFunctionArgument::Index,
55 ArrayFunctionArgument::Index,
56 ],
57 array_coercion: Some(ListCoercion::FixedSizedListToList),
58 },
59 ),
60 volatility: Volatility::Immutable,
61 parameter_names: None,
62 },
63 }
64 }
65}
66
67impl ScalarUDFImpl for SparkSlice {
68 fn name(&self) -> &str {
69 "slice"
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77 internal_err!("return_field_from_args should be used instead")
78 }
79
80 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
81 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
82
83 let data_type = match args.arg_fields[0].data_type() {
84 DataType::Null => {
85 DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
86 }
87 dt => dt.clone(),
88 };
89
90 Ok(Arc::new(Field::new("slice", data_type, nullable)))
91 }
92
93 fn invoke_with_args(
94 &self,
95 mut func_args: ScalarFunctionArgs,
96 ) -> Result<ColumnarValue> {
97 if func_args.args[0].data_type() == DataType::Null {
98 return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
99 DataType::Null,
100 true,
101 1,
102 )));
103 }
104
105 let array_len = func_args
106 .args
107 .iter()
108 .find_map(|arg| match arg {
109 ColumnarValue::Array(array) => Some(array.len()),
110 _ => None,
111 })
112 .unwrap_or(func_args.number_rows);
113
114 let arrays = func_args
115 .args
116 .iter()
117 .map(|arg| match arg {
118 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
119 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
120 })
121 .collect::<Result<Vec<_>>>()?;
122
123 let (start, end) = calculate_start_end(&arrays)?;
124
125 array_slice_udf().invoke_with_args(ScalarFunctionArgs {
126 args: vec![
127 func_args.args.swap_remove(0),
128 ColumnarValue::Array(start),
129 ColumnarValue::Array(end),
130 ],
131 arg_fields: func_args.arg_fields,
132 number_rows: func_args.number_rows,
133 return_field: func_args.return_field,
134 config_options: func_args.config_options,
135 })
136 }
137}
138
139fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
140 let [values, start, length] = take_function_args("slice", args)?;
141
142 let values_len = values.len();
143
144 let start = as_int64_array(&start)?;
145 let length = as_int64_array(&length)?;
146
147 let values = as_list_array(values)?;
148
149 let mut adjusted_start = Int64Builder::with_capacity(values_len);
150 let mut end = Int64Builder::with_capacity(values_len);
151
152 for row in 0..values_len {
153 if values.is_null(row) || start.is_null(row) || length.is_null(row) {
154 adjusted_start.append_null();
155 end.append_null();
156 continue;
157 }
158 let start = start.value(row);
159 let length = length.value(row);
160 let value_length = values.value(row).len() as i64;
161
162 if start == 0 {
163 return exec_err!("Start index must not be zero");
164 }
165 if length < 0 {
166 return exec_err!("Length must be non-negative, but got {}", length);
167 }
168
169 let adjusted_start_value = if start < 0 {
170 start + value_length + 1
171 } else {
172 start
173 };
174
175 if adjusted_start_value < 1 {
179 adjusted_start.append_value(1);
180 end.append_value(0);
181 continue;
182 }
183
184 adjusted_start.append_value(adjusted_start_value);
185 end.append_value(adjusted_start_value + (length - 1));
186 }
187
188 Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use arrow::array::NullArray;
195 use arrow::datatypes::Field;
196 use datafusion_common::ScalarValue;
197 use datafusion_common::cast::as_list_array;
198 use datafusion_expr::ReturnFieldArgs;
199
200 #[test]
201 fn test_spark_slice_function_when_input_is_null() {
202 let slice = SparkSlice::new();
203 let arg_fields: Vec<Arc<Field>> = vec![
204 Arc::new(Field::new("a", DataType::Null, true)),
205 Arc::new(Field::new("s", DataType::Int64, true)),
206 Arc::new(Field::new("l", DataType::Int64, true)),
207 ];
208 let out = slice
209 .return_field_from_args(ReturnFieldArgs {
210 arg_fields: &arg_fields,
211 scalar_arguments: &[],
212 })
213 .unwrap();
214 assert_eq!(
215 out.data_type(),
216 &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
217 );
218 }
219
220 #[test]
221 fn test_spark_slice_function_when_input_array_is_null() {
222 let input_args = vec![
223 ColumnarValue::Array(Arc::new(NullArray::new(1))),
224 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
225 ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
226 ];
227
228 let args = ScalarFunctionArgs {
229 args: input_args,
230 arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))],
231 number_rows: 1,
232 return_field: Arc::new(Field::new(
233 "slice",
234 DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))),
235 true,
236 )),
237 config_options: Arc::new(Default::default()),
238 };
239 let slice = SparkSlice::new();
240 let result = slice.invoke_with_args(args).unwrap();
241 let arr = result.to_array(1).unwrap();
242 let list = as_list_array(&arr).unwrap();
243 assert_eq!(
244 arr.data_type(),
245 &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
246 );
247 assert!(list.is_null(0));
248 }
249}