1use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
19use arrow::datatypes::{DataType, Field, Schema};
20use arrow::{
21 array::{as_primitive_array, Capacities, MutableArrayData},
22 buffer::{NullBuffer, OffsetBuffer},
23 datatypes::ArrowNativeType,
24 record_batch::RecordBatch,
25};
26use datafusion::common::{
27 cast::{as_large_list_array, as_list_array},
28 internal_err, DataFusionError, Result as DataFusionResult,
29};
30use datafusion::logical_expr::ColumnarValue;
31use datafusion::physical_expr::PhysicalExpr;
32use std::hash::Hash;
33use std::{
34 any::Any,
35 fmt::{Debug, Display, Formatter},
36 sync::Arc,
37};
38
39const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
43
44#[derive(Debug, Eq)]
45pub struct ArrayInsert {
46 src_array_expr: Arc<dyn PhysicalExpr>,
47 pos_expr: Arc<dyn PhysicalExpr>,
48 item_expr: Arc<dyn PhysicalExpr>,
49 legacy_negative_index: bool,
50}
51
52impl Hash for ArrayInsert {
53 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54 self.src_array_expr.hash(state);
55 self.pos_expr.hash(state);
56 self.item_expr.hash(state);
57 self.legacy_negative_index.hash(state);
58 }
59}
60impl PartialEq for ArrayInsert {
61 fn eq(&self, other: &Self) -> bool {
62 self.src_array_expr.eq(&other.src_array_expr)
63 && self.pos_expr.eq(&other.pos_expr)
64 && self.item_expr.eq(&other.item_expr)
65 && self.legacy_negative_index.eq(&other.legacy_negative_index)
66 }
67}
68
69impl ArrayInsert {
70 pub fn new(
71 src_array_expr: Arc<dyn PhysicalExpr>,
72 pos_expr: Arc<dyn PhysicalExpr>,
73 item_expr: Arc<dyn PhysicalExpr>,
74 legacy_negative_index: bool,
75 ) -> Self {
76 Self {
77 src_array_expr,
78 pos_expr,
79 item_expr,
80 legacy_negative_index,
81 }
82 }
83
84 pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
85 match data_type {
86 DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
87 DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
88 data_type => Err(DataFusionError::Internal(format!(
89 "Unexpected src array type in ArrayInsert: {:?}",
90 data_type
91 ))),
92 }
93 }
94}
95
96impl PhysicalExpr for ArrayInsert {
97 fn as_any(&self) -> &dyn Any {
98 self
99 }
100
101 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
102 unimplemented!()
103 }
104
105 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
106 self.array_type(&self.src_array_expr.data_type(input_schema)?)
107 }
108
109 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
110 self.src_array_expr.nullable(input_schema)
111 }
112
113 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
114 let pos_value = self
115 .pos_expr
116 .evaluate(batch)?
117 .into_array(batch.num_rows())?;
118
119 if !matches!(pos_value.data_type(), DataType::Int32) {
122 return Err(DataFusionError::Internal(format!(
123 "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32",
124 pos_value.data_type()
125 )));
126 }
127
128 let src_value = self
130 .src_array_expr
131 .evaluate(batch)?
132 .into_array(batch.num_rows())?;
133
134 let src_element_type = match self.array_type(src_value.data_type())? {
135 DataType::List(field) => &field.data_type().clone(),
136 DataType::LargeList(field) => &field.data_type().clone(),
137 _ => unreachable!(),
138 };
139
140 let item_value = self
142 .item_expr
143 .evaluate(batch)?
144 .into_array(batch.num_rows())?;
145 if item_value.data_type() != src_element_type {
146 return Err(DataFusionError::Internal(format!(
147 "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}",
148 src_element_type,
149 item_value.data_type()
150 )));
151 }
152
153 match src_value.data_type() {
154 DataType::List(_) => {
155 let list_array = as_list_array(&src_value)?;
156 array_insert(
157 list_array,
158 &item_value,
159 &pos_value,
160 self.legacy_negative_index,
161 )
162 }
163 DataType::LargeList(_) => {
164 let list_array = as_large_list_array(&src_value)?;
165 array_insert(
166 list_array,
167 &item_value,
168 &pos_value,
169 self.legacy_negative_index,
170 )
171 }
172 _ => unreachable!(), }
174 }
175
176 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
177 vec![&self.src_array_expr, &self.pos_expr, &self.item_expr]
178 }
179
180 fn with_new_children(
181 self: Arc<Self>,
182 children: Vec<Arc<dyn PhysicalExpr>>,
183 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
184 match children.len() {
185 3 => Ok(Arc::new(ArrayInsert::new(
186 Arc::clone(&children[0]),
187 Arc::clone(&children[1]),
188 Arc::clone(&children[2]),
189 self.legacy_negative_index,
190 ))),
191 _ => internal_err!("ArrayInsert should have exactly three childrens"),
192 }
193 }
194}
195
196fn array_insert<O: OffsetSizeTrait>(
197 list_array: &GenericListArray<O>,
198 items_array: &ArrayRef,
199 pos_array: &ArrayRef,
200 legacy_mode: bool,
201) -> DataFusionResult<ColumnarValue> {
202 let values = list_array.values();
209 let offsets = list_array.offsets();
210 let values_data = values.to_data();
211 let item_data = items_array.to_data();
212 let new_capacity = Capacities::Array(values_data.len() + item_data.len());
213
214 let mut mutable_values =
215 MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
216
217 let mut new_offsets = vec![O::usize_as(0)];
218 let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
219
220 let pos_data: &Int32Array = as_primitive_array(&pos_array); for (row_index, offset_window) in offsets.windows(2).enumerate() {
223 let pos = pos_data.values()[row_index];
224 let start = offset_window[0].as_usize();
225 let end = offset_window[1].as_usize();
226 let is_item_null = items_array.is_null(row_index);
227
228 if list_array.is_null(row_index) {
229 mutable_values.extend_nulls(1);
231 new_offsets.push(new_offsets[row_index] + O::one());
232 new_nulls.push(false);
233 continue;
234 }
235
236 if pos == 0 {
237 return Err(DataFusionError::Internal(
238 "Position for array_insert should be greter or less than zero".to_string(),
239 ));
240 }
241
242 if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
243 let corrected_pos = if pos > 0 {
244 (pos - 1).as_usize()
245 } else {
246 end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
247 };
248 let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
249 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
250 return Err(DataFusionError::Internal(format!(
251 "Max array length in Spark is {:?}, but got {:?}",
252 MAX_ROUNDED_ARRAY_LENGTH, new_array_len
253 )));
254 }
255
256 if (start + corrected_pos) <= end {
257 mutable_values.extend(0, start, start + corrected_pos);
258 mutable_values.extend(1, row_index, row_index + 1);
259 mutable_values.extend(0, start + corrected_pos, end);
260 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
261 } else {
262 mutable_values.extend(0, start, end);
263 mutable_values.extend_nulls(new_array_len - (end - start));
264 mutable_values.extend(1, row_index, row_index + 1);
265 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
268 }
269 } else {
270 let base_offset = if legacy_mode { 1 } else { 0 };
275 let new_array_len = (-pos + base_offset).as_usize();
276 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
277 return Err(DataFusionError::Internal(format!(
278 "Max array length in Spark is {:?}, but got {:?}",
279 MAX_ROUNDED_ARRAY_LENGTH, new_array_len
280 )));
281 }
282 mutable_values.extend(1, row_index, row_index + 1);
283 mutable_values.extend_nulls(new_array_len - (end - start + 1));
284 mutable_values.extend(0, start, end);
285 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
286 }
287 if is_item_null {
288 if (start == end) || (values.is_null(row_index)) {
289 new_nulls.push(false)
290 } else {
291 new_nulls.push(true)
292 }
293 } else {
294 new_nulls.push(true)
295 }
296 }
297
298 let data = make_array(mutable_values.freeze());
299 let data_type = match list_array.data_type() {
300 DataType::List(field) => field.data_type(),
301 DataType::LargeList(field) => field.data_type(),
302 _ => unreachable!(),
303 };
304 let new_array = GenericListArray::<O>::try_new(
305 Arc::new(Field::new("item", data_type.clone(), true)),
306 OffsetBuffer::new(new_offsets.into()),
307 data,
308 Some(NullBuffer::new(new_nulls.into())),
309 )?;
310
311 Ok(ColumnarValue::Array(Arc::new(new_array)))
312}
313
314impl Display for ArrayInsert {
315 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
316 write!(
317 f,
318 "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]",
319 self.src_array_expr, self.pos_expr, self.item_expr
320 )
321 }
322}
323
324#[cfg(test)]
325mod test {
326 use super::*;
327 use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
328 use arrow::datatypes::Int32Type;
329 use datafusion::common::Result;
330 use datafusion::physical_plan::ColumnarValue;
331 use std::sync::Arc;
332
333 #[test]
334 fn test_array_insert() -> Result<()> {
335 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
338 Some(vec![Some(1), Some(2), Some(3)]),
339 Some(vec![Some(4), Some(5)]),
340 Some(vec![None]),
341 Some(vec![Some(1), Some(2), Some(3)]),
342 Some(vec![Some(1), Some(2), Some(3)]),
343 None,
344 ]);
345
346 let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]);
347 let items = Int32Array::from(vec![
348 Some(10),
349 Some(20),
350 Some(30),
351 Some(100),
352 Some(100),
353 Some(40),
354 ]);
355
356 let ColumnarValue::Array(result) = array_insert(
357 &list,
358 &(Arc::new(items) as ArrayRef),
359 &(Arc::new(positions) as ArrayRef),
360 false,
361 )?
362 else {
363 unreachable!()
364 };
365
366 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
367 Some(vec![Some(1), Some(10), Some(2), Some(3)]),
368 Some(vec![Some(20), Some(4), Some(5)]),
369 Some(vec![Some(30), None]),
370 Some(vec![Some(1), Some(2), Some(3), None, Some(100)]),
371 Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]),
372 None,
373 ]);
374
375 assert_eq!(&result.to_data(), &expected.to_data());
376
377 Ok(())
378 }
379
380 #[test]
381 fn test_array_insert_negative_index() -> Result<()> {
382 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
385 Some(vec![Some(1), Some(2), Some(3)]),
386 Some(vec![Some(4), Some(5)]),
387 Some(vec![Some(1)]),
388 None,
389 ]);
390
391 let positions = Int32Array::from(vec![-2, -1, -3, -1]);
392 let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]);
393
394 let ColumnarValue::Array(result) = array_insert(
395 &list,
396 &(Arc::new(items) as ArrayRef),
397 &(Arc::new(positions) as ArrayRef),
398 false,
399 )?
400 else {
401 unreachable!()
402 };
403
404 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
405 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
406 Some(vec![Some(4), Some(5), Some(20)]),
407 Some(vec![Some(100), None, Some(1)]),
408 None,
409 ]);
410
411 assert_eq!(&result.to_data(), &expected.to_data());
412
413 Ok(())
414 }
415
416 #[test]
417 fn test_array_insert_legacy_mode() -> Result<()> {
418 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
420 Some(vec![Some(1), Some(2), Some(3)]),
421 Some(vec![Some(4), Some(5)]),
422 None,
423 ]);
424
425 let positions = Int32Array::from(vec![-1, -1, -1]);
426 let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
427
428 let ColumnarValue::Array(result) = array_insert(
429 &list,
430 &(Arc::new(items) as ArrayRef),
431 &(Arc::new(positions) as ArrayRef),
432 true,
433 )?
434 else {
435 unreachable!()
436 };
437
438 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
439 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
440 Some(vec![Some(4), Some(20), Some(5)]),
441 None,
442 ]);
443
444 assert_eq!(&result.to_data(), &expected.to_data());
445
446 Ok(())
447 }
448}