1use arrow::array::{
21 Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
22 NullBufferBuilder, OffsetSizeTrait, new_null_array,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{Result, exec_err, utils::take_function_args};
30use datafusion_expr::{
31 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32 ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::sync::Arc;
40
41make_udf_expr_and_func!(ArrayReplace,
43 array_replace,
44 array from to,
45 "replaces the first occurrence of the specified element with another specified element.",
46 array_replace_udf
47);
48make_udf_expr_and_func!(ArrayReplaceN,
49 array_replace_n,
50 array from to max,
51 "replaces the first `max` occurrences of the specified element with another specified element.",
52 array_replace_n_udf
53);
54make_udf_expr_and_func!(ArrayReplaceAll,
55 array_replace_all,
56 array from to,
57 "replaces all occurrences of the specified element with another specified element.",
58 array_replace_all_udf
59);
60
61#[user_doc(
62 doc_section(label = "Array Functions"),
63 description = "Replaces the first occurrence of the specified element with another specified element.",
64 syntax_example = "array_replace(array, from, to)",
65 sql_example = r#"```sql
66> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
67+--------------------------------------------------------+
68| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
69+--------------------------------------------------------+
70| [1, 5, 2, 3, 2, 1, 4] |
71+--------------------------------------------------------+
72```"#,
73 argument(
74 name = "array",
75 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
76 ),
77 argument(name = "from", description = "Initial element."),
78 argument(name = "to", description = "Final element.")
79)]
80#[derive(Debug, PartialEq, Eq, Hash)]
81pub struct ArrayReplace {
82 signature: Signature,
83 aliases: Vec<String>,
84}
85
86impl Default for ArrayReplace {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl ArrayReplace {
93 pub fn new() -> Self {
94 Self {
95 signature: Signature {
96 type_signature: TypeSignature::ArraySignature(
97 ArrayFunctionSignature::Array {
98 arguments: vec![
99 ArrayFunctionArgument::Array,
100 ArrayFunctionArgument::Element,
101 ArrayFunctionArgument::Element,
102 ],
103 array_coercion: Some(ListCoercion::FixedSizedListToList),
104 },
105 ),
106 volatility: Volatility::Immutable,
107 parameter_names: None,
108 },
109 aliases: vec![String::from("list_replace")],
110 }
111 }
112}
113
114impl ScalarUDFImpl for ArrayReplace {
115 fn name(&self) -> &str {
116 "array_replace"
117 }
118
119 fn signature(&self) -> &Signature {
120 &self.signature
121 }
122
123 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
124 Ok(args[0].clone())
125 }
126
127 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
128 make_scalar_function(array_replace_inner)(&args.args)
129 }
130
131 fn aliases(&self) -> &[String] {
132 &self.aliases
133 }
134
135 fn documentation(&self) -> Option<&Documentation> {
136 self.doc()
137 }
138}
139
140#[user_doc(
141 doc_section(label = "Array Functions"),
142 description = "Replaces the first `max` occurrences of the specified element with another specified element.",
143 syntax_example = "array_replace_n(array, from, to, max)",
144 sql_example = r#"```sql
145> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
146+-------------------------------------------------------------------+
147| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
148+-------------------------------------------------------------------+
149| [1, 5, 5, 3, 2, 1, 4] |
150+-------------------------------------------------------------------+
151```"#,
152 argument(
153 name = "array",
154 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
155 ),
156 argument(name = "from", description = "Initial element."),
157 argument(name = "to", description = "Final element."),
158 argument(name = "max", description = "Number of first occurrences to replace.")
159)]
160#[derive(Debug, PartialEq, Eq, Hash)]
161pub(super) struct ArrayReplaceN {
162 signature: Signature,
163 aliases: Vec<String>,
164}
165
166impl ArrayReplaceN {
167 pub fn new() -> Self {
168 Self {
169 signature: Signature {
170 type_signature: TypeSignature::ArraySignature(
171 ArrayFunctionSignature::Array {
172 arguments: vec![
173 ArrayFunctionArgument::Array,
174 ArrayFunctionArgument::Element,
175 ArrayFunctionArgument::Element,
176 ArrayFunctionArgument::Index,
177 ],
178 array_coercion: Some(ListCoercion::FixedSizedListToList),
179 },
180 ),
181 volatility: Volatility::Immutable,
182 parameter_names: None,
183 },
184 aliases: vec![String::from("list_replace_n")],
185 }
186 }
187}
188
189impl ScalarUDFImpl for ArrayReplaceN {
190 fn name(&self) -> &str {
191 "array_replace_n"
192 }
193
194 fn signature(&self) -> &Signature {
195 &self.signature
196 }
197
198 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
199 Ok(args[0].clone())
200 }
201
202 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
203 make_scalar_function(array_replace_n_inner)(&args.args)
204 }
205
206 fn aliases(&self) -> &[String] {
207 &self.aliases
208 }
209
210 fn documentation(&self) -> Option<&Documentation> {
211 self.doc()
212 }
213}
214
215#[user_doc(
216 doc_section(label = "Array Functions"),
217 description = "Replaces all occurrences of the specified element with another specified element.",
218 syntax_example = "array_replace_all(array, from, to)",
219 sql_example = r#"```sql
220> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
221+------------------------------------------------------------+
222| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
223+------------------------------------------------------------+
224| [1, 5, 5, 3, 5, 1, 4] |
225+------------------------------------------------------------+
226```"#,
227 argument(
228 name = "array",
229 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
230 ),
231 argument(name = "from", description = "Initial element."),
232 argument(name = "to", description = "Final element.")
233)]
234#[derive(Debug, PartialEq, Eq, Hash)]
235pub(super) struct ArrayReplaceAll {
236 signature: Signature,
237 aliases: Vec<String>,
238}
239
240impl ArrayReplaceAll {
241 pub fn new() -> Self {
242 Self {
243 signature: Signature {
244 type_signature: TypeSignature::ArraySignature(
245 ArrayFunctionSignature::Array {
246 arguments: vec![
247 ArrayFunctionArgument::Array,
248 ArrayFunctionArgument::Element,
249 ArrayFunctionArgument::Element,
250 ],
251 array_coercion: Some(ListCoercion::FixedSizedListToList),
252 },
253 ),
254 volatility: Volatility::Immutable,
255 parameter_names: None,
256 },
257 aliases: vec![String::from("list_replace_all")],
258 }
259 }
260}
261
262impl ScalarUDFImpl for ArrayReplaceAll {
263 fn name(&self) -> &str {
264 "array_replace_all"
265 }
266
267 fn signature(&self) -> &Signature {
268 &self.signature
269 }
270
271 fn return_type(&self, args: &[DataType]) -> Result<DataType> {
272 Ok(args[0].clone())
273 }
274
275 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
276 make_scalar_function(array_replace_all_inner)(&args.args)
277 }
278
279 fn aliases(&self) -> &[String] {
280 &self.aliases
281 }
282
283 fn documentation(&self) -> Option<&Documentation> {
284 self.doc()
285 }
286}
287
288fn general_replace<O: OffsetSizeTrait>(
306 list_array: &GenericListArray<O>,
307 from_array: &ArrayRef,
308 to_array: &ArrayRef,
309 arr_n: &[i64],
310) -> Result<ArrayRef> {
311 let mut offsets: Vec<O> = vec![O::usize_as(0)];
313 let values = list_array.values();
314 let original_data = values.to_data();
315 let to_data = to_array.to_data();
316 let capacity = Capacities::Array(original_data.len());
317
318 let mut mutable = MutableArrayData::with_capacities(
320 vec![&original_data, &to_data],
321 false,
322 capacity,
323 );
324
325 let mut valid = NullBufferBuilder::new(list_array.len());
326
327 for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
328 if list_array.is_null(row_index) {
329 offsets.push(offsets[row_index]);
330 valid.append_null();
331 continue;
332 }
333
334 let start = offset_window[0];
335 let end = offset_window[1];
336
337 let list_array_row = list_array.value(row_index);
338
339 let eq_array =
342 compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
343
344 let original_idx = O::usize_as(0);
345 let replace_idx = O::usize_as(1);
346 let n = arr_n[row_index];
347 let mut counter = 0;
348
349 if n <= 0 || !eq_array.has_true() {
351 mutable.extend(
352 original_idx.to_usize().unwrap(),
353 start.to_usize().unwrap(),
354 end.to_usize().unwrap(),
355 );
356 offsets.push(offsets[row_index] + (end - start));
357 valid.append_non_null();
358 continue;
359 }
360
361 let mut pending_retain: Option<O> = None;
362 for (i, to_replace) in eq_array.iter().enumerate() {
363 let i = O::usize_as(i);
364 if to_replace == Some(true) && counter < n {
365 if let Some(rs) = pending_retain.take() {
367 mutable.extend(
368 original_idx.to_usize().unwrap(),
369 (start + rs).to_usize().unwrap(),
370 (start + i).to_usize().unwrap(),
371 );
372 }
373 mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
374 counter += 1;
375 if counter == n {
376 mutable.extend(
378 original_idx.to_usize().unwrap(),
379 (start + i).to_usize().unwrap() + 1,
380 end.to_usize().unwrap(),
381 );
382 break;
383 }
384 } else if pending_retain.is_none() {
385 pending_retain = Some(i);
386 }
387 }
388
389 if counter < n
392 && let Some(rs) = pending_retain
393 {
394 mutable.extend(
395 original_idx.to_usize().unwrap(),
396 (start + rs).to_usize().unwrap(),
397 end.to_usize().unwrap(),
398 );
399 }
400
401 offsets.push(offsets[row_index] + (end - start));
402 valid.append_non_null();
403 }
404
405 let data = mutable.freeze();
406
407 Ok(Arc::new(GenericListArray::<O>::try_new(
408 Arc::new(Field::new_list_field(list_array.value_type(), true)),
409 OffsetBuffer::<O>::new(offsets.into()),
410 arrow::array::make_array(data),
411 valid.finish(),
412 )?))
413}
414
415fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
416 let [array, from, to] = take_function_args("array_replace", args)?;
417
418 let arr_n = vec![1; array.len()];
420 match array.data_type() {
421 DataType::List(_) => {
422 let list_array = array.as_list::<i32>();
423 general_replace::<i32>(list_array, from, to, &arr_n)
424 }
425 DataType::LargeList(_) => {
426 let list_array = array.as_list::<i64>();
427 general_replace::<i64>(list_array, from, to, &arr_n)
428 }
429 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
430 array_type => exec_err!("array_replace does not support type '{array_type}'."),
431 }
432}
433
434fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
435 let [array, from, to, max] = take_function_args("array_replace_n", args)?;
436
437 let arr_n = as_int64_array(max)?.values().to_vec();
439 match array.data_type() {
440 DataType::List(_) => {
441 let list_array = array.as_list::<i32>();
442 general_replace::<i32>(list_array, from, to, &arr_n)
443 }
444 DataType::LargeList(_) => {
445 let list_array = array.as_list::<i64>();
446 general_replace::<i64>(list_array, from, to, &arr_n)
447 }
448 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
449 array_type => {
450 exec_err!("array_replace_n does not support type '{array_type}'.")
451 }
452 }
453}
454
455fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
456 let [array, from, to] = take_function_args("array_replace_all", args)?;
457
458 let arr_n = vec![i64::MAX; array.len()];
460 match array.data_type() {
461 DataType::List(_) => {
462 let list_array = array.as_list::<i32>();
463 general_replace::<i32>(list_array, from, to, &arr_n)
464 }
465 DataType::LargeList(_) => {
466 let list_array = array.as_list::<i64>();
467 general_replace::<i64>(list_array, from, to, &arr_n)
468 }
469 DataType::Null => Ok(new_null_array(array.data_type(), 1)),
470 array_type => {
471 exec_err!("array_replace_all does not support type '{array_type}'.")
472 }
473 }
474}