1use std::any::Any;
19use std::sync::Arc;
20
21use crate::strings::make_and_append_view;
22use crate::utils::make_scalar_function;
23use arrow::array::{
24 Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
25 StringViewArray, StringViewBuilder,
26};
27use arrow::buffer::ScalarBuffer;
28use arrow::datatypes::DataType;
29use datafusion_common::cast::as_int64_array;
30use datafusion_common::types::{
31 NativeType, logical_int32, logical_int64, logical_string,
32};
33use datafusion_common::{Result, exec_err};
34use datafusion_expr::{
35 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature,
36 TypeSignatureClass, Volatility,
37};
38use datafusion_macros::user_doc;
39
40#[user_doc(
41 doc_section(label = "String Functions"),
42 description = "Extracts a substring of a specified number of characters from a specific starting position in a string.",
43 syntax_example = "substr(str, start_pos[, length])",
44 alternative_syntax = "substring(str from start_pos for length)",
45 sql_example = r#"```sql
46> select substr('datafusion', 5, 3);
47+----------------------------------------------+
48| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
49+----------------------------------------------+
50| fus |
51+----------------------------------------------+
52```"#,
53 standard_argument(name = "str", prefix = "String"),
54 argument(
55 name = "start_pos",
56 description = "Character position to start the substring at. The first character in the string has a position of 1."
57 ),
58 argument(
59 name = "length",
60 description = "Number of characters to extract. If not specified, returns the rest of the string after the start position."
61 )
62)]
63#[derive(Debug, PartialEq, Eq, Hash)]
64pub struct SubstrFunc {
65 signature: Signature,
66 aliases: Vec<String>,
67}
68
69impl Default for SubstrFunc {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl SubstrFunc {
76 pub fn new() -> Self {
77 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
78 let int64 = Coercion::new_implicit(
79 TypeSignatureClass::Native(logical_int64()),
80 vec![TypeSignatureClass::Native(logical_int32())],
81 NativeType::Int64,
82 );
83 Self {
84 signature: Signature::one_of(
85 vec![
86 TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
87 TypeSignature::Coercible(vec![
88 string.clone(),
89 int64.clone(),
90 int64.clone(),
91 ]),
92 ],
93 Volatility::Immutable,
94 )
95 .with_parameter_names(vec![
96 "str".to_string(),
97 "start_pos".to_string(),
98 "length".to_string(),
99 ])
100 .expect("valid parameter names"),
101 aliases: vec![String::from("substring")],
102 }
103 }
104}
105
106impl ScalarUDFImpl for SubstrFunc {
107 fn as_any(&self) -> &dyn Any {
108 self
109 }
110
111 fn name(&self) -> &str {
112 "substr"
113 }
114
115 fn signature(&self) -> &Signature {
116 &self.signature
117 }
118
119 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
121 Ok(DataType::Utf8View)
122 }
123
124 fn invoke_with_args(
125 &self,
126 args: datafusion_expr::ScalarFunctionArgs,
127 ) -> Result<ColumnarValue> {
128 make_scalar_function(substr, vec![])(&args.args)
129 }
130
131 fn aliases(&self) -> &[String] {
132 &self.aliases
133 }
134
135 fn documentation(&self) -> Option<&Documentation> {
136 self.doc()
137 }
138}
139
140fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
145 match args[0].data_type() {
146 DataType::Utf8 => {
147 let string_array = args[0].as_string::<i32>();
148 string_substr::<_>(string_array, &args[1..])
149 }
150 DataType::LargeUtf8 => {
151 let string_array = args[0].as_string::<i64>();
152 string_substr::<_>(string_array, &args[1..])
153 }
154 DataType::Utf8View => {
155 let string_array = args[0].as_string_view();
156 string_view_substr(string_array, &args[1..])
157 }
158 other => exec_err!(
159 "Unsupported data type {other:?} for function substr,\
160 expected Utf8View, Utf8 or LargeUtf8."
161 ),
162 }
163}
164
165pub fn get_true_start_end(
180 input: &str,
181 start: i64,
182 count: Option<u64>,
183 is_input_ascii_only: bool,
184) -> (usize, usize) {
185 let start = start.checked_sub(1).unwrap_or(start);
186
187 let end = match count {
188 Some(count) => {
189 let count_i64 = i64::try_from(count).unwrap_or(i64::MAX);
190 start.saturating_add(count_i64)
191 }
192 None => input.len() as i64,
193 };
194 let count_to_end = count.is_some();
195
196 let start = start.clamp(0, input.len() as i64) as usize;
197 let end = end.clamp(0, input.len() as i64) as usize;
198 let count = end - start;
199
200 if is_input_ascii_only {
202 return (start, end);
203 }
204
205 let (mut st, mut ed) = (input.len(), input.len());
209 let mut start_counting = false;
210 let mut cnt = 0;
211 for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
212 if char_cnt == start {
213 st = byte_cnt;
214 if count_to_end {
215 start_counting = true;
216 } else {
217 break;
218 }
219 }
220 if start_counting {
221 if cnt == count {
222 ed = byte_cnt;
223 break;
224 }
225 cnt += 1;
226 }
227 }
228 (st, ed)
229}
230
231pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
242 string_array: &V,
243 start: &Int64Array,
244 count: Option<&Int64Array>,
245) -> bool {
246 let is_short_prefix = match count {
247 Some(count) => {
248 let short_prefix_threshold = 32.0;
249 let n_sample = 10;
250
251 let total_prefix_len = start
254 .iter()
255 .zip(count.iter())
256 .take(n_sample)
257 .map(|(start, count)| {
258 let start = start.unwrap_or(0);
259 let count = count.unwrap_or(0);
260 start.saturating_add(count)
262 })
263 .fold(0i64, |acc, val| acc.saturating_add(val));
264
265 (total_prefix_len as f64 / n_sample as f64) <= short_prefix_threshold
266 }
267 None => false,
268 };
269
270 if is_short_prefix {
271 false
273 } else {
274 string_array.is_ascii()
275 }
276}
277
278fn string_view_substr(
281 string_view_array: &StringViewArray,
282 args: &[ArrayRef],
283) -> Result<ArrayRef> {
284 let mut views_buf = Vec::with_capacity(string_view_array.len());
285 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
286
287 let start_array = as_int64_array(&args[0])?;
288 let count_array_opt = if args.len() == 2 {
289 Some(as_int64_array(&args[1])?)
290 } else {
291 None
292 };
293
294 let enable_ascii_fast_path =
295 enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);
296
297 match args.len() {
300 1 => {
301 for ((str_opt, raw_view), start_opt) in string_view_array
302 .iter()
303 .zip(string_view_array.views().iter())
304 .zip(start_array.iter())
305 {
306 if let (Some(str), Some(start)) = (str_opt, start_opt) {
307 let (start, end) =
308 get_true_start_end(str, start, None, enable_ascii_fast_path);
309 let substr = &str[start..end];
310
311 make_and_append_view(
312 &mut views_buf,
313 &mut null_builder,
314 raw_view,
315 substr,
316 start as u32,
317 );
318 } else {
319 null_builder.append_null();
320 views_buf.push(0);
321 }
322 }
323 }
324 2 => {
325 let count_array = count_array_opt.unwrap();
326 for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
327 .iter()
328 .zip(string_view_array.views().iter())
329 .zip(start_array.iter())
330 .zip(count_array.iter())
331 {
332 if let (Some(str), Some(start), Some(count)) =
333 (str_opt, start_opt, count_opt)
334 {
335 if count < 0 {
336 return exec_err!(
337 "negative substring length not allowed: substr(<str>, {start}, {count})"
338 );
339 } else {
340 if start == i64::MIN {
341 return exec_err!(
342 "negative overflow when calculating skip value"
343 );
344 }
345 let (start, end) = get_true_start_end(
346 str,
347 start,
348 Some(count as u64),
349 enable_ascii_fast_path,
350 );
351 let substr = &str[start..end];
352
353 make_and_append_view(
354 &mut views_buf,
355 &mut null_builder,
356 raw_view,
357 substr,
358 start as u32,
359 );
360 }
361 } else {
362 null_builder.append_null();
363 views_buf.push(0);
364 }
365 }
366 }
367 other => {
368 return exec_err!(
369 "substr was called with {other} arguments. It requires 2 or 3."
370 );
371 }
372 }
373
374 let views_buf = ScalarBuffer::from(views_buf);
375 let nulls_buf = null_builder.finish();
376
377 unsafe {
382 let array = StringViewArray::new_unchecked(
383 views_buf,
384 string_view_array.data_buffers().to_vec(),
385 nulls_buf,
386 );
387 Ok(Arc::new(array) as ArrayRef)
388 }
389}
390
391fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
392where
393 V: StringArrayType<'a>,
394{
395 let start_array = as_int64_array(&args[0])?;
396 let count_array_opt = if args.len() == 2 {
397 Some(as_int64_array(&args[1])?)
398 } else {
399 None
400 };
401
402 let enable_ascii_fast_path =
403 enable_ascii_fast_path(&string_array, start_array, count_array_opt);
404
405 match args.len() {
406 1 => {
407 let iter = ArrayIter::new(string_array);
408 let mut result_builder = StringViewBuilder::new();
409 for (string, start) in iter.zip(start_array.iter()) {
410 match (string, start) {
411 (Some(string), Some(start)) => {
412 let (start, end) = get_true_start_end(
413 string,
414 start,
415 None,
416 enable_ascii_fast_path,
417 ); let substr = &string[start..end];
419 result_builder.append_value(substr);
420 }
421 _ => {
422 result_builder.append_null();
423 }
424 }
425 }
426 Ok(Arc::new(result_builder.finish()) as ArrayRef)
427 }
428 2 => {
429 let iter = ArrayIter::new(string_array);
430 let count_array = count_array_opt.unwrap();
431 let mut result_builder = StringViewBuilder::new();
432
433 for ((string, start), count) in
434 iter.zip(start_array.iter()).zip(count_array.iter())
435 {
436 match (string, start, count) {
437 (Some(string), Some(start), Some(count)) => {
438 if count < 0 {
439 return exec_err!(
440 "negative substring length not allowed: substr(<str>, {start}, {count})"
441 );
442 } else {
443 if start == i64::MIN {
444 return exec_err!(
445 "negative overflow when calculating skip value"
446 );
447 }
448 let (start, end) = get_true_start_end(
449 string,
450 start,
451 Some(count as u64),
452 enable_ascii_fast_path,
453 ); let substr = &string[start..end];
455 result_builder.append_value(substr);
456 }
457 }
458 _ => {
459 result_builder.append_null();
460 }
461 }
462 }
463 Ok(Arc::new(result_builder.finish()) as ArrayRef)
464 }
465 other => {
466 exec_err!("substr was called with {other} arguments. It requires 2 or 3.")
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use arrow::array::{Array, StringViewArray};
474 use arrow::datatypes::DataType::Utf8View;
475
476 use datafusion_common::{Result, ScalarValue, exec_err};
477 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
478
479 use crate::unicode::substr::SubstrFunc;
480 use crate::utils::test::test_function;
481
482 #[test]
483 fn test_functions() -> Result<()> {
484 test_function!(
485 SubstrFunc::new(),
486 vec![
487 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
488 ColumnarValue::Scalar(ScalarValue::from(1i64)),
489 ],
490 Ok(None),
491 &str,
492 Utf8View,
493 StringViewArray
494 );
495 test_function!(
496 SubstrFunc::new(),
497 vec![
498 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
499 "alphabet"
500 )))),
501 ColumnarValue::Scalar(ScalarValue::from(0i64)),
502 ],
503 Ok(Some("alphabet")),
504 &str,
505 Utf8View,
506 StringViewArray
507 );
508 test_function!(
509 SubstrFunc::new(),
510 vec![
511 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
512 "this és longer than 12B"
513 )))),
514 ColumnarValue::Scalar(ScalarValue::from(5i64)),
515 ColumnarValue::Scalar(ScalarValue::from(2i64)),
516 ],
517 Ok(Some(" é")),
518 &str,
519 Utf8View,
520 StringViewArray
521 );
522 test_function!(
523 SubstrFunc::new(),
524 vec![
525 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
526 "this is longer than 12B"
527 )))),
528 ColumnarValue::Scalar(ScalarValue::from(5i64)),
529 ],
530 Ok(Some(" is longer than 12B")),
531 &str,
532 Utf8View,
533 StringViewArray
534 );
535 test_function!(
536 SubstrFunc::new(),
537 vec![
538 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
539 "joséésoj"
540 )))),
541 ColumnarValue::Scalar(ScalarValue::from(5i64)),
542 ],
543 Ok(Some("ésoj")),
544 &str,
545 Utf8View,
546 StringViewArray
547 );
548 test_function!(
549 SubstrFunc::new(),
550 vec![
551 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
552 "alphabet"
553 )))),
554 ColumnarValue::Scalar(ScalarValue::from(3i64)),
555 ColumnarValue::Scalar(ScalarValue::from(2i64)),
556 ],
557 Ok(Some("ph")),
558 &str,
559 Utf8View,
560 StringViewArray
561 );
562 test_function!(
563 SubstrFunc::new(),
564 vec![
565 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
566 "alphabet"
567 )))),
568 ColumnarValue::Scalar(ScalarValue::from(3i64)),
569 ColumnarValue::Scalar(ScalarValue::from(20i64)),
570 ],
571 Ok(Some("phabet")),
572 &str,
573 Utf8View,
574 StringViewArray
575 );
576 test_function!(
577 SubstrFunc::new(),
578 vec![
579 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
580 ColumnarValue::Scalar(ScalarValue::from(0i64)),
581 ],
582 Ok(Some("alphabet")),
583 &str,
584 Utf8View,
585 StringViewArray
586 );
587 test_function!(
588 SubstrFunc::new(),
589 vec![
590 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
591 ColumnarValue::Scalar(ScalarValue::from(5i64)),
592 ],
593 Ok(Some("ésoj")),
594 &str,
595 Utf8View,
596 StringViewArray
597 );
598 test_function!(
599 SubstrFunc::new(),
600 vec![
601 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
602 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
603 ],
604 Ok(Some("joséésoj")),
605 &str,
606 Utf8View,
607 StringViewArray
608 );
609 test_function!(
610 SubstrFunc::new(),
611 vec![
612 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
613 ColumnarValue::Scalar(ScalarValue::from(1i64)),
614 ],
615 Ok(Some("alphabet")),
616 &str,
617 Utf8View,
618 StringViewArray
619 );
620 test_function!(
621 SubstrFunc::new(),
622 vec![
623 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
624 ColumnarValue::Scalar(ScalarValue::from(2i64)),
625 ],
626 Ok(Some("lphabet")),
627 &str,
628 Utf8View,
629 StringViewArray
630 );
631 test_function!(
632 SubstrFunc::new(),
633 vec![
634 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
635 ColumnarValue::Scalar(ScalarValue::from(3i64)),
636 ],
637 Ok(Some("phabet")),
638 &str,
639 Utf8View,
640 StringViewArray
641 );
642 test_function!(
643 SubstrFunc::new(),
644 vec![
645 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
646 ColumnarValue::Scalar(ScalarValue::from(-3i64)),
647 ],
648 Ok(Some("alphabet")),
649 &str,
650 Utf8View,
651 StringViewArray
652 );
653 test_function!(
654 SubstrFunc::new(),
655 vec![
656 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
657 ColumnarValue::Scalar(ScalarValue::from(30i64)),
658 ],
659 Ok(Some("")),
660 &str,
661 Utf8View,
662 StringViewArray
663 );
664 test_function!(
665 SubstrFunc::new(),
666 vec![
667 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
668 ColumnarValue::Scalar(ScalarValue::Int64(None)),
669 ],
670 Ok(None),
671 &str,
672 Utf8View,
673 StringViewArray
674 );
675 test_function!(
676 SubstrFunc::new(),
677 vec![
678 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
679 ColumnarValue::Scalar(ScalarValue::from(3i64)),
680 ColumnarValue::Scalar(ScalarValue::from(2i64)),
681 ],
682 Ok(Some("ph")),
683 &str,
684 Utf8View,
685 StringViewArray
686 );
687 test_function!(
688 SubstrFunc::new(),
689 vec![
690 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
691 ColumnarValue::Scalar(ScalarValue::from(3i64)),
692 ColumnarValue::Scalar(ScalarValue::from(20i64)),
693 ],
694 Ok(Some("phabet")),
695 &str,
696 Utf8View,
697 StringViewArray
698 );
699 test_function!(
700 SubstrFunc::new(),
701 vec![
702 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
703 ColumnarValue::Scalar(ScalarValue::from(0i64)),
704 ColumnarValue::Scalar(ScalarValue::from(5i64)),
705 ],
706 Ok(Some("alph")),
707 &str,
708 Utf8View,
709 StringViewArray
710 );
711 test_function!(
713 SubstrFunc::new(),
714 vec![
715 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
716 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
717 ColumnarValue::Scalar(ScalarValue::from(10i64)),
718 ],
719 Ok(Some("alph")),
720 &str,
721 Utf8View,
722 StringViewArray
723 );
724 test_function!(
726 SubstrFunc::new(),
727 vec![
728 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
729 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
730 ColumnarValue::Scalar(ScalarValue::from(4i64)),
731 ],
732 Ok(Some("")),
733 &str,
734 Utf8View,
735 StringViewArray
736 );
737 test_function!(
739 SubstrFunc::new(),
740 vec![
741 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
742 ColumnarValue::Scalar(ScalarValue::from(-5i64)),
743 ColumnarValue::Scalar(ScalarValue::from(5i64)),
744 ],
745 Ok(Some("")),
746 &str,
747 Utf8View,
748 StringViewArray
749 );
750 test_function!(
751 SubstrFunc::new(),
752 vec![
753 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
754 ColumnarValue::Scalar(ScalarValue::Int64(None)),
755 ColumnarValue::Scalar(ScalarValue::from(20i64)),
756 ],
757 Ok(None),
758 &str,
759 Utf8View,
760 StringViewArray
761 );
762 test_function!(
763 SubstrFunc::new(),
764 vec![
765 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
766 ColumnarValue::Scalar(ScalarValue::from(3i64)),
767 ColumnarValue::Scalar(ScalarValue::Int64(None)),
768 ],
769 Ok(None),
770 &str,
771 Utf8View,
772 StringViewArray
773 );
774 test_function!(
775 SubstrFunc::new(),
776 vec![
777 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
778 ColumnarValue::Scalar(ScalarValue::from(1i64)),
779 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
780 ],
781 exec_err!("negative substring length not allowed: substr(<str>, 1, -1)"),
782 &str,
783 Utf8View,
784 StringViewArray
785 );
786 test_function!(
787 SubstrFunc::new(),
788 vec![
789 ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
790 ColumnarValue::Scalar(ScalarValue::from(5i64)),
791 ColumnarValue::Scalar(ScalarValue::from(2i64)),
792 ],
793 Ok(Some("és")),
794 &str,
795 Utf8View,
796 StringViewArray
797 );
798 #[cfg(not(feature = "unicode_expressions"))]
799 test_function!(
800 SubstrFunc::new(),
801 &[
802 ColumnarValue::Scalar(ScalarValue::from("alphabet")),
803 ColumnarValue::Scalar(ScalarValue::from(0i64)),
804 ],
805 internal_err!(
806 "function substr requires compilation with feature flag: unicode_expressions."
807 ),
808 &str,
809 Utf8View,
810 StringViewArray
811 );
812 test_function!(
813 SubstrFunc::new(),
814 vec![
815 ColumnarValue::Scalar(ScalarValue::from("abc")),
816 ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
817 ],
818 Ok(Some("abc")),
819 &str,
820 Utf8View,
821 StringViewArray
822 );
823 test_function!(
824 SubstrFunc::new(),
825 vec![
826 ColumnarValue::Scalar(ScalarValue::from("overflow")),
827 ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
828 ColumnarValue::Scalar(ScalarValue::from(1i64)),
829 ],
830 exec_err!("negative overflow when calculating skip value"),
831 &str,
832 Utf8View,
833 StringViewArray
834 );
835 test_function!(
836 SubstrFunc::new(),
837 vec![
838 ColumnarValue::Scalar(ScalarValue::from("large count")),
839 ColumnarValue::Scalar(ScalarValue::from(2i64)),
840 ColumnarValue::Scalar(ScalarValue::from(i64::MAX)),
841 ],
842 Ok(Some("arge count")),
843 &str,
844 Utf8View,
845 StringViewArray
846 );
847
848 Ok(())
849 }
850}