1use arrow::array::{
19 Array, ArrayAccessor, ArrayBuilder, ArrayRef, AsArray, BinaryViewBuilder,
20 GenericBinaryBuilder, GenericStringBuilder, Int64Array, OffsetSizeTrait,
21 StringViewBuilder,
22};
23use arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::{Field, FieldRef};
25use datafusion_common::cast::as_int64_array;
26use datafusion_common::types::{
27 NativeType, logical_int32, logical_int64, logical_string,
28};
29use datafusion_common::{Result, exec_err};
30use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass};
31use datafusion_expr::{
32 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
33 Volatility,
34};
35use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end};
36use datafusion_functions::utils::make_scalar_function;
37use std::sync::Arc;
38
39#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkSubstring {
47 signature: Signature,
48 aliases: Vec<String>,
49}
50
51impl Default for SparkSubstring {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl SparkSubstring {
58 pub fn new() -> Self {
59 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
60 let binary = Coercion::new_exact(TypeSignatureClass::Binary);
61 let int64 = Coercion::new_implicit(
62 TypeSignatureClass::Native(logical_int64()),
63 vec![TypeSignatureClass::Native(logical_int32())],
64 NativeType::Int64,
65 );
66 Self {
67 signature: Signature::one_of(
68 vec![
69 TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
70 TypeSignature::Coercible(vec![
71 string.clone(),
72 int64.clone(),
73 int64.clone(),
74 ]),
75 TypeSignature::Coercible(vec![binary.clone(), int64.clone()]),
76 TypeSignature::Coercible(vec![
77 binary.clone(),
78 int64.clone(),
79 int64.clone(),
80 ]),
81 ],
82 Volatility::Immutable,
83 )
84 .with_parameter_names(vec![
85 "str".to_string(),
86 "pos".to_string(),
87 "length".to_string(),
88 ])
89 .expect("valid parameter names"),
90 aliases: vec![String::from("substr")],
91 }
92 }
93}
94
95impl ScalarUDFImpl for SparkSubstring {
96 fn name(&self) -> &str {
97 "substring"
98 }
99
100 fn signature(&self) -> &Signature {
101 &self.signature
102 }
103
104 fn aliases(&self) -> &[String] {
105 &self.aliases
106 }
107
108 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 make_scalar_function(spark_substring, vec![])(&args.args)
110 }
111
112 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113 datafusion_common::internal_err!(
114 "return_type should not be called for Spark substring"
115 )
116 }
117
118 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
119 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
121
122 Ok(Arc::new(Field::new(
123 "substring",
124 args.arg_fields[0].data_type().clone(),
125 nullable,
126 )))
127 }
128}
129
130fn spark_substring(args: &[ArrayRef]) -> Result<ArrayRef> {
131 let start_array = as_int64_array(&args[1])?;
132 let length_array = if args.len() > 2 {
133 Some(as_int64_array(&args[2])?)
134 } else {
135 None
136 };
137
138 match args[0].data_type() {
139 DataType::Utf8 => {
140 let array = args[0].as_string::<i32>();
141 let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
142 spark_substring_generic(
143 &array,
144 start_array,
145 length_array,
146 GenericStringBuilder::<i32>::new(),
147 is_ascii,
148 )
149 }
150 DataType::LargeUtf8 => {
151 let array = args[0].as_string::<i64>();
152 let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
153 spark_substring_generic(
154 &array,
155 start_array,
156 length_array,
157 GenericStringBuilder::<i64>::new(),
158 is_ascii,
159 )
160 }
161 DataType::Utf8View => {
162 let array = args[0].as_string_view();
163 let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
164 spark_substring_generic(
165 &array,
166 start_array,
167 length_array,
168 StringViewBuilder::new(),
169 is_ascii,
170 )
171 }
172 DataType::Binary => spark_substring_generic(
176 &args[0].as_binary::<i32>(),
177 start_array,
178 length_array,
179 GenericBinaryBuilder::<i32>::new(),
180 true,
181 ),
182 DataType::LargeBinary => spark_substring_generic(
183 &args[0].as_binary::<i64>(),
184 start_array,
185 length_array,
186 GenericBinaryBuilder::<i64>::new(),
187 true,
188 ),
189 DataType::BinaryView => spark_substring_generic(
190 &args[0].as_binary_view(),
191 start_array,
192 length_array,
193 BinaryViewBuilder::new(),
194 true,
195 ),
196 other => exec_err!(
197 "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8, LargeUtf8, Binary, LargeBinary or BinaryView."
198 ),
199 }
200}
201
202#[inline]
214fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 {
215 if start >= 0 {
216 start.max(1)
217 } else {
218 let len_i64 = i64::try_from(len).unwrap_or(i64::MAX);
219 start + len_i64 + 1
220 }
221}
222
223trait SubstringItem {
224 fn positional_len(&self, is_ascii: bool) -> usize;
228
229 fn byte_range(
232 &self,
233 adjusted_start: i64,
234 len: Option<i64>,
235 is_ascii: bool,
236 ) -> Result<(usize, usize)>;
237
238 fn byte_slice(&self, start: usize, end: usize) -> &Self;
239}
240
241impl SubstringItem for str {
242 fn positional_len(&self, is_ascii: bool) -> usize {
243 if is_ascii {
244 self.len()
245 } else {
246 self.chars().count()
247 }
248 }
249
250 fn byte_range(
251 &self,
252 adjusted_start: i64,
253 len: Option<i64>,
254 is_ascii: bool,
255 ) -> Result<(usize, usize)> {
256 get_true_start_end(self, adjusted_start, len, is_ascii)
257 }
258
259 fn byte_slice(&self, start: usize, end: usize) -> &Self {
260 &self[start..end]
261 }
262}
263
264impl SubstringItem for [u8] {
265 fn positional_len(&self, _is_ascii: bool) -> usize {
266 self.len()
267 }
268
269 fn byte_range(
270 &self,
271 adjusted_start: i64,
272 len: Option<i64>,
273 _is_ascii: bool,
274 ) -> Result<(usize, usize)> {
275 let byte_len = self.len();
276 let start0 = adjusted_start.saturating_sub(1);
277 let end0 = match len {
278 Some(l) => start0.saturating_add(l),
279 None => byte_len as i64,
280 };
281 let byte_len_i64 = byte_len as i64;
282 Ok((
283 start0.clamp(0, byte_len_i64) as usize,
284 end0.clamp(0, byte_len_i64) as usize,
285 ))
286 }
287
288 fn byte_slice(&self, start: usize, end: usize) -> &Self {
289 &self[start..end]
290 }
291}
292
293trait SubstringBuilder: ArrayBuilder {
294 type Item: SubstringItem + ?Sized;
295 fn append_value(&mut self, val: &Self::Item);
296 fn append_null(&mut self);
297 fn append_empty(&mut self);
300}
301
302impl<O: OffsetSizeTrait> SubstringBuilder for GenericStringBuilder<O> {
303 type Item = str;
304 fn append_value(&mut self, val: &str) {
305 GenericStringBuilder::append_value(self, val);
306 }
307 fn append_null(&mut self) {
308 GenericStringBuilder::append_null(self);
309 }
310 fn append_empty(&mut self) {
311 GenericStringBuilder::append_value(self, "");
312 }
313}
314
315impl SubstringBuilder for StringViewBuilder {
316 type Item = str;
317 fn append_value(&mut self, val: &str) {
318 StringViewBuilder::append_value(self, val);
319 }
320 fn append_null(&mut self) {
321 StringViewBuilder::append_null(self);
322 }
323 fn append_empty(&mut self) {
324 StringViewBuilder::append_value(self, "");
325 }
326}
327
328impl<O: OffsetSizeTrait> SubstringBuilder for GenericBinaryBuilder<O> {
329 type Item = [u8];
330 fn append_value(&mut self, val: &[u8]) {
331 GenericBinaryBuilder::append_value(self, val);
332 }
333 fn append_null(&mut self) {
334 GenericBinaryBuilder::append_null(self);
335 }
336 fn append_empty(&mut self) {
337 GenericBinaryBuilder::append_value(self, &[]);
338 }
339}
340
341impl SubstringBuilder for BinaryViewBuilder {
342 type Item = [u8];
343 fn append_value(&mut self, val: &[u8]) {
344 BinaryViewBuilder::append_value(self, val);
345 }
346 fn append_null(&mut self) {
347 BinaryViewBuilder::append_null(self);
348 }
349 fn append_empty(&mut self) {
350 BinaryViewBuilder::append_value(self, []);
351 }
352}
353
354fn spark_substring_generic<'a, Source, Item, Builder>(
360 array: &Source,
361 start_array: &Int64Array,
362 length_array: Option<&Int64Array>,
363 mut builder: Builder,
364 is_ascii: bool,
365) -> Result<ArrayRef>
366where
367 Source: ArrayAccessor<Item = &'a Item>,
368 Item: SubstringItem + ?Sized + 'a,
369 Builder: SubstringBuilder<Item = Item>,
370{
371 for i in 0..array.len() {
372 if array.is_null(i) || start_array.is_null(i) {
373 builder.append_null();
374 continue;
375 }
376
377 if let Some(len_arr) = length_array
378 && len_arr.is_null(i)
379 {
380 builder.append_null();
381 continue;
382 }
383
384 let value = array.value(i);
385 let start = start_array.value(i);
386 let len_opt = length_array.map(|arr| arr.value(i));
387
388 if let Some(len) = len_opt
390 && len < 0
391 {
392 builder.append_empty();
393 continue;
394 }
395
396 let positional_len = value.positional_len(is_ascii);
397 let adjusted_start = spark_start_to_datafusion_start(start, positional_len);
398 let (byte_start, byte_end) =
399 value.byte_range(adjusted_start, len_opt, is_ascii)?;
400 builder.append_value(value.byte_slice(byte_start, byte_end));
401 }
402
403 Ok(builder.finish())
404}