datafusion_spark/function/string/
substring.rs1use arrow::array::{
19 Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array,
20 OffsetSizeTrait, StringArrayType, StringViewBuilder,
21};
22use arrow::datatypes::DataType;
23use datafusion_common::arrow::datatypes::{Field, FieldRef};
24use datafusion_common::cast::as_int64_array;
25use datafusion_common::types::{
26 NativeType, logical_int32, logical_int64, logical_string,
27};
28use datafusion_common::{Result, exec_err};
29use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass};
30use datafusion_expr::{
31 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32 Volatility,
33};
34use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end};
35use datafusion_functions::utils::make_scalar_function;
36use std::any::Any;
37use std::sync::Arc;
38
39#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkSubstring {
47 signature: Signature,
48 aliases: Vec<String>,
49}
50
51impl Default for SparkSubstring {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl SparkSubstring {
58 pub fn new() -> Self {
59 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
60 let int64 = Coercion::new_implicit(
61 TypeSignatureClass::Native(logical_int64()),
62 vec![TypeSignatureClass::Native(logical_int32())],
63 NativeType::Int64,
64 );
65 Self {
66 signature: Signature::one_of(
67 vec![
68 TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
69 TypeSignature::Coercible(vec![
70 string.clone(),
71 int64.clone(),
72 int64.clone(),
73 ]),
74 ],
75 Volatility::Immutable,
76 )
77 .with_parameter_names(vec![
78 "str".to_string(),
79 "pos".to_string(),
80 "length".to_string(),
81 ])
82 .expect("valid parameter names"),
83 aliases: vec![String::from("substr")],
84 }
85 }
86}
87
88impl ScalarUDFImpl for SparkSubstring {
89 fn as_any(&self) -> &dyn Any {
90 self
91 }
92
93 fn name(&self) -> &str {
94 "substring"
95 }
96
97 fn signature(&self) -> &Signature {
98 &self.signature
99 }
100
101 fn aliases(&self) -> &[String] {
102 &self.aliases
103 }
104
105 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106 make_scalar_function(spark_substring, vec![])(&args.args)
107 }
108
109 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
110 datafusion_common::internal_err!(
111 "return_type should not be called for Spark substring"
112 )
113 }
114
115 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
116 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
118
119 Ok(Arc::new(Field::new(
120 "substring",
121 args.arg_fields[0].data_type().clone(),
122 nullable,
123 )))
124 }
125}
126
127fn spark_substring(args: &[ArrayRef]) -> Result<ArrayRef> {
128 let start_array = as_int64_array(&args[1])?;
129 let length_array = if args.len() > 2 {
130 Some(as_int64_array(&args[2])?)
131 } else {
132 None
133 };
134
135 match args[0].data_type() {
136 DataType::Utf8 => spark_substring_impl(
137 &args[0].as_string::<i32>(),
138 start_array,
139 length_array,
140 GenericStringBuilder::<i32>::new(),
141 ),
142 DataType::LargeUtf8 => spark_substring_impl(
143 &args[0].as_string::<i64>(),
144 start_array,
145 length_array,
146 GenericStringBuilder::<i64>::new(),
147 ),
148 DataType::Utf8View => spark_substring_impl(
149 &args[0].as_string_view(),
150 start_array,
151 length_array,
152 StringViewBuilder::new(),
153 ),
154 other => exec_err!(
155 "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8."
156 ),
157 }
158}
159
160#[inline]
169fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 {
170 if start >= 0 {
171 start.max(1)
172 } else {
173 let len_i64 = i64::try_from(len).unwrap_or(i64::MAX);
174 let start = start.saturating_add(len_i64).saturating_add(1);
175 start.max(1)
176 }
177}
178
179trait StringArrayBuilder: ArrayBuilder {
180 fn append_value(&mut self, val: &str);
181 fn append_null(&mut self);
182}
183
184impl<O: OffsetSizeTrait> StringArrayBuilder for GenericStringBuilder<O> {
185 fn append_value(&mut self, val: &str) {
186 GenericStringBuilder::append_value(self, val);
187 }
188 fn append_null(&mut self) {
189 GenericStringBuilder::append_null(self);
190 }
191}
192
193impl StringArrayBuilder for StringViewBuilder {
194 fn append_value(&mut self, val: &str) {
195 StringViewBuilder::append_value(self, val);
196 }
197 fn append_null(&mut self) {
198 StringViewBuilder::append_null(self);
199 }
200}
201
202fn spark_substring_impl<'a, V, B>(
203 string_array: &V,
204 start_array: &Int64Array,
205 length_array: Option<&Int64Array>,
206 mut builder: B,
207) -> Result<ArrayRef>
208where
209 V: StringArrayType<'a>,
210 B: StringArrayBuilder,
211{
212 let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array);
213
214 for i in 0..string_array.len() {
215 if string_array.is_null(i) || start_array.is_null(i) {
216 builder.append_null();
217 continue;
218 }
219
220 if let Some(len_arr) = length_array
221 && len_arr.is_null(i)
222 {
223 builder.append_null();
224 continue;
225 }
226
227 let string = string_array.value(i);
228 let start = start_array.value(i);
229 let len_opt = length_array.map(|arr| arr.value(i));
230
231 if let Some(len) = len_opt
233 && len < 0
234 {
235 builder.append_value("");
236 continue;
237 }
238
239 let string_len = if is_ascii {
240 string.len()
241 } else {
242 string.chars().count()
243 };
244
245 let adjusted_start = spark_start_to_datafusion_start(start, string_len);
246
247 let (byte_start, byte_end) = get_true_start_end(
248 string,
249 adjusted_start,
250 len_opt.map(|l| l as u64),
251 is_ascii,
252 );
253 let substr = &string[byte_start..byte_end];
254 builder.append_value(substr);
255 }
256
257 Ok(builder.finish())
258}