1use arrow::array::{
21 new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray,
22 MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32 ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42make_udf_expr_and_func!(ArrayReplace,
44 array_replace,
45 array from to,
46 "replaces the first occurrence of the specified element with another specified element.",
47 array_replace_udf
48);
49make_udf_expr_and_func!(ArrayReplaceN,
50 array_replace_n,
51 array from to max,
52 "replaces the first `max` occurrences of the specified element with another specified element.",
53 array_replace_n_udf
54);
55make_udf_expr_and_func!(ArrayReplaceAll,
56 array_replace_all,
57 array from to,
58 "replaces all occurrences of the specified element with another specified element.",
59 array_replace_all_udf
60);
61
62#[user_doc(
63 doc_section(label = "Array Functions"),
64 description = "Replaces the first occurrence of the specified element with another specified element.",
65 syntax_example = "array_replace(array, from, to)",
66 sql_example = r#"```sql
67> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
68+--------------------------------------------------------+
69| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
70+--------------------------------------------------------+
71| [1, 5, 2, 3, 2, 1, 4] |
72+--------------------------------------------------------+
73```"#,
74 argument(
75 name = "array",
76 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77 ),
78 argument(name = "from", description = "Initial element."),
79 argument(name = "to", description = "Final element.")
80)]
81#[derive(Debug, PartialEq, Eq, Hash)]
82pub struct ArrayReplace {
83 signature: Signature,
84 aliases: Vec<String>,
85}
86
87impl Default for ArrayReplace {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl ArrayReplace {
94 pub fn new() -> Self {
95 Self {
96 signature: Signature {
97 type_signature: TypeSignature::ArraySignature(
98 ArrayFunctionSignature::Array {
99 arguments: vec![
100 ArrayFunctionArgument::Array,
101 ArrayFunctionArgument::Element,
102 ArrayFunctionArgument::Element,
103 ],
104 array_coercion: Some(ListCoercion::FixedSizedListToList),
105 },
106 ),
107 volatility: Volatility::Immutable,
108 parameter_names: None,
109 },
110 aliases: vec![String::from("list_replace")],
111 }
112 }
113}
114
115impl ScalarUDFImpl for ArrayReplace {
116 fn as_any(&self) -> &dyn Any {
117 self
118 }
119
120 fn name(&self) -> &str {
121 "array_replace"
122 }
123
124 fn signature(&self) -> &Signature {
125 &self.signature
126 }
127
128 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
129 Ok(args[0].clone())
130 }
131
132 fn invoke_with_args(
133 &self,
134 args: datafusion_expr::ScalarFunctionArgs,
135 ) -> Result<ColumnarValue> {
136 make_scalar_function(array_replace_inner)(&args.args)
137 }
138
139 fn aliases(&self) -> &[String] {
140 &self.aliases
141 }
142
143 fn documentation(&self) -> Option<&Documentation> {
144 self.doc()
145 }
146}
147
148#[user_doc(
149 doc_section(label = "Array Functions"),
150 description = "Replaces the first `max` occurrences of the specified element with another specified element.",
151 syntax_example = "array_replace_n(array, from, to, max)",
152 sql_example = r#"```sql
153> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
154+-------------------------------------------------------------------+
155| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
156+-------------------------------------------------------------------+
157| [1, 5, 5, 3, 2, 1, 4] |
158+-------------------------------------------------------------------+
159```"#,
160 argument(
161 name = "array",
162 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
163 ),
164 argument(name = "from", description = "Initial element."),
165 argument(name = "to", description = "Final element."),
166 argument(name = "max", description = "Number of first occurrences to replace.")
167)]
168#[derive(Debug, PartialEq, Eq, Hash)]
169pub(super) struct ArrayReplaceN {
170 signature: Signature,
171 aliases: Vec<String>,
172}
173
174impl ArrayReplaceN {
175 pub fn new() -> Self {
176 Self {
177 signature: Signature {
178 type_signature: TypeSignature::ArraySignature(
179 ArrayFunctionSignature::Array {
180 arguments: vec![
181 ArrayFunctionArgument::Array,
182 ArrayFunctionArgument::Element,
183 ArrayFunctionArgument::Element,
184 ArrayFunctionArgument::Index,
185 ],
186 array_coercion: Some(ListCoercion::FixedSizedListToList),
187 },
188 ),
189 volatility: Volatility::Immutable,
190 parameter_names: None,
191 },
192 aliases: vec![String::from("list_replace_n")],
193 }
194 }
195}
196
197impl ScalarUDFImpl for ArrayReplaceN {
198 fn as_any(&self) -> &dyn Any {
199 self
200 }
201
202 fn name(&self) -> &str {
203 "array_replace_n"
204 }
205
206 fn signature(&self) -> &Signature {
207 &self.signature
208 }
209
210 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
211 Ok(args[0].clone())
212 }
213
214 fn invoke_with_args(
215 &self,
216 args: datafusion_expr::ScalarFunctionArgs,
217 ) -> Result<ColumnarValue> {
218 make_scalar_function(array_replace_n_inner)(&args.args)
219 }
220
221 fn aliases(&self) -> &[String] {
222 &self.aliases
223 }
224
225 fn documentation(&self) -> Option<&Documentation> {
226 self.doc()
227 }
228}
229
230#[user_doc(
231 doc_section(label = "Array Functions"),
232 description = "Replaces all occurrences of the specified element with another specified element.",
233 syntax_example = "array_replace_all(array, from, to)",
234 sql_example = r#"```sql
235> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
236+------------------------------------------------------------+
237| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
238+------------------------------------------------------------+
239| [1, 5, 5, 3, 5, 1, 4] |
240+------------------------------------------------------------+
241```"#,
242 argument(
243 name = "array",
244 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
245 ),
246 argument(name = "from", description = "Initial element."),
247 argument(name = "to", description = "Final element.")
248)]
249#[derive(Debug, PartialEq, Eq, Hash)]
250pub(super) struct ArrayReplaceAll {
251 signature: Signature,
252 aliases: Vec<String>,
253}
254
255impl ArrayReplaceAll {
256 pub fn new() -> Self {
257 Self {
258 signature: Signature {
259 type_signature: TypeSignature::ArraySignature(
260 ArrayFunctionSignature::Array {
261 arguments: vec![
262 ArrayFunctionArgument::Array,
263 ArrayFunctionArgument::Element,
264 ArrayFunctionArgument::Element,
265 ],
266 array_coercion: Some(ListCoercion::FixedSizedListToList),
267 },
268 ),
269 volatility: Volatility::Immutable,
270 parameter_names: None,
271 },
272 aliases: vec![String::from("list_replace_all")],
273 }
274 }
275}
276
277impl ScalarUDFImpl for ArrayReplaceAll {
278 fn as_any(&self) -> &dyn Any {
279 self
280 }
281
282 fn name(&self) -> &str {
283 "array_replace_all"
284 }
285
286 fn signature(&self) -> &Signature {
287 &self.signature
288 }
289
290 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
291 Ok(args[0].clone())
292 }
293
294 fn invoke_with_args(
295 &self,
296 args: datafusion_expr::ScalarFunctionArgs,
297 ) -> Result<ColumnarValue> {
298 make_scalar_function(array_replace_all_inner)(&args.args)
299 }
300
301 fn aliases(&self) -> &[String] {
302 &self.aliases
303 }
304
305 fn documentation(&self) -> Option<&Documentation> {
306 self.doc()
307 }
308}
309
310fn general_replace<O: OffsetSizeTrait>(
328 list_array: &GenericListArray<O>,
329 from_array: &ArrayRef,
330 to_array: &ArrayRef,
331 arr_n: Vec<i64>,
332) -> Result<ArrayRef> {
333 let mut offsets: Vec<O> = vec![O::usize_as(0)];
335 let values = list_array.values();
336 let original_data = values.to_data();
337 let to_data = to_array.to_data();
338 let capacity = Capacities::Array(original_data.len());
339
340 let mut mutable = MutableArrayData::with_capacities(
342 vec![&original_data, &to_data],
343 false,
344 capacity,
345 );
346
347 let mut valid = NullBufferBuilder::new(list_array.len());
348
349 for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
350 if list_array.is_null(row_index) {
351 offsets.push(offsets[row_index]);
352 valid.append_null();
353 continue;
354 }
355
356 let start = offset_window[0];
357 let end = offset_window[1];
358
359 let list_array_row = list_array.value(row_index);
360
361 let eq_array =
364 compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
365
366 let original_idx = O::usize_as(0);
367 let replace_idx = O::usize_as(1);
368 let n = arr_n[row_index];
369 let mut counter = 0;
370
371 if eq_array.false_count() == eq_array.len() {
373 mutable.extend(
374 original_idx.to_usize().unwrap(),
375 start.to_usize().unwrap(),
376 end.to_usize().unwrap(),
377 );
378 offsets.push(offsets[row_index] + (end - start));
379 valid.append_non_null();
380 continue;
381 }
382
383 for (i, to_replace) in eq_array.iter().enumerate() {
384 let i = O::usize_as(i);
385 if let Some(true) = to_replace {
386 mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
387 counter += 1;
388 if counter == n {
389 mutable.extend(
391 original_idx.to_usize().unwrap(),
392 (start + i).to_usize().unwrap() + 1,
393 end.to_usize().unwrap(),
394 );
395 break;
396 }
397 } else {
398 mutable.extend(
400 original_idx.to_usize().unwrap(),
401 (start + i).to_usize().unwrap(),
402 (start + i).to_usize().unwrap() + 1,
403 );
404 }
405 }
406
407 offsets.push(offsets[row_index] + (end - start));
408 valid.append_non_null();
409 }
410
411 let data = mutable.freeze();
412
413 Ok(Arc::new(GenericListArray::<O>::try_new(
414 Arc::new(Field::new_list_field(list_array.value_type(), true)),
415 OffsetBuffer::<O>::new(offsets.into()),
416 arrow::array::make_array(data),
417 valid.finish(),
418 )?))
419}
420
421pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
422 let [array, from, to] = take_function_args("array_replace", args)?;
423
424 let arr_n = vec![1; array.len()];
426 match array.data_type() {
427 DataType::List(_) => {
428 let list_array = array.as_list::<i32>();
429 general_replace::<i32>(list_array, from, to, arr_n)
430 }
431 DataType::LargeList(_) => {
432 let list_array = array.as_list::<i64>();
433 general_replace::<i64>(list_array, from, to, arr_n)
434 }
435 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
436 array_type => exec_err!("array_replace does not support type '{array_type}'."),
437 }
438}
439
440pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
441 let [array, from, to, max] = take_function_args("array_replace_n", args)?;
442
443 let arr_n = as_int64_array(max)?.values().to_vec();
445 match array.data_type() {
446 DataType::List(_) => {
447 let list_array = array.as_list::<i32>();
448 general_replace::<i32>(list_array, from, to, arr_n)
449 }
450 DataType::LargeList(_) => {
451 let list_array = array.as_list::<i64>();
452 general_replace::<i64>(list_array, from, to, arr_n)
453 }
454 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
455 array_type => {
456 exec_err!("array_replace_n does not support type '{array_type}'.")
457 }
458 }
459}
460
461pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
462 let [array, from, to] = take_function_args("array_replace_all", args)?;
463
464 let arr_n = vec![i64::MAX; array.len()];
466 match array.data_type() {
467 DataType::List(_) => {
468 let list_array = array.as_list::<i32>();
469 general_replace::<i32>(list_array, from, to, arr_n)
470 }
471 DataType::LargeList(_) => {
472 let list_array = array.as_list::<i64>();
473 general_replace::<i64>(list_array, from, to, arr_n)
474 }
475 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
476 array_type => {
477 exec_err!("array_replace_all does not support type '{array_type}'.")
478 }
479 }
480}