1use std::any::Any;
19use std::sync::Arc;
20
21use crate::strings::make_and_append_view;
22use crate::utils::make_scalar_function;
23use arrow::array::{
24 Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
25 StringViewArray, StringViewBuilder,
26};
27use arrow::buffer::ScalarBuffer;
28use arrow::datatypes::DataType;
29use datafusion_common::cast::as_int64_array;
30use datafusion_common::{exec_err, plan_err, Result};
31use datafusion_expr::{
32 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Extracts a substring of a specified number of characters from a specific starting position in a string.",
39 syntax_example = "substr(str, start_pos[, length])",
40 alternative_syntax = "substring(str from start_pos for length)",
41 sql_example = r#"```sql
42> select substr('datafusion', 5, 3);
43+----------------------------------------------+
44| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
45+----------------------------------------------+
46| fus |
47+----------------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(
51 name = "start_pos",
52 description = "Character position to start the substring at. The first character in the string has a position of 1."
53 ),
54 argument(
55 name = "length",
56 description = "Number of characters to extract. If not specified, returns the rest of the string after the start position."
57 )
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub struct SubstrFunc {
61 signature: Signature,
62 aliases: Vec<String>,
63}
64
65impl Default for SubstrFunc {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl SubstrFunc {
72 pub fn new() -> Self {
73 Self {
74 signature: Signature::user_defined(Volatility::Immutable)
75 .with_parameter_names(vec![
76 "str".to_string(),
77 "start_pos".to_string(),
78 "length".to_string(),
79 ])
80 .expect("valid parameter names"),
81 aliases: vec![String::from("substring")],
82 }
83 }
84}
85
86impl ScalarUDFImpl for SubstrFunc {
87 fn as_any(&self) -> &dyn Any {
88 self
89 }
90
91 fn name(&self) -> &str {
92 "substr"
93 }
94
95 fn signature(&self) -> &Signature {
96 &self.signature
97 }
98
99 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
101 Ok(DataType::Utf8View)
102 }
103
104 fn invoke_with_args(
105 &self,
106 args: datafusion_expr::ScalarFunctionArgs,
107 ) -> Result<ColumnarValue> {
108 make_scalar_function(substr, vec![])(&args.args)
109 }
110
111 fn aliases(&self) -> &[String] {
112 &self.aliases
113 }
114
115 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
116 if arg_types.len() < 2 || arg_types.len() > 3 {
117 return plan_err!(
118 "The {} function requires 2 or 3 arguments, but got {}.",
119 self.name(),
120 arg_types.len()
121 );
122 }
123 let first_data_type = match &arg_types[0] {
124 DataType::Null => Ok(DataType::Utf8),
125 DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()),
126 DataType::Dictionary(key_type, value_type) => {
127 if key_type.is_integer() {
128 match value_type.as_ref() {
129 DataType::Null => Ok(DataType::Utf8),
130 DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()),
131 _ => plan_err!(
132 "The first argument of the {} function can only be a string, but got {:?}.",
133 self.name(),
134 arg_types[0]
135 ),
136 }
137 } else {
138 plan_err!(
139 "The first argument of the {} function can only be a string, but got {:?}.",
140 self.name(),
141 arg_types[0]
142 )
143 }
144 }
145 _ => plan_err!(
146 "The first argument of the {} function can only be a string, but got {:?}.",
147 self.name(),
148 arg_types[0]
149 )
150 }?;
151
152 if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) {
153 return plan_err!(
154 "The second argument of the {} function can only be an integer, but got {:?}.",
155 self.name(),
156 arg_types[1]
157 );
158 }
159
160 if arg_types.len() == 3
161 && ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2])
162 {
163 return plan_err!(
164 "The third argument of the {} function can only be an integer, but got {:?}.",
165 self.name(),
166 arg_types[2]
167 );
168 }
169
170 if arg_types.len() == 2 {
171 Ok(vec![first_data_type.to_owned(), DataType::Int64])
172 } else {
173 Ok(vec![
174 first_data_type.to_owned(),
175 DataType::Int64,
176 DataType::Int64,
177 ])
178 }
179 }
180
181 fn documentation(&self) -> Option<&Documentation> {
182 self.doc()
183 }
184}
185
186pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
191 match args[0].data_type() {
192 DataType::Utf8 => {
193 let string_array = args[0].as_string::<i32>();
194 string_substr::<_>(string_array, &args[1..])
195 }
196 DataType::LargeUtf8 => {
197 let string_array = args[0].as_string::<i64>();
198 string_substr::<_>(string_array, &args[1..])
199 }
200 DataType::Utf8View => {
201 let string_array = args[0].as_string_view();
202 string_view_substr(string_array, &args[1..])
203 }
204 other => exec_err!(
205 "Unsupported data type {other:?} for function substr,\
206 expected Utf8View, Utf8 or LargeUtf8."
207 ),
208 }
209}
210
211fn get_true_start_end(
226 input: &str,
227 start: i64,
228 count: Option<u64>,
229 is_input_ascii_only: bool,
230) -> (usize, usize) {
231 let start = start.checked_sub(1).unwrap_or(start);
232
233 let end = match count {
234 Some(count) => start + count as i64,
235 None => input.len() as i64,
236 };
237 let count_to_end = count.is_some();
238
239 let start = start.clamp(0, input.len() as i64) as usize;
240 let end = end.clamp(0, input.len() as i64) as usize;
241 let count = end - start;
242
243 if is_input_ascii_only {
245 return (start, end);
246 }
247
248 let (mut st, mut ed) = (input.len(), input.len());
252 let mut start_counting = false;
253 let mut cnt = 0;
254 for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
255 if char_cnt == start {
256 st = byte_cnt;
257 if count_to_end {
258 start_counting = true;
259 } else {
260 break;
261 }
262 }
263 if start_counting {
264 if cnt == count {
265 ed = byte_cnt;
266 break;
267 }
268 cnt += 1;
269 }
270 }
271 (st, ed)
272}
273
274fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
285 string_array: &V,
286 start: &Int64Array,
287 count: Option<&Int64Array>,
288) -> bool {
289 let is_short_prefix = match count {
290 Some(count) => {
291 let short_prefix_threshold = 32.0;
292 let n_sample = 10;
293
294 let avg_prefix_len = start
297 .iter()
298 .zip(count.iter())
299 .take(n_sample)
300 .map(|(start, count)| {
301 let start = start.unwrap_or(0);
302 let count = count.unwrap_or(0);
303 start + count
305 })
306 .sum::<i64>();
307
308 avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
309 }
310 None => false,
311 };
312
313 if is_short_prefix {
314 false
316 } else {
317 string_array.is_ascii()
318 }
319}
320
321fn string_view_substr(
324 string_view_array: &StringViewArray,
325 args: &[ArrayRef],
326) -> Result<ArrayRef> {
327 let mut views_buf = Vec::with_capacity(string_view_array.len());
328 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
329
330 let start_array = as_int64_array(&args[0])?;
331 let count_array_opt = if args.len() == 2 {
332 Some(as_int64_array(&args[1])?)
333 } else {
334 None
335 };
336
337 let enable_ascii_fast_path =
338 enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);
339
340 match args.len() {
343 1 => {
344 for ((str_opt, raw_view), start_opt) in string_view_array
345 .iter()
346 .zip(string_view_array.views().iter())
347 .zip(start_array.iter())
348 {
349 if let (Some(str), Some(start)) = (str_opt, start_opt) {
350 let (start, end) =
351 get_true_start_end(str, start, None, enable_ascii_fast_path);
352 let substr = &str[start..end];
353
354 make_and_append_view(
355 &mut views_buf,
356 &mut null_builder,
357 raw_view,
358 substr,
359 start as u32,
360 );
361 } else {
362 null_builder.append_null();
363 views_buf.push(0);
364 }
365 }
366 }
367 2 => {
368 let count_array = count_array_opt.unwrap();
369 for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
370 .iter()
371 .zip(string_view_array.views().iter())
372 .zip(start_array.iter())
373 .zip(count_array.iter())
374 {
375 if let (Some(str), Some(start), Some(count)) =
376 (str_opt, start_opt, count_opt)
377 {
378 if count < 0 {
379 return exec_err!(
380 "negative substring length not allowed: substr(<str>, {start}, {count})"
381 );
382 } else {
383 if start == i64::MIN {
384 return exec_err!(
385 "negative overflow when calculating skip value"
386 );
387 }
388 let (start, end) = get_true_start_end(
389 str,
390 start,
391 Some(count as u64),
392 enable_ascii_fast_path,
393 );
394 let substr = &str[start..end];
395
396 make_and_append_view(
397 &mut views_buf,
398 &mut null_builder,
399 raw_view,
400 substr,
401 start as u32,
402 );
403 }
404 } else {
405 null_builder.append_null();
406 views_buf.push(0);
407 }
408 }
409 }
410 other => {
411 return exec_err!(
412 "substr was called with {other} arguments. It requires 2 or 3."
413 )
414 }
415 }
416
417 let views_buf = ScalarBuffer::from(views_buf);
418 let nulls_buf = null_builder.finish();
419
420 unsafe {
425 let array = StringViewArray::new_unchecked(
426 views_buf,
427 string_view_array.data_buffers().to_vec(),
428 nulls_buf,
429 );
430 Ok(Arc::new(array) as ArrayRef)
431 }
432}
433
434fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
435where
436 V: StringArrayType<'a>,
437{
438 let start_array = as_int64_array(&args[0])?;
439 let count_array_opt = if args.len() == 2 {
440 Some(as_int64_array(&args[1])?)
441 } else {
442 None
443 };
444
445 let enable_ascii_fast_path =
446 enable_ascii_fast_path(&string_array, start_array, count_array_opt);
447
448 match args.len() {
449 1 => {
450 let iter = ArrayIter::new(string_array);
451 let mut result_builder = StringViewBuilder::new();
452 for (string, start) in iter.zip(start_array.iter()) {
453 match (string, start) {
454 (Some(string), Some(start)) => {
455 let (start, end) = get_true_start_end(
456 string,
457 start,
458 None,
459 enable_ascii_fast_path,
460 ); let substr = &string[start..end];
462 result_builder.append_value(substr);
463 }
464 _ => {
465 result_builder.append_null();
466 }
467 }
468 }
469 Ok(Arc::new(result_builder.finish()) as ArrayRef)
470 }
471 2 => {
472 let iter = ArrayIter::new(string_array);
473 let count_array = count_array_opt.unwrap();
474 let mut result_builder = StringViewBuilder::new();
475
476 for ((string, start), count) in
477 iter.zip(start_array.iter()).zip(count_array.iter())
478 {
479 match (string, start, count) {
480 (Some(string), Some(start), Some(count)) => {
481 if count < 0 {
482 return exec_err!(
483 "negative substring length not allowed: substr(<str>, {start}, {count})"
484 );
485 } else {
486 if start == i64::MIN {
487 return exec_err!(
488 "negative overflow when calculating skip value"
489 );
490 }
491 let (start, end) = get_true_start_end(
492 string,
493 start,
494 Some(count as u64),
495 enable_ascii_fast_path,
496 ); let substr = &string[start..end];
498 result_builder.append_value(substr);
499 }
500 }
501 _ => {
502 result_builder.append_null();
503 }
504 }
505 }
506 Ok(Arc::new(result_builder.finish()) as ArrayRef)
507 }
508 other => {
509 exec_err!("substr was called with {other} arguments. It requires 2 or 3.")
510 }
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use arrow::array::{Array, StringViewArray};
517 use arrow::datatypes::DataType::Utf8View;
518
519 use datafusion_common::{exec_err, Result, ScalarValue};
520 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
521
522 use crate::unicode::substr::SubstrFunc;
523 use crate::utils::test::test_function;
524
525 #[test]
526 fn test_functions() -> Result<()> {
527 test_function!(
528 SubstrFunc::new(),
529 vec![
530 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
531 ColumnarValue::Scalar(ScalarValue::from(1i64)),
532 ],
533 Ok(None),
534 &str,
535 Utf8View,
536 StringViewArray
537 );
538 test_function!(
539 SubstrFunc::new(),
540 vec![
541 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
542 "alphabet"
543 )))),
544 ColumnarValue::Scalar(ScalarValue::from(0i64)),
545 ],
546 Ok(Some("alphabet")),
547 &str,
548 Utf8View,
549 StringViewArray
550 );
551 test_function!(
552 SubstrFunc::new(),
553 vec![
554 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
555 "this és longer than 12B"
556 )))),
557 ColumnarValue::Scalar(ScalarValue::from(5i64)),
558 ColumnarValue::Scalar(ScalarValue::from(2i64)),
559 ],
560 Ok(Some(" é")),
561 &str,
562 Utf8View,
563 StringViewArray
564 );
565 test_function!(
566 SubstrFunc::new(),
567 vec![
568 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
569 "this is longer than 12B"
570 )))),
571 ColumnarValue::Scalar(ScalarValue::from(5i64)),
572 ],
573 Ok(Some(" is longer than 12B")),
574 &str,
575 Utf8View,
576 StringViewArray
577 );
578 test_function!(
579 SubstrFunc::new(),
580 vec![
581 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
582 "joséésoj"
583 )))),
584 ColumnarValue::Scalar(ScalarValue::from(5i64)),
585 ],
586 Ok(Some("ésoj")),
587 &str,
588 Utf8View,
589 StringViewArray
590 );
591 test_function!(
592 SubstrFunc::new(),
593 vec![
594 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
595 "alphabet"
596 )))),
597 ColumnarValue::Scalar(ScalarValue::from(3i64)),
598 ColumnarValue::Scalar(ScalarValue::from(2i64)),
599 ],
600 Ok(Some("ph")),
601 &str,
602 Utf8View,
603 StringViewArray
604 );
605 test_function!(
606 SubstrFunc::new(),
607 vec![
608 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
609 "alphabet"
610 )))),
611 ColumnarValue::Scalar(ScalarValue::from(3i64)),
612 ColumnarValue::Scalar(ScalarValue::from(20i64)),
613 ],
614 Ok(Some("phabet")),
615 &str,
616 Utf8View,
617 StringViewArray
618 );
619 test_function!(
620 SubstrFunc::new(),
621 vec![
622 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
623 ColumnarValue::Scalar(ScalarValue::from(0i64)),
624 ],
625 Ok(Some("alphabet")),
626 &str,
627 Utf8View,
628 StringViewArray
629 );
630 test_function!(
631 SubstrFunc::new(),
632 vec![
633 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
634 ColumnarValue::Scalar(ScalarValue::from(5i64)),
635 ],
636 Ok(Some("ésoj")),
637 &str,
638 Utf8View,
639 StringViewArray
640 );
641 test_function!(
642 SubstrFunc::new(),
643 vec![
644 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
645 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
646 ],
647 Ok(Some("joséésoj")),
648 &str,
649 Utf8View,
650 StringViewArray
651 );
652 test_function!(
653 SubstrFunc::new(),
654 vec![
655 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
656 ColumnarValue::Scalar(ScalarValue::from(1i64)),
657 ],
658 Ok(Some("alphabet")),
659 &str,
660 Utf8View,
661 StringViewArray
662 );
663 test_function!(
664 SubstrFunc::new(),
665 vec![
666 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
667 ColumnarValue::Scalar(ScalarValue::from(2i64)),
668 ],
669 Ok(Some("lphabet")),
670 &str,
671 Utf8View,
672 StringViewArray
673 );
674 test_function!(
675 SubstrFunc::new(),
676 vec![
677 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
678 ColumnarValue::Scalar(ScalarValue::from(3i64)),
679 ],
680 Ok(Some("phabet")),
681 &str,
682 Utf8View,
683 StringViewArray
684 );
685 test_function!(
686 SubstrFunc::new(),
687 vec![
688 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
689 ColumnarValue::Scalar(ScalarValue::from(-3i64)),
690 ],
691 Ok(Some("alphabet")),
692 &str,
693 Utf8View,
694 StringViewArray
695 );
696 test_function!(
697 SubstrFunc::new(),
698 vec![
699 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
700 ColumnarValue::Scalar(ScalarValue::from(30i64)),
701 ],
702 Ok(Some("")),
703 &str,
704 Utf8View,
705 StringViewArray
706 );
707 test_function!(
708 SubstrFunc::new(),
709 vec![
710 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
711 ColumnarValue::Scalar(ScalarValue::Int64(None)),
712 ],
713 Ok(None),
714 &str,
715 Utf8View,
716 StringViewArray
717 );
718 test_function!(
719 SubstrFunc::new(),
720 vec![
721 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
722 ColumnarValue::Scalar(ScalarValue::from(3i64)),
723 ColumnarValue::Scalar(ScalarValue::from(2i64)),
724 ],
725 Ok(Some("ph")),
726 &str,
727 Utf8View,
728 StringViewArray
729 );
730 test_function!(
731 SubstrFunc::new(),
732 vec![
733 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
734 ColumnarValue::Scalar(ScalarValue::from(3i64)),
735 ColumnarValue::Scalar(ScalarValue::from(20i64)),
736 ],
737 Ok(Some("phabet")),
738 &str,
739 Utf8View,
740 StringViewArray
741 );
742 test_function!(
743 SubstrFunc::new(),
744 vec![
745 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
746 ColumnarValue::Scalar(ScalarValue::from(0i64)),
747 ColumnarValue::Scalar(ScalarValue::from(5i64)),
748 ],
749 Ok(Some("alph")),
750 &str,
751 Utf8View,
752 StringViewArray
753 );
754 test_function!(
756 SubstrFunc::new(),
757 vec![
758 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
759 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
760 ColumnarValue::Scalar(ScalarValue::from(10i64)),
761 ],
762 Ok(Some("alph")),
763 &str,
764 Utf8View,
765 StringViewArray
766 );
767 test_function!(
769 SubstrFunc::new(),
770 vec![
771 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
772 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
773 ColumnarValue::Scalar(ScalarValue::from(4i64)),
774 ],
775 Ok(Some("")),
776 &str,
777 Utf8View,
778 StringViewArray
779 );
780 test_function!(
782 SubstrFunc::new(),
783 vec![
784 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
785 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
786 ColumnarValue::Scalar(ScalarValue::from(5i64)),
787 ],
788 Ok(Some("")),
789 &str,
790 Utf8View,
791 StringViewArray
792 );
793 test_function!(
794 SubstrFunc::new(),
795 vec![
796 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
797 ColumnarValue::Scalar(ScalarValue::Int64(None)),
798 ColumnarValue::Scalar(ScalarValue::from(20i64)),
799 ],
800 Ok(None),
801 &str,
802 Utf8View,
803 StringViewArray
804 );
805 test_function!(
806 SubstrFunc::new(),
807 vec![
808 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
809 ColumnarValue::Scalar(ScalarValue::from(3i64)),
810 ColumnarValue::Scalar(ScalarValue::Int64(None)),
811 ],
812 Ok(None),
813 &str,
814 Utf8View,
815 StringViewArray
816 );
817 test_function!(
818 SubstrFunc::new(),
819 vec![
820 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
821 ColumnarValue::Scalar(ScalarValue::from(1i64)),
822 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
823 ],
824 exec_err!("negative substring length not allowed: substr(<str>, 1, -1)"),
825 &str,
826 Utf8View,
827 StringViewArray
828 );
829 test_function!(
830 SubstrFunc::new(),
831 vec![
832 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
833 ColumnarValue::Scalar(ScalarValue::from(5i64)),
834 ColumnarValue::Scalar(ScalarValue::from(2i64)),
835 ],
836 Ok(Some("és")),
837 &str,
838 Utf8View,
839 StringViewArray
840 );
841 #[cfg(not(feature = "unicode_expressions"))]
842 test_function!(
843 SubstrFunc::new(),
844 &[
845 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
846 ColumnarValue::Scalar(ScalarValue::from(0i64)),
847 ],
848 internal_err!(
849 "function substr requires compilation with feature flag: unicode_expressions."
850 ),
851 &str,
852 Utf8View,
853 StringViewArray
854 );
855 test_function!(
856 SubstrFunc::new(),
857 vec![
858 ColumnarValue::Scalar(ScalarValue::from("abc")),
859 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
860 ],
861 Ok(Some("abc")),
862 &str,
863 Utf8View,
864 StringViewArray
865 );
866 test_function!(
867 SubstrFunc::new(),
868 vec![
869 ColumnarValue::Scalar(ScalarValue::from("overflow")),
870 ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)),
871 ColumnarValue::Scalar(ScalarValue::from(1i64)),
872 ],
873 exec_err!("negative overflow when calculating skip value"),
874 &str,
875 Utf8View,
876 StringViewArray
877 );
878
879 Ok(())
880 }
881}