1use std::any::Any;
19use std::sync::Arc;
20
21use crate::strings::make_and_append_view;
22use crate::utils::make_scalar_function;
23use arrow::array::{
24 Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
25 StringViewArray, StringViewBuilder,
26};
27use arrow::buffer::ScalarBuffer;
28use arrow::datatypes::DataType;
29use datafusion_common::cast::as_int64_array;
30use datafusion_common::types::{
31 NativeType, logical_int32, logical_int64, logical_string,
32};
33use datafusion_common::{Result, exec_err};
34use datafusion_expr::{
35 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
36 TypeSignatureClass, Volatility,
37};
38use datafusion_macros::user_doc;
39
40#[user_doc(
41 doc_section(label = "String Functions"),
42 description = "Extracts a substring of a specified number of characters from a specific starting position in a string.",
43 syntax_example = "substr(str, start_pos[, length])",
44 alternative_syntax = "substring(str from start_pos for length)",
45 sql_example = r#"```sql
46> select substr('datafusion', 5, 3);
47+----------------------------------------------+
48| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
49+----------------------------------------------+
50| fus |
51+----------------------------------------------+
52```"#,
53 standard_argument(name = "str", prefix = "String"),
54 argument(
55 name = "start_pos",
56 description = "Character position to start the substring at. The first character in the string has a position of 1."
57 ),
58 argument(
59 name = "length",
60 description = "Number of characters to extract. If not specified, returns the rest of the string after the start position."
61 )
62)]
63#[derive(Debug, PartialEq, Eq, Hash)]
64pub struct SubstrFunc {
65 signature: Signature,
66 aliases: Vec<String>,
67}
68
69impl Default for SubstrFunc {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl SubstrFunc {
76 pub fn new() -> Self {
77 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
78 let int64 = Coercion::new_implicit(
79 TypeSignatureClass::Native(logical_int64()),
80 vec![TypeSignatureClass::Native(logical_int32())],
81 NativeType::Int64,
82 );
83 Self {
84 signature: Signature::one_of(
85 vec![
86 TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
87 TypeSignature::Coercible(vec![
88 string.clone(),
89 int64.clone(),
90 int64.clone(),
91 ]),
92 ],
93 Volatility::Immutable,
94 )
95 .with_parameter_names(vec![
96 "str".to_string(),
97 "start_pos".to_string(),
98 "length".to_string(),
99 ])
100 .expect("valid parameter names"),
101 aliases: vec![String::from("substring")],
102 }
103 }
104}
105
106impl ScalarUDFImpl for SubstrFunc {
107 fn as_any(&self) -> &dyn Any {
108 self
109 }
110
111 fn name(&self) -> &str {
112 "substr"
113 }
114
115 fn signature(&self) -> &Signature {
116 &self.signature
117 }
118
119 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
121 Ok(DataType::Utf8View)
122 }
123
124 fn invoke_with_args(
125 &self,
126 args: datafusion_expr::ScalarFunctionArgs,
127 ) -> Result<ColumnarValue> {
128 make_scalar_function(substr, vec![])(&args.args)
129 }
130
131 fn aliases(&self) -> &[String] {
132 &self.aliases
133 }
134
135 fn documentation(&self) -> Option<&Documentation> {
136 self.doc()
137 }
138}
139
140fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
145 match args[0].data_type() {
146 DataType::Utf8 => {
147 let string_array = args[0].as_string::<i32>();
148 string_substr::<_>(string_array, &args[1..])
149 }
150 DataType::LargeUtf8 => {
151 let string_array = args[0].as_string::<i64>();
152 string_substr::<_>(string_array, &args[1..])
153 }
154 DataType::Utf8View => {
155 let string_array = args[0].as_string_view();
156 string_view_substr(string_array, &args[1..])
157 }
158 other => exec_err!(
159 "Unsupported data type {other:?} for function substr,\
160 expected Utf8View, Utf8 or LargeUtf8."
161 ),
162 }
163}
164
165fn get_true_start_end(
180 input: &str,
181 start: i64,
182 count: Option<u64>,
183 is_input_ascii_only: bool,
184) -> (usize, usize) {
185 let start = start.checked_sub(1).unwrap_or(start);
186
187 let end = match count {
188 Some(count) => start + count as i64,
189 None => input.len() as i64,
190 };
191 let count_to_end = count.is_some();
192
193 let start = start.clamp(0, input.len() as i64) as usize;
194 let end = end.clamp(0, input.len() as i64) as usize;
195 let count = end - start;
196
197 if is_input_ascii_only {
199 return (start, end);
200 }
201
202 let (mut st, mut ed) = (input.len(), input.len());
206 let mut start_counting = false;
207 let mut cnt = 0;
208 for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
209 if char_cnt == start {
210 st = byte_cnt;
211 if count_to_end {
212 start_counting = true;
213 } else {
214 break;
215 }
216 }
217 if start_counting {
218 if cnt == count {
219 ed = byte_cnt;
220 break;
221 }
222 cnt += 1;
223 }
224 }
225 (st, ed)
226}
227
228fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
239 string_array: &V,
240 start: &Int64Array,
241 count: Option<&Int64Array>,
242) -> bool {
243 let is_short_prefix = match count {
244 Some(count) => {
245 let short_prefix_threshold = 32.0;
246 let n_sample = 10;
247
248 let avg_prefix_len = start
251 .iter()
252 .zip(count.iter())
253 .take(n_sample)
254 .map(|(start, count)| {
255 let start = start.unwrap_or(0);
256 let count = count.unwrap_or(0);
257 start + count
259 })
260 .sum::<i64>();
261
262 avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
263 }
264 None => false,
265 };
266
267 if is_short_prefix {
268 false
270 } else {
271 string_array.is_ascii()
272 }
273}
274
275fn string_view_substr(
278 string_view_array: &StringViewArray,
279 args: &[ArrayRef],
280) -> Result<ArrayRef> {
281 let mut views_buf = Vec::with_capacity(string_view_array.len());
282 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
283
284 let start_array = as_int64_array(&args[0])?;
285 let count_array_opt = if args.len() == 2 {
286 Some(as_int64_array(&args[1])?)
287 } else {
288 None
289 };
290
291 let enable_ascii_fast_path =
292 enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);
293
294 match args.len() {
297 1 => {
298 for ((str_opt, raw_view), start_opt) in string_view_array
299 .iter()
300 .zip(string_view_array.views().iter())
301 .zip(start_array.iter())
302 {
303 if let (Some(str), Some(start)) = (str_opt, start_opt) {
304 let (start, end) =
305 get_true_start_end(str, start, None, enable_ascii_fast_path);
306 let substr = &str[start..end];
307
308 make_and_append_view(
309 &mut views_buf,
310 &mut null_builder,
311 raw_view,
312 substr,
313 start as u32,
314 );
315 } else {
316 null_builder.append_null();
317 views_buf.push(0);
318 }
319 }
320 }
321 2 => {
322 let count_array = count_array_opt.unwrap();
323 for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
324 .iter()
325 .zip(string_view_array.views().iter())
326 .zip(start_array.iter())
327 .zip(count_array.iter())
328 {
329 if let (Some(str), Some(start), Some(count)) =
330 (str_opt, start_opt, count_opt)
331 {
332 if count < 0 {
333 return exec_err!(
334 "negative substring length not allowed: substr(<str>, {start}, {count})"
335 );
336 } else {
337 if start == i64::MIN {
338 return exec_err!(
339 "negative overflow when calculating skip value"
340 );
341 }
342 let (start, end) = get_true_start_end(
343 str,
344 start,
345 Some(count as u64),
346 enable_ascii_fast_path,
347 );
348 let substr = &str[start..end];
349
350 make_and_append_view(
351 &mut views_buf,
352 &mut null_builder,
353 raw_view,
354 substr,
355 start as u32,
356 );
357 }
358 } else {
359 null_builder.append_null();
360 views_buf.push(0);
361 }
362 }
363 }
364 other => {
365 return exec_err!(
366 "substr was called with {other} arguments. It requires 2 or 3."
367 );
368 }
369 }
370
371 let views_buf = ScalarBuffer::from(views_buf);
372 let nulls_buf = null_builder.finish();
373
374 unsafe {
379 let array = StringViewArray::new_unchecked(
380 views_buf,
381 string_view_array.data_buffers().to_vec(),
382 nulls_buf,
383 );
384 Ok(Arc::new(array) as ArrayRef)
385 }
386}
387
388fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
389where
390 V: StringArrayType<'a>,
391{
392 let start_array = as_int64_array(&args[0])?;
393 let count_array_opt = if args.len() == 2 {
394 Some(as_int64_array(&args[1])?)
395 } else {
396 None
397 };
398
399 let enable_ascii_fast_path =
400 enable_ascii_fast_path(&string_array, start_array, count_array_opt);
401
402 match args.len() {
403 1 => {
404 let iter = ArrayIter::new(string_array);
405 let mut result_builder = StringViewBuilder::new();
406 for (string, start) in iter.zip(start_array.iter()) {
407 match (string, start) {
408 (Some(string), Some(start)) => {
409 let (start, end) = get_true_start_end(
410 string,
411 start,
412 None,
413 enable_ascii_fast_path,
414 ); let substr = &string[start..end];
416 result_builder.append_value(substr);
417 }
418 _ => {
419 result_builder.append_null();
420 }
421 }
422 }
423 Ok(Arc::new(result_builder.finish()) as ArrayRef)
424 }
425 2 => {
426 let iter = ArrayIter::new(string_array);
427 let count_array = count_array_opt.unwrap();
428 let mut result_builder = StringViewBuilder::new();
429
430 for ((string, start), count) in
431 iter.zip(start_array.iter()).zip(count_array.iter())
432 {
433 match (string, start, count) {
434 (Some(string), Some(start), Some(count)) => {
435 if count < 0 {
436 return exec_err!(
437 "negative substring length not allowed: substr(<str>, {start}, {count})"
438 );
439 } else {
440 if start == i64::MIN {
441 return exec_err!(
442 "negative overflow when calculating skip value"
443 );
444 }
445 let (start, end) = get_true_start_end(
446 string,
447 start,
448 Some(count as u64),
449 enable_ascii_fast_path,
450 ); let substr = &string[start..end];
452 result_builder.append_value(substr);
453 }
454 }
455 _ => {
456 result_builder.append_null();
457 }
458 }
459 }
460 Ok(Arc::new(result_builder.finish()) as ArrayRef)
461 }
462 other => {
463 exec_err!("substr was called with {other} arguments. It requires 2 or 3.")
464 }
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use arrow::array::{Array, StringViewArray};
471 use arrow::datatypes::DataType::Utf8View;
472
473 use datafusion_common::{Result, ScalarValue, exec_err};
474 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
475
476 use crate::unicode::substr::SubstrFunc;
477 use crate::utils::test::test_function;
478
479 #[test]
480 fn test_functions() -> Result<()> {
481 test_function!(
482 SubstrFunc::new(),
483 vec![
484 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
485 ColumnarValue::Scalar(ScalarValue::from(1i64)),
486 ],
487 Ok(None),
488 &str,
489 Utf8View,
490 StringViewArray
491 );
492 test_function!(
493 SubstrFunc::new(),
494 vec![
495 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
496 "alphabet"
497 )))),
498 ColumnarValue::Scalar(ScalarValue::from(0i64)),
499 ],
500 Ok(Some("alphabet")),
501 &str,
502 Utf8View,
503 StringViewArray
504 );
505 test_function!(
506 SubstrFunc::new(),
507 vec![
508 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
509 "this és longer than 12B"
510 )))),
511 ColumnarValue::Scalar(ScalarValue::from(5i64)),
512 ColumnarValue::Scalar(ScalarValue::from(2i64)),
513 ],
514 Ok(Some(" é")),
515 &str,
516 Utf8View,
517 StringViewArray
518 );
519 test_function!(
520 SubstrFunc::new(),
521 vec![
522 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
523 "this is longer than 12B"
524 )))),
525 ColumnarValue::Scalar(ScalarValue::from(5i64)),
526 ],
527 Ok(Some(" is longer than 12B")),
528 &str,
529 Utf8View,
530 StringViewArray
531 );
532 test_function!(
533 SubstrFunc::new(),
534 vec![
535 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
536 "joséésoj"
537 )))),
538 ColumnarValue::Scalar(ScalarValue::from(5i64)),
539 ],
540 Ok(Some("ésoj")),
541 &str,
542 Utf8View,
543 StringViewArray
544 );
545 test_function!(
546 SubstrFunc::new(),
547 vec![
548 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
549 "alphabet"
550 )))),
551 ColumnarValue::Scalar(ScalarValue::from(3i64)),
552 ColumnarValue::Scalar(ScalarValue::from(2i64)),
553 ],
554 Ok(Some("ph")),
555 &str,
556 Utf8View,
557 StringViewArray
558 );
559 test_function!(
560 SubstrFunc::new(),
561 vec![
562 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
563 "alphabet"
564 )))),
565 ColumnarValue::Scalar(ScalarValue::from(3i64)),
566 ColumnarValue::Scalar(ScalarValue::from(20i64)),
567 ],
568 Ok(Some("phabet")),
569 &str,
570 Utf8View,
571 StringViewArray
572 );
573 test_function!(
574 SubstrFunc::new(),
575 vec![
576 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
577 ColumnarValue::Scalar(ScalarValue::from(0i64)),
578 ],
579 Ok(Some("alphabet")),
580 &str,
581 Utf8View,
582 StringViewArray
583 );
584 test_function!(
585 SubstrFunc::new(),
586 vec![
587 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
588 ColumnarValue::Scalar(ScalarValue::from(5i64)),
589 ],
590 Ok(Some("ésoj")),
591 &str,
592 Utf8View,
593 StringViewArray
594 );
595 test_function!(
596 SubstrFunc::new(),
597 vec![
598 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
599 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
600 ],
601 Ok(Some("joséésoj")),
602 &str,
603 Utf8View,
604 StringViewArray
605 );
606 test_function!(
607 SubstrFunc::new(),
608 vec![
609 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
610 ColumnarValue::Scalar(ScalarValue::from(1i64)),
611 ],
612 Ok(Some("alphabet")),
613 &str,
614 Utf8View,
615 StringViewArray
616 );
617 test_function!(
618 SubstrFunc::new(),
619 vec![
620 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
621 ColumnarValue::Scalar(ScalarValue::from(2i64)),
622 ],
623 Ok(Some("lphabet")),
624 &str,
625 Utf8View,
626 StringViewArray
627 );
628 test_function!(
629 SubstrFunc::new(),
630 vec![
631 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
632 ColumnarValue::Scalar(ScalarValue::from(3i64)),
633 ],
634 Ok(Some("phabet")),
635 &str,
636 Utf8View,
637 StringViewArray
638 );
639 test_function!(
640 SubstrFunc::new(),
641 vec![
642 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
643 ColumnarValue::Scalar(ScalarValue::from(-3i64)),
644 ],
645 Ok(Some("alphabet")),
646 &str,
647 Utf8View,
648 StringViewArray
649 );
650 test_function!(
651 SubstrFunc::new(),
652 vec![
653 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
654 ColumnarValue::Scalar(ScalarValue::from(30i64)),
655 ],
656 Ok(Some("")),
657 &str,
658 Utf8View,
659 StringViewArray
660 );
661 test_function!(
662 SubstrFunc::new(),
663 vec![
664 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
665 ColumnarValue::Scalar(ScalarValue::Int64(None)),
666 ],
667 Ok(None),
668 &str,
669 Utf8View,
670 StringViewArray
671 );
672 test_function!(
673 SubstrFunc::new(),
674 vec![
675 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
676 ColumnarValue::Scalar(ScalarValue::from(3i64)),
677 ColumnarValue::Scalar(ScalarValue::from(2i64)),
678 ],
679 Ok(Some("ph")),
680 &str,
681 Utf8View,
682 StringViewArray
683 );
684 test_function!(
685 SubstrFunc::new(),
686 vec![
687 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
688 ColumnarValue::Scalar(ScalarValue::from(3i64)),
689 ColumnarValue::Scalar(ScalarValue::from(20i64)),
690 ],
691 Ok(Some("phabet")),
692 &str,
693 Utf8View,
694 StringViewArray
695 );
696 test_function!(
697 SubstrFunc::new(),
698 vec![
699 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
700 ColumnarValue::Scalar(ScalarValue::from(0i64)),
701 ColumnarValue::Scalar(ScalarValue::from(5i64)),
702 ],
703 Ok(Some("alph")),
704 &str,
705 Utf8View,
706 StringViewArray
707 );
708 test_function!(
710 SubstrFunc::new(),
711 vec![
712 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
713 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
714 ColumnarValue::Scalar(ScalarValue::from(10i64)),
715 ],
716 Ok(Some("alph")),
717 &str,
718 Utf8View,
719 StringViewArray
720 );
721 test_function!(
723 SubstrFunc::new(),
724 vec![
725 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
726 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
727 ColumnarValue::Scalar(ScalarValue::from(4i64)),
728 ],
729 Ok(Some("")),
730 &str,
731 Utf8View,
732 StringViewArray
733 );
734 test_function!(
736 SubstrFunc::new(),
737 vec![
738 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
739 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
740 ColumnarValue::Scalar(ScalarValue::from(5i64)),
741 ],
742 Ok(Some("")),
743 &str,
744 Utf8View,
745 StringViewArray
746 );
747 test_function!(
748 SubstrFunc::new(),
749 vec![
750 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
751 ColumnarValue::Scalar(ScalarValue::Int64(None)),
752 ColumnarValue::Scalar(ScalarValue::from(20i64)),
753 ],
754 Ok(None),
755 &str,
756 Utf8View,
757 StringViewArray
758 );
759 test_function!(
760 SubstrFunc::new(),
761 vec![
762 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
763 ColumnarValue::Scalar(ScalarValue::from(3i64)),
764 ColumnarValue::Scalar(ScalarValue::Int64(None)),
765 ],
766 Ok(None),
767 &str,
768 Utf8View,
769 StringViewArray
770 );
771 test_function!(
772 SubstrFunc::new(),
773 vec![
774 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
775 ColumnarValue::Scalar(ScalarValue::from(1i64)),
776 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
777 ],
778 exec_err!("negative substring length not allowed: substr(<str>, 1, -1)"),
779 &str,
780 Utf8View,
781 StringViewArray
782 );
783 test_function!(
784 SubstrFunc::new(),
785 vec![
786 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
787 ColumnarValue::Scalar(ScalarValue::from(5i64)),
788 ColumnarValue::Scalar(ScalarValue::from(2i64)),
789 ],
790 Ok(Some("és")),
791 &str,
792 Utf8View,
793 StringViewArray
794 );
795 #[cfg(not(feature = "unicode_expressions"))]
796 test_function!(
797 SubstrFunc::new(),
798 &[
799 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
800 ColumnarValue::Scalar(ScalarValue::from(0i64)),
801 ],
802 internal_err!(
803 "function substr requires compilation with feature flag: unicode_expressions."
804 ),
805 &str,
806 Utf8View,
807 StringViewArray
808 );
809 test_function!(
810 SubstrFunc::new(),
811 vec![
812 ColumnarValue::Scalar(ScalarValue::from("abc")),
813 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
814 ],
815 Ok(Some("abc")),
816 &str,
817 Utf8View,
818 StringViewArray
819 );
820 test_function!(
821 SubstrFunc::new(),
822 vec![
823 ColumnarValue::Scalar(ScalarValue::from("overflow")),
824 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
825 ColumnarValue::Scalar(ScalarValue::from(1i64)),
826 ],
827 exec_err!("negative overflow when calculating skip value"),
828 &str,
829 Utf8View,
830 StringViewArray
831 );
832
833 Ok(())
834 }
835}