1use arrow::array::{
21 new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray,
22 MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32 ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42make_udf_expr_and_func!(ArrayReplace,
44 array_replace,
45 array from to,
46 "replaces the first occurrence of the specified element with another specified element.",
47 array_replace_udf
48);
49make_udf_expr_and_func!(ArrayReplaceN,
50 array_replace_n,
51 array from to max,
52 "replaces the first `max` occurrences of the specified element with another specified element.",
53 array_replace_n_udf
54);
55make_udf_expr_and_func!(ArrayReplaceAll,
56 array_replace_all,
57 array from to,
58 "replaces all occurrences of the specified element with another specified element.",
59 array_replace_all_udf
60);
61
62#[user_doc(
63 doc_section(label = "Array Functions"),
64 description = "Replaces the first occurrence of the specified element with another specified element.",
65 syntax_example = "array_replace(array, from, to)",
66 sql_example = r#"```sql
67> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
68+--------------------------------------------------------+
69| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
70+--------------------------------------------------------+
71| [1, 5, 2, 3, 2, 1, 4] |
72+--------------------------------------------------------+
73```"#,
74 argument(
75 name = "array",
76 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77 ),
78 argument(name = "from", description = "Initial element."),
79 argument(name = "to", description = "Final element.")
80)]
81#[derive(Debug)]
82pub struct ArrayReplace {
83 signature: Signature,
84 aliases: Vec<String>,
85}
86
87impl Default for ArrayReplace {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl ArrayReplace {
94 pub fn new() -> Self {
95 Self {
96 signature: Signature {
97 type_signature: TypeSignature::ArraySignature(
98 ArrayFunctionSignature::Array {
99 arguments: vec![
100 ArrayFunctionArgument::Array,
101 ArrayFunctionArgument::Element,
102 ArrayFunctionArgument::Element,
103 ],
104 array_coercion: Some(ListCoercion::FixedSizedListToList),
105 },
106 ),
107 volatility: Volatility::Immutable,
108 },
109 aliases: vec![String::from("list_replace")],
110 }
111 }
112}
113
114impl ScalarUDFImpl for ArrayReplace {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "array_replace"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
128 Ok(args[0].clone())
129 }
130
131 fn invoke_with_args(
132 &self,
133 args: datafusion_expr::ScalarFunctionArgs,
134 ) -> Result<ColumnarValue> {
135 make_scalar_function(array_replace_inner)(&args.args)
136 }
137
138 fn aliases(&self) -> &[String] {
139 &self.aliases
140 }
141
142 fn documentation(&self) -> Option<&Documentation> {
143 self.doc()
144 }
145}
146
147#[user_doc(
148 doc_section(label = "Array Functions"),
149 description = "Replaces the first `max` occurrences of the specified element with another specified element.",
150 syntax_example = "array_replace_n(array, from, to, max)",
151 sql_example = r#"```sql
152> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
153+-------------------------------------------------------------------+
154| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
155+-------------------------------------------------------------------+
156| [1, 5, 5, 3, 2, 1, 4] |
157+-------------------------------------------------------------------+
158```"#,
159 argument(
160 name = "array",
161 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
162 ),
163 argument(name = "from", description = "Initial element."),
164 argument(name = "to", description = "Final element."),
165 argument(name = "max", description = "Number of first occurrences to replace.")
166)]
167#[derive(Debug)]
168pub(super) struct ArrayReplaceN {
169 signature: Signature,
170 aliases: Vec<String>,
171}
172
173impl ArrayReplaceN {
174 pub fn new() -> Self {
175 Self {
176 signature: Signature {
177 type_signature: TypeSignature::ArraySignature(
178 ArrayFunctionSignature::Array {
179 arguments: vec![
180 ArrayFunctionArgument::Array,
181 ArrayFunctionArgument::Element,
182 ArrayFunctionArgument::Element,
183 ArrayFunctionArgument::Index,
184 ],
185 array_coercion: Some(ListCoercion::FixedSizedListToList),
186 },
187 ),
188 volatility: Volatility::Immutable,
189 },
190 aliases: vec![String::from("list_replace_n")],
191 }
192 }
193}
194
195impl ScalarUDFImpl for ArrayReplaceN {
196 fn as_any(&self) -> &dyn Any {
197 self
198 }
199
200 fn name(&self) -> &str {
201 "array_replace_n"
202 }
203
204 fn signature(&self) -> &Signature {
205 &self.signature
206 }
207
208 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
209 Ok(args[0].clone())
210 }
211
212 fn invoke_with_args(
213 &self,
214 args: datafusion_expr::ScalarFunctionArgs,
215 ) -> Result<ColumnarValue> {
216 make_scalar_function(array_replace_n_inner)(&args.args)
217 }
218
219 fn aliases(&self) -> &[String] {
220 &self.aliases
221 }
222
223 fn documentation(&self) -> Option<&Documentation> {
224 self.doc()
225 }
226}
227
228#[user_doc(
229 doc_section(label = "Array Functions"),
230 description = "Replaces all occurrences of the specified element with another specified element.",
231 syntax_example = "array_replace_all(array, from, to)",
232 sql_example = r#"```sql
233> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
234+------------------------------------------------------------+
235| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
236+------------------------------------------------------------+
237| [1, 5, 5, 3, 5, 1, 4] |
238+------------------------------------------------------------+
239```"#,
240 argument(
241 name = "array",
242 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
243 ),
244 argument(name = "from", description = "Initial element."),
245 argument(name = "to", description = "Final element.")
246)]
247#[derive(Debug)]
248pub(super) struct ArrayReplaceAll {
249 signature: Signature,
250 aliases: Vec<String>,
251}
252
253impl ArrayReplaceAll {
254 pub fn new() -> Self {
255 Self {
256 signature: Signature {
257 type_signature: TypeSignature::ArraySignature(
258 ArrayFunctionSignature::Array {
259 arguments: vec![
260 ArrayFunctionArgument::Array,
261 ArrayFunctionArgument::Element,
262 ArrayFunctionArgument::Element,
263 ],
264 array_coercion: Some(ListCoercion::FixedSizedListToList),
265 },
266 ),
267 volatility: Volatility::Immutable,
268 },
269 aliases: vec![String::from("list_replace_all")],
270 }
271 }
272}
273
274impl ScalarUDFImpl for ArrayReplaceAll {
275 fn as_any(&self) -> &dyn Any {
276 self
277 }
278
279 fn name(&self) -> &str {
280 "array_replace_all"
281 }
282
283 fn signature(&self) -> &Signature {
284 &self.signature
285 }
286
287 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
288 Ok(args[0].clone())
289 }
290
291 fn invoke_with_args(
292 &self,
293 args: datafusion_expr::ScalarFunctionArgs,
294 ) -> Result<ColumnarValue> {
295 make_scalar_function(array_replace_all_inner)(&args.args)
296 }
297
298 fn aliases(&self) -> &[String] {
299 &self.aliases
300 }
301
302 fn documentation(&self) -> Option<&Documentation> {
303 self.doc()
304 }
305}
306
307fn general_replace<O: OffsetSizeTrait>(
325 list_array: &GenericListArray<O>,
326 from_array: &ArrayRef,
327 to_array: &ArrayRef,
328 arr_n: Vec<i64>,
329) -> Result<ArrayRef> {
330 let mut offsets: Vec<O> = vec![O::usize_as(0)];
332 let values = list_array.values();
333 let original_data = values.to_data();
334 let to_data = to_array.to_data();
335 let capacity = Capacities::Array(original_data.len());
336
337 let mut mutable = MutableArrayData::with_capacities(
339 vec![&original_data, &to_data],
340 false,
341 capacity,
342 );
343
344 let mut valid = NullBufferBuilder::new(list_array.len());
345
346 for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
347 if list_array.is_null(row_index) {
348 offsets.push(offsets[row_index]);
349 valid.append_null();
350 continue;
351 }
352
353 let start = offset_window[0];
354 let end = offset_window[1];
355
356 let list_array_row = list_array.value(row_index);
357
358 let eq_array =
361 compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
362
363 let original_idx = O::usize_as(0);
364 let replace_idx = O::usize_as(1);
365 let n = arr_n[row_index];
366 let mut counter = 0;
367
368 if eq_array.false_count() == eq_array.len() {
370 mutable.extend(
371 original_idx.to_usize().unwrap(),
372 start.to_usize().unwrap(),
373 end.to_usize().unwrap(),
374 );
375 offsets.push(offsets[row_index] + (end - start));
376 valid.append_non_null();
377 continue;
378 }
379
380 for (i, to_replace) in eq_array.iter().enumerate() {
381 let i = O::usize_as(i);
382 if let Some(true) = to_replace {
383 mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
384 counter += 1;
385 if counter == n {
386 mutable.extend(
388 original_idx.to_usize().unwrap(),
389 (start + i).to_usize().unwrap() + 1,
390 end.to_usize().unwrap(),
391 );
392 break;
393 }
394 } else {
395 mutable.extend(
397 original_idx.to_usize().unwrap(),
398 (start + i).to_usize().unwrap(),
399 (start + i).to_usize().unwrap() + 1,
400 );
401 }
402 }
403
404 offsets.push(offsets[row_index] + (end - start));
405 valid.append_non_null();
406 }
407
408 let data = mutable.freeze();
409
410 Ok(Arc::new(GenericListArray::<O>::try_new(
411 Arc::new(Field::new_list_field(list_array.value_type(), true)),
412 OffsetBuffer::<O>::new(offsets.into()),
413 arrow::array::make_array(data),
414 valid.finish(),
415 )?))
416}
417
418pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
419 let [array, from, to] = take_function_args("array_replace", args)?;
420
421 let arr_n = vec![1; array.len()];
423 match array.data_type() {
424 DataType::List(_) => {
425 let list_array = array.as_list::<i32>();
426 general_replace::<i32>(list_array, from, to, arr_n)
427 }
428 DataType::LargeList(_) => {
429 let list_array = array.as_list::<i64>();
430 general_replace::<i64>(list_array, from, to, arr_n)
431 }
432 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
433 array_type => exec_err!("array_replace does not support type '{array_type:?}'."),
434 }
435}
436
437pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
438 let [array, from, to, max] = take_function_args("array_replace_n", args)?;
439
440 let arr_n = as_int64_array(max)?.values().to_vec();
442 match array.data_type() {
443 DataType::List(_) => {
444 let list_array = array.as_list::<i32>();
445 general_replace::<i32>(list_array, from, to, arr_n)
446 }
447 DataType::LargeList(_) => {
448 let list_array = array.as_list::<i64>();
449 general_replace::<i64>(list_array, from, to, arr_n)
450 }
451 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
452 array_type => {
453 exec_err!("array_replace_n does not support type '{array_type:?}'.")
454 }
455 }
456}
457
458pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
459 let [array, from, to] = take_function_args("array_replace_all", args)?;
460
461 let arr_n = vec![i64::MAX; array.len()];
463 match array.data_type() {
464 DataType::List(_) => {
465 let list_array = array.as_list::<i32>();
466 general_replace::<i32>(list_array, from, to, arr_n)
467 }
468 DataType::LargeList(_) => {
469 let list_array = array.as_list::<i64>();
470 general_replace::<i64>(list_array, from, to, arr_n)
471 }
472 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
473 array_type => {
474 exec_err!("array_replace_all does not support type '{array_type:?}'.")
475 }
476 }
477}