datafusion_spark/function/array/
slice.rs1use arrow::array::{Array, ArrayRef, Int64Builder};
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::cast::{as_int64_array, as_list_array};
21use datafusion_common::utils::ListCoercion;
22use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
23use datafusion_expr::{
24 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
25 ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
26};
27use datafusion_functions_nested::extract::array_slice_udf;
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkSlice {
36 signature: Signature,
37}
38
39impl Default for SparkSlice {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkSlice {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature {
49 type_signature: TypeSignature::ArraySignature(
50 ArrayFunctionSignature::Array {
51 arguments: vec![
52 ArrayFunctionArgument::Array,
53 ArrayFunctionArgument::Index,
54 ArrayFunctionArgument::Index,
55 ],
56 array_coercion: Some(ListCoercion::FixedSizedListToList),
57 },
58 ),
59 volatility: Volatility::Immutable,
60 parameter_names: None,
61 },
62 }
63 }
64}
65
66impl ScalarUDFImpl for SparkSlice {
67 fn as_any(&self) -> &dyn Any {
68 self
69 }
70
71 fn name(&self) -> &str {
72 "slice"
73 }
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80 internal_err!("return_field_from_args should be used instead")
81 }
82
83 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
84 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
85
86 Ok(Arc::new(Field::new(
87 "slice",
88 args.arg_fields[0].data_type().clone(),
89 nullable,
90 )))
91 }
92
93 fn invoke_with_args(
94 &self,
95 mut func_args: ScalarFunctionArgs,
96 ) -> Result<ColumnarValue> {
97 let array_len = func_args
98 .args
99 .iter()
100 .find_map(|arg| match arg {
101 ColumnarValue::Array(array) => Some(array.len()),
102 _ => None,
103 })
104 .unwrap_or(func_args.number_rows);
105
106 let arrays = func_args
107 .args
108 .iter()
109 .map(|arg| match arg {
110 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
111 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
112 })
113 .collect::<Result<Vec<_>>>()?;
114
115 let (start, end) = calculate_start_end(&arrays)?;
116
117 array_slice_udf().invoke_with_args(ScalarFunctionArgs {
118 args: vec![
119 func_args.args.swap_remove(0),
120 ColumnarValue::Array(start),
121 ColumnarValue::Array(end),
122 ],
123 arg_fields: func_args.arg_fields,
124 number_rows: func_args.number_rows,
125 return_field: func_args.return_field,
126 config_options: func_args.config_options,
127 })
128 }
129}
130
131fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
132 let [values, start, length] = take_function_args("slice", args)?;
133
134 let values_len = values.len();
135
136 let start = as_int64_array(&start)?;
137 let length = as_int64_array(&length)?;
138
139 let values = as_list_array(values)?;
140
141 let mut adjusted_start = Int64Builder::with_capacity(values_len);
142 let mut end = Int64Builder::with_capacity(values_len);
143
144 for row in 0..values_len {
145 if values.is_null(row) || start.is_null(row) || length.is_null(row) {
146 adjusted_start.append_null();
147 end.append_null();
148 continue;
149 }
150 let start = start.value(row);
151 let length = length.value(row);
152 let value_length = values.value(row).len() as i64;
153
154 if start == 0 {
155 return exec_err!("Start index must not be zero");
156 }
157 if length < 0 {
158 return exec_err!("Length must be non-negative, but got {}", length);
159 }
160
161 let adjusted_start_value = if start < 0 {
162 start + value_length + 1
163 } else {
164 start
165 };
166
167 adjusted_start.append_value(adjusted_start_value);
168 end.append_value(adjusted_start_value + (length - 1));
169 }
170
171 Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
172}