datafusion_functions/core/
overlay.rs1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
21use arrow::datatypes::DataType;
22
23use crate::utils::{make_scalar_function, utf8_to_str_type};
24use datafusion_common::cast::{
25 as_generic_string_array, as_int64_array, as_string_view_array,
26};
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility};
29use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
30use datafusion_macros::user_doc;
31
32#[user_doc(
33 doc_section(label = "String Functions"),
34 description = "Returns the string which is replaced by another string from the specified position and specified count length.",
35 syntax_example = "overlay(str PLACING substr FROM pos [FOR count])",
36 sql_example = r#"```sql
37> select overlay('Txxxxas' placing 'hom' from 2 for 4);
38+--------------------------------------------------------+
39| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) |
40+--------------------------------------------------------+
41| Thomas |
42+--------------------------------------------------------+
43```"#,
44 standard_argument(name = "str", prefix = "String"),
45 argument(name = "substr", description = "Substring to replace in str."),
46 argument(
47 name = "pos",
48 description = "The start position to start the replace in str."
49 ),
50 argument(
51 name = "count",
52 description = "The count of characters to be replaced from start position of str. If not specified, will use substr length instead."
53 )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct OverlayFunc {
57 signature: Signature,
58}
59
60impl Default for OverlayFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl OverlayFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::one_of(
71 vec![
72 TypeSignature::Exact(vec![Utf8View, Utf8View, Int64, Int64]),
73 TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]),
74 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
75 TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]),
76 TypeSignature::Exact(vec![Utf8, Utf8, Int64]),
77 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]),
78 ],
79 Volatility::Immutable,
80 ),
81 }
82 }
83}
84
85impl ScalarUDFImpl for OverlayFunc {
86 fn name(&self) -> &str {
87 "overlay"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95 utf8_to_str_type(&arg_types[0], "overlay")
96 }
97
98 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99 match args.args[0].data_type() {
100 DataType::Utf8View | DataType::Utf8 => {
101 make_scalar_function(overlay::<i32>, vec![])(&args.args)
102 }
103 DataType::LargeUtf8 => {
104 make_scalar_function(overlay::<i64>, vec![])(&args.args)
105 }
106 other => exec_err!("Unsupported data type {other:?} for function overlay"),
107 }
108 }
109
110 fn documentation(&self) -> Option<&Documentation> {
111 self.doc()
112 }
113}
114
115fn byte_index_for_char(string: &str, char_idx: usize, is_ascii: bool) -> usize {
118 if is_ascii {
119 char_idx.min(string.len())
120 } else {
121 string
122 .char_indices()
123 .nth(char_idx)
124 .map_or(string.len(), |(byte_idx, _)| byte_idx)
125 }
126}
127
128fn overlay_one(
133 string: &str,
134 characters: &str,
135 start_pos: i64,
136 replace_len: i64,
137) -> Result<String> {
138 if start_pos < 1 {
139 return exec_err!("negative substring length not allowed");
140 }
141
142 let is_ascii = string.is_ascii();
143 let string_char_len = if is_ascii {
144 string.len() as i64
145 } else {
146 string.chars().count() as i64
147 };
148
149 let start_char_idx = start_pos - 1;
157 let end_char_idx = start_char_idx.saturating_add(replace_len);
158
159 let prefix_char_idx = usize::try_from(start_char_idx).unwrap_or(usize::MAX);
160 let prefix_end_byte = byte_index_for_char(string, prefix_char_idx, is_ascii);
161
162 let mut res = String::with_capacity(string.len() + characters.len());
163 res.push_str(&string[..prefix_end_byte]);
164 res.push_str(characters);
165
166 if end_char_idx < string_char_len {
167 let suffix_char_idx = usize::try_from(end_char_idx.max(0)).unwrap_or(usize::MAX);
168 let suffix_start_byte = byte_index_for_char(string, suffix_char_idx, is_ascii);
169 res.push_str(&string[suffix_start_byte..]);
170 }
171 Ok(res)
172}
173
174macro_rules! process_overlay {
175 ($string_array:expr, $characters_array:expr, $pos_array:expr) => {{
177 $string_array
178 .iter()
179 .zip($characters_array.iter())
180 .zip($pos_array.iter())
181 .map(|((string, characters), start_pos)| {
182 match (string, characters, start_pos) {
183 (Some(string), Some(characters), Some(start_pos)) => {
184 let replace_len = characters.chars().count() as i64;
185 overlay_one(string, characters, start_pos, replace_len).map(Some)
186 }
187 _ => Ok(None),
188 }
189 })
190 .collect::<Result<GenericStringArray<T>>>()
191 }};
192
193 ($string_array:expr, $characters_array:expr, $pos_array:expr, $len_array:expr) => {{
195 $string_array
196 .iter()
197 .zip($characters_array.iter())
198 .zip($pos_array.iter())
199 .zip($len_array.iter())
200 .map(|(((string, characters), start_pos), replace_len)| {
201 match (string, characters, start_pos, replace_len) {
202 (
203 Some(string),
204 Some(characters),
205 Some(start_pos),
206 Some(replace_len),
207 ) => {
208 overlay_one(string, characters, start_pos, replace_len).map(Some)
209 }
210 _ => Ok(None),
211 }
212 })
213 .collect::<Result<GenericStringArray<T>>>()
214 }};
215}
216
217fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
229 if !matches!(args.len(), 3 | 4) {
230 return exec_err!(
231 "overlay was called with {} arguments. It requires 3 or 4.",
232 args.len()
233 );
234 }
235 if args[0].data_type() == &DataType::Utf8View {
236 string_view_overlay::<T>(args)
237 } else {
238 string_overlay::<T>(args)
239 }
240}
241
242fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
243 let string_array = as_generic_string_array::<T>(&args[0])?;
244 let characters_array = as_generic_string_array::<T>(&args[1])?;
245 let pos_array = as_int64_array(&args[2])?;
246
247 let result = if args.len() == 4 {
248 let len_array = as_int64_array(&args[3])?;
249 process_overlay!(string_array, characters_array, pos_array, len_array)?
250 } else {
251 process_overlay!(string_array, characters_array, pos_array)?
252 };
253 Ok(Arc::new(result) as ArrayRef)
254}
255
256fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
257 let string_array = as_string_view_array(&args[0])?;
258 let characters_array = as_string_view_array(&args[1])?;
259 let pos_array = as_int64_array(&args[2])?;
260
261 let result = if args.len() == 4 {
262 let len_array = as_int64_array(&args[3])?;
263 process_overlay!(string_array, characters_array, pos_array, len_array)?
264 } else {
265 process_overlay!(string_array, characters_array, pos_array)?
266 };
267 Ok(Arc::new(result) as ArrayRef)
268}
269
270#[cfg(test)]
271mod tests {
272 use arrow::array::{Int64Array, StringArray};
273
274 use super::*;
275
276 #[test]
277 fn to_overlay() -> Result<()> {
278 let string =
279 Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"]));
280 let replace_string =
281 Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"]));
282 let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
286 let result = as_generic_string_array::<i32>(&res).unwrap();
287 let expected = StringArray::from(vec!["123abc", "qwertyasdfg", "ijkz", "Thomas"]);
290 assert_eq!(&expected, result);
291
292 Ok(())
293 }
294}