1use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
19use arrow::datatypes::{DataType, Field, Schema};
20use arrow::{
21 array::{as_primitive_array, Capacities, MutableArrayData},
22 buffer::{NullBuffer, OffsetBuffer},
23 datatypes::ArrowNativeType,
24 record_batch::RecordBatch,
25};
26use datafusion::common::{
27 cast::{as_large_list_array, as_list_array},
28 internal_err, DataFusionError, Result as DataFusionResult,
29};
30use datafusion::logical_expr::ColumnarValue;
31use datafusion::physical_expr::PhysicalExpr;
32use std::hash::Hash;
33use std::{
34 any::Any,
35 fmt::{Debug, Display, Formatter},
36 sync::Arc,
37};
38
39const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
43
44#[derive(Debug, Eq)]
45pub struct ArrayInsert {
46 src_array_expr: Arc<dyn PhysicalExpr>,
47 pos_expr: Arc<dyn PhysicalExpr>,
48 item_expr: Arc<dyn PhysicalExpr>,
49 legacy_negative_index: bool,
50}
51
52impl Hash for ArrayInsert {
53 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54 self.src_array_expr.hash(state);
55 self.pos_expr.hash(state);
56 self.item_expr.hash(state);
57 self.legacy_negative_index.hash(state);
58 }
59}
60impl PartialEq for ArrayInsert {
61 fn eq(&self, other: &Self) -> bool {
62 self.src_array_expr.eq(&other.src_array_expr)
63 && self.pos_expr.eq(&other.pos_expr)
64 && self.item_expr.eq(&other.item_expr)
65 && self.legacy_negative_index.eq(&other.legacy_negative_index)
66 }
67}
68
69impl ArrayInsert {
70 pub fn new(
71 src_array_expr: Arc<dyn PhysicalExpr>,
72 pos_expr: Arc<dyn PhysicalExpr>,
73 item_expr: Arc<dyn PhysicalExpr>,
74 legacy_negative_index: bool,
75 ) -> Self {
76 Self {
77 src_array_expr,
78 pos_expr,
79 item_expr,
80 legacy_negative_index,
81 }
82 }
83
84 pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
85 match data_type {
86 DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
87 DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
88 data_type => Err(DataFusionError::Internal(format!(
89 "Unexpected src array type in ArrayInsert: {data_type:?}"
90 ))),
91 }
92 }
93}
94
95impl PhysicalExpr for ArrayInsert {
96 fn as_any(&self) -> &dyn Any {
97 self
98 }
99
100 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
101 unimplemented!()
102 }
103
104 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
105 self.array_type(&self.src_array_expr.data_type(input_schema)?)
106 }
107
108 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
109 self.src_array_expr.nullable(input_schema)
110 }
111
112 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
113 let pos_value = self
114 .pos_expr
115 .evaluate(batch)?
116 .into_array(batch.num_rows())?;
117
118 if !matches!(pos_value.data_type(), DataType::Int32) {
121 return Err(DataFusionError::Internal(format!(
122 "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32",
123 pos_value.data_type()
124 )));
125 }
126
127 let src_value = self
129 .src_array_expr
130 .evaluate(batch)?
131 .into_array(batch.num_rows())?;
132
133 let src_element_type = match self.array_type(src_value.data_type())? {
134 DataType::List(field) => &field.data_type().clone(),
135 DataType::LargeList(field) => &field.data_type().clone(),
136 _ => unreachable!(),
137 };
138
139 let item_value = self
141 .item_expr
142 .evaluate(batch)?
143 .into_array(batch.num_rows())?;
144 if item_value.data_type() != src_element_type {
145 return Err(DataFusionError::Internal(format!(
146 "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}",
147 src_element_type,
148 item_value.data_type()
149 )));
150 }
151
152 match src_value.data_type() {
153 DataType::List(_) => {
154 let list_array = as_list_array(&src_value)?;
155 array_insert(
156 list_array,
157 &item_value,
158 &pos_value,
159 self.legacy_negative_index,
160 )
161 }
162 DataType::LargeList(_) => {
163 let list_array = as_large_list_array(&src_value)?;
164 array_insert(
165 list_array,
166 &item_value,
167 &pos_value,
168 self.legacy_negative_index,
169 )
170 }
171 _ => unreachable!(), }
173 }
174
175 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
176 vec![&self.src_array_expr, &self.pos_expr, &self.item_expr]
177 }
178
179 fn with_new_children(
180 self: Arc<Self>,
181 children: Vec<Arc<dyn PhysicalExpr>>,
182 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
183 match children.len() {
184 3 => Ok(Arc::new(ArrayInsert::new(
185 Arc::clone(&children[0]),
186 Arc::clone(&children[1]),
187 Arc::clone(&children[2]),
188 self.legacy_negative_index,
189 ))),
190 _ => internal_err!("ArrayInsert should have exactly three childrens"),
191 }
192 }
193}
194
195fn array_insert<O: OffsetSizeTrait>(
196 list_array: &GenericListArray<O>,
197 items_array: &ArrayRef,
198 pos_array: &ArrayRef,
199 legacy_mode: bool,
200) -> DataFusionResult<ColumnarValue> {
201 let values = list_array.values();
208 let offsets = list_array.offsets();
209 let values_data = values.to_data();
210 let item_data = items_array.to_data();
211 let new_capacity = Capacities::Array(values_data.len() + item_data.len());
212
213 let mut mutable_values =
214 MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
215
216 let mut new_offsets = vec![O::usize_as(0)];
217 let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
218
219 let pos_data: &Int32Array = as_primitive_array(&pos_array); for (row_index, offset_window) in offsets.windows(2).enumerate() {
222 let pos = pos_data.values()[row_index];
223 let start = offset_window[0].as_usize();
224 let end = offset_window[1].as_usize();
225 let is_item_null = items_array.is_null(row_index);
226
227 if list_array.is_null(row_index) {
228 mutable_values.extend_nulls(1);
230 new_offsets.push(new_offsets[row_index] + O::one());
231 new_nulls.push(false);
232 continue;
233 }
234
235 if pos == 0 {
236 return Err(DataFusionError::Internal(
237 "Position for array_insert should be greter or less than zero".to_string(),
238 ));
239 }
240
241 if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
242 let corrected_pos = if pos > 0 {
243 (pos - 1).as_usize()
244 } else {
245 end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
246 };
247 let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
248 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
249 return Err(DataFusionError::Internal(format!(
250 "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
251 )));
252 }
253
254 if (start + corrected_pos) <= end {
255 mutable_values.extend(0, start, start + corrected_pos);
256 mutable_values.extend(1, row_index, row_index + 1);
257 mutable_values.extend(0, start + corrected_pos, end);
258 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
259 } else {
260 mutable_values.extend(0, start, end);
261 mutable_values.extend_nulls(new_array_len - (end - start));
262 mutable_values.extend(1, row_index, row_index + 1);
263 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
266 }
267 } else {
268 let base_offset = if legacy_mode { 1 } else { 0 };
273 let new_array_len = (-pos + base_offset).as_usize();
274 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
275 return Err(DataFusionError::Internal(format!(
276 "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
277 )));
278 }
279 mutable_values.extend(1, row_index, row_index + 1);
280 mutable_values.extend_nulls(new_array_len - (end - start + 1));
281 mutable_values.extend(0, start, end);
282 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
283 }
284 if is_item_null {
285 if (start == end) || (values.is_null(row_index)) {
286 new_nulls.push(false)
287 } else {
288 new_nulls.push(true)
289 }
290 } else {
291 new_nulls.push(true)
292 }
293 }
294
295 let data = make_array(mutable_values.freeze());
296 let data_type = match list_array.data_type() {
297 DataType::List(field) => field.data_type(),
298 DataType::LargeList(field) => field.data_type(),
299 _ => unreachable!(),
300 };
301 let new_array = GenericListArray::<O>::try_new(
302 Arc::new(Field::new("item", data_type.clone(), true)),
303 OffsetBuffer::new(new_offsets.into()),
304 data,
305 Some(NullBuffer::new(new_nulls.into())),
306 )?;
307
308 Ok(ColumnarValue::Array(Arc::new(new_array)))
309}
310
311impl Display for ArrayInsert {
312 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
313 write!(
314 f,
315 "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]",
316 self.src_array_expr, self.pos_expr, self.item_expr
317 )
318 }
319}
320
321#[cfg(test)]
322mod test {
323 use super::*;
324 use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
325 use arrow::datatypes::Int32Type;
326 use datafusion::common::Result;
327 use datafusion::physical_plan::ColumnarValue;
328 use std::sync::Arc;
329
330 #[test]
331 fn test_array_insert() -> Result<()> {
332 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
335 Some(vec![Some(1), Some(2), Some(3)]),
336 Some(vec![Some(4), Some(5)]),
337 Some(vec![None]),
338 Some(vec![Some(1), Some(2), Some(3)]),
339 Some(vec![Some(1), Some(2), Some(3)]),
340 None,
341 ]);
342
343 let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]);
344 let items = Int32Array::from(vec![
345 Some(10),
346 Some(20),
347 Some(30),
348 Some(100),
349 Some(100),
350 Some(40),
351 ]);
352
353 let ColumnarValue::Array(result) = array_insert(
354 &list,
355 &(Arc::new(items) as ArrayRef),
356 &(Arc::new(positions) as ArrayRef),
357 false,
358 )?
359 else {
360 unreachable!()
361 };
362
363 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
364 Some(vec![Some(1), Some(10), Some(2), Some(3)]),
365 Some(vec![Some(20), Some(4), Some(5)]),
366 Some(vec![Some(30), None]),
367 Some(vec![Some(1), Some(2), Some(3), None, Some(100)]),
368 Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]),
369 None,
370 ]);
371
372 assert_eq!(&result.to_data(), &expected.to_data());
373
374 Ok(())
375 }
376
377 #[test]
378 fn test_array_insert_negative_index() -> Result<()> {
379 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
382 Some(vec![Some(1), Some(2), Some(3)]),
383 Some(vec![Some(4), Some(5)]),
384 Some(vec![Some(1)]),
385 None,
386 ]);
387
388 let positions = Int32Array::from(vec![-2, -1, -3, -1]);
389 let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]);
390
391 let ColumnarValue::Array(result) = array_insert(
392 &list,
393 &(Arc::new(items) as ArrayRef),
394 &(Arc::new(positions) as ArrayRef),
395 false,
396 )?
397 else {
398 unreachable!()
399 };
400
401 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
402 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
403 Some(vec![Some(4), Some(5), Some(20)]),
404 Some(vec![Some(100), None, Some(1)]),
405 None,
406 ]);
407
408 assert_eq!(&result.to_data(), &expected.to_data());
409
410 Ok(())
411 }
412
413 #[test]
414 fn test_array_insert_legacy_mode() -> Result<()> {
415 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
417 Some(vec![Some(1), Some(2), Some(3)]),
418 Some(vec![Some(4), Some(5)]),
419 None,
420 ]);
421
422 let positions = Int32Array::from(vec![-1, -1, -1]);
423 let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
424
425 let ColumnarValue::Array(result) = array_insert(
426 &list,
427 &(Arc::new(items) as ArrayRef),
428 &(Arc::new(positions) as ArrayRef),
429 true,
430 )?
431 else {
432 unreachable!()
433 };
434
435 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
436 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
437 Some(vec![Some(4), Some(20), Some(5)]),
438 None,
439 ]);
440
441 assert_eq!(&result.to_data(), &expected.to_data());
442
443 Ok(())
444 }
445}