datafusion_functions/unicode/
substrindex.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23 GenericStringBuilder, OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{DataType, Int32Type, Int64Type};
26
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::{Result, exec_err, utils::take_function_args};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "String Functions"),
37 description = r#"Returns the substring from str before count occurrences of the delimiter delim.
38If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
39If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#,
40 syntax_example = "substr_index(str, delim, count)",
41 sql_example = r#"```sql
42> select substr_index('www.apache.org', '.', 1);
43+---------------------------------------------------------+
44| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) |
45+---------------------------------------------------------+
46| www |
47+---------------------------------------------------------+
48> select substr_index('www.apache.org', '.', -1);
49+----------------------------------------------------------+
50| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) |
51+----------------------------------------------------------+
52| org |
53+----------------------------------------------------------+
54```"#,
55 standard_argument(name = "str", prefix = "String"),
56 argument(
57 name = "delim",
58 description = "The string to find in str to split str."
59 ),
60 argument(
61 name = "count",
62 description = "The number of times to search for the delimiter. Can be either a positive or negative number."
63 )
64)]
65#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct SubstrIndexFunc {
67 signature: Signature,
68 aliases: Vec<String>,
69}
70
71impl Default for SubstrIndexFunc {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl SubstrIndexFunc {
78 pub fn new() -> Self {
79 use DataType::*;
80 Self {
81 signature: Signature::one_of(
82 vec![
83 Exact(vec![Utf8View, Utf8View, Int64]),
84 Exact(vec![Utf8, Utf8, Int64]),
85 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
86 ],
87 Volatility::Immutable,
88 ),
89 aliases: vec![String::from("substring_index")],
90 }
91 }
92}
93
94impl ScalarUDFImpl for SubstrIndexFunc {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "substr_index"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108 utf8_to_str_type(&arg_types[0], "substr_index")
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: datafusion_expr::ScalarFunctionArgs,
114 ) -> Result<ColumnarValue> {
115 make_scalar_function(substr_index, vec![])(&args.args)
116 }
117
118 fn aliases(&self) -> &[String] {
119 &self.aliases
120 }
121
122 fn documentation(&self) -> Option<&Documentation> {
123 self.doc()
124 }
125}
126
127fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
133 let [str, delim, count] = take_function_args("substr_index", args)?;
134
135 match str.data_type() {
136 DataType::Utf8 => {
137 let string_array = str.as_string::<i32>();
138 let delimiter_array = delim.as_string::<i32>();
139 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
140 substr_index_general::<Int32Type, _, _>(
141 string_array,
142 delimiter_array,
143 count_array,
144 )
145 }
146 DataType::LargeUtf8 => {
147 let string_array = str.as_string::<i64>();
148 let delimiter_array = delim.as_string::<i64>();
149 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
150 substr_index_general::<Int64Type, _, _>(
151 string_array,
152 delimiter_array,
153 count_array,
154 )
155 }
156 DataType::Utf8View => {
157 let string_array = str.as_string_view();
158 let delimiter_array = delim.as_string_view();
159 let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
160 substr_index_general::<Int32Type, _, _>(
161 string_array,
162 delimiter_array,
163 count_array,
164 )
165 }
166 other => {
167 exec_err!("Unsupported data type {other:?} for function substr_index")
168 }
169 }
170}
171
172fn substr_index_general<
173 'a,
174 T: ArrowPrimitiveType,
175 V: ArrayAccessor<Item = &'a str>,
176 P: ArrayAccessor<Item = i64>,
177>(
178 string_array: V,
179 delimiter_array: V,
180 count_array: P,
181) -> Result<ArrayRef>
182where
183 T::Native: OffsetSizeTrait,
184{
185 let num_rows = string_array.len();
186 let mut builder = GenericStringBuilder::<T::Native>::with_capacity(num_rows, 0);
187 let string_iter = ArrayIter::new(string_array);
188 let delimiter_array_iter = ArrayIter::new(delimiter_array);
189 let count_array_iter = ArrayIter::new(count_array);
190 string_iter
191 .zip(delimiter_array_iter)
192 .zip(count_array_iter)
193 .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
194 (Some(string), Some(delimiter), Some(n)) => {
195 if n == 0 || string.is_empty() || delimiter.is_empty() {
197 builder.append_value("");
198 return;
199 }
200
201 let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
202 let result_idx = if delimiter.len() == 1 {
203 let d_byte = delimiter.as_bytes()[0];
205 let bytes = string.as_bytes();
206
207 if n > 0 {
208 bytes
209 .iter()
210 .enumerate()
211 .filter(|&(_, &b)| b == d_byte)
212 .nth(occurrences - 1)
213 .map(|(idx, _)| idx)
214 } else {
215 bytes
216 .iter()
217 .enumerate()
218 .rev()
219 .filter(|&(_, &b)| b == d_byte)
220 .nth(occurrences - 1)
221 .map(|(idx, _)| idx + 1)
222 }
223 } else if n > 0 {
224 string
226 .match_indices(delimiter)
227 .nth(occurrences - 1)
228 .map(|(idx, _)| idx)
229 } else {
230 string
232 .rmatch_indices(delimiter)
233 .nth(occurrences - 1)
234 .map(|(idx, _)| idx + delimiter.len())
235 };
236 match result_idx {
237 Some(idx) => {
238 if n > 0 {
239 builder.append_value(&string[..idx]);
240 } else {
241 builder.append_value(&string[idx..]);
242 }
243 }
244 None => builder.append_value(string),
245 }
246 }
247 _ => builder.append_null(),
248 });
249
250 Ok(Arc::new(builder.finish()) as ArrayRef)
251}
252
253#[cfg(test)]
254mod tests {
255 use arrow::array::{Array, StringArray};
256 use arrow::datatypes::DataType::Utf8;
257
258 use datafusion_common::{Result, ScalarValue};
259 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
260
261 use crate::unicode::substrindex::SubstrIndexFunc;
262 use crate::utils::test::test_function;
263
264 #[test]
265 fn test_functions() -> Result<()> {
266 test_function!(
267 SubstrIndexFunc::new(),
268 vec![
269 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
270 ColumnarValue::Scalar(ScalarValue::from(".")),
271 ColumnarValue::Scalar(ScalarValue::from(1i64)),
272 ],
273 Ok(Some("www")),
274 &str,
275 Utf8,
276 StringArray
277 );
278 test_function!(
279 SubstrIndexFunc::new(),
280 vec![
281 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
282 ColumnarValue::Scalar(ScalarValue::from(".")),
283 ColumnarValue::Scalar(ScalarValue::from(2i64)),
284 ],
285 Ok(Some("www.apache")),
286 &str,
287 Utf8,
288 StringArray
289 );
290 test_function!(
291 SubstrIndexFunc::new(),
292 vec![
293 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
294 ColumnarValue::Scalar(ScalarValue::from(".")),
295 ColumnarValue::Scalar(ScalarValue::from(-2i64)),
296 ],
297 Ok(Some("apache.org")),
298 &str,
299 Utf8,
300 StringArray
301 );
302 test_function!(
303 SubstrIndexFunc::new(),
304 vec![
305 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
306 ColumnarValue::Scalar(ScalarValue::from(".")),
307 ColumnarValue::Scalar(ScalarValue::from(-1i64)),
308 ],
309 Ok(Some("org")),
310 &str,
311 Utf8,
312 StringArray
313 );
314 test_function!(
315 SubstrIndexFunc::new(),
316 vec![
317 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
318 ColumnarValue::Scalar(ScalarValue::from(".")),
319 ColumnarValue::Scalar(ScalarValue::from(0i64)),
320 ],
321 Ok(Some("")),
322 &str,
323 Utf8,
324 StringArray
325 );
326 test_function!(
327 SubstrIndexFunc::new(),
328 vec![
329 ColumnarValue::Scalar(ScalarValue::from("")),
330 ColumnarValue::Scalar(ScalarValue::from(".")),
331 ColumnarValue::Scalar(ScalarValue::from(1i64)),
332 ],
333 Ok(Some("")),
334 &str,
335 Utf8,
336 StringArray
337 );
338 test_function!(
339 SubstrIndexFunc::new(),
340 vec![
341 ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
342 ColumnarValue::Scalar(ScalarValue::from("")),
343 ColumnarValue::Scalar(ScalarValue::from(1i64)),
344 ],
345 Ok(Some("")),
346 &str,
347 Utf8,
348 StringArray
349 );
350 Ok(())
351 }
352}