1use arrow::{
19 array::{as_primitive_array, Capacities, MutableArrayData},
20 buffer::{NullBuffer, OffsetBuffer},
21 datatypes::ArrowNativeType,
22 record_batch::RecordBatch,
23};
24use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
25use arrow_schema::{DataType, Field, Schema};
26use datafusion::logical_expr::ColumnarValue;
27use datafusion_common::{
28 cast::{as_large_list_array, as_list_array},
29 internal_err, DataFusionError, Result as DataFusionResult,
30};
31use datafusion_physical_expr::PhysicalExpr;
32use std::hash::Hash;
33use std::{
34 any::Any,
35 fmt::{Debug, Display, Formatter},
36 sync::Arc,
37};
38
39const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
43
44#[derive(Debug, Eq)]
45pub struct ArrayInsert {
46 src_array_expr: Arc<dyn PhysicalExpr>,
47 pos_expr: Arc<dyn PhysicalExpr>,
48 item_expr: Arc<dyn PhysicalExpr>,
49 legacy_negative_index: bool,
50}
51
52impl Hash for ArrayInsert {
53 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54 self.src_array_expr.hash(state);
55 self.pos_expr.hash(state);
56 self.item_expr.hash(state);
57 self.legacy_negative_index.hash(state);
58 }
59}
60impl PartialEq for ArrayInsert {
61 fn eq(&self, other: &Self) -> bool {
62 self.src_array_expr.eq(&other.src_array_expr)
63 && self.pos_expr.eq(&other.pos_expr)
64 && self.item_expr.eq(&other.item_expr)
65 && self.legacy_negative_index.eq(&other.legacy_negative_index)
66 }
67}
68
69impl ArrayInsert {
70 pub fn new(
71 src_array_expr: Arc<dyn PhysicalExpr>,
72 pos_expr: Arc<dyn PhysicalExpr>,
73 item_expr: Arc<dyn PhysicalExpr>,
74 legacy_negative_index: bool,
75 ) -> Self {
76 Self {
77 src_array_expr,
78 pos_expr,
79 item_expr,
80 legacy_negative_index,
81 }
82 }
83
84 pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
85 match data_type {
86 DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
87 DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
88 data_type => Err(DataFusionError::Internal(format!(
89 "Unexpected src array type in ArrayInsert: {:?}",
90 data_type
91 ))),
92 }
93 }
94}
95
96impl PhysicalExpr for ArrayInsert {
97 fn as_any(&self) -> &dyn Any {
98 self
99 }
100
101 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
102 self.array_type(&self.src_array_expr.data_type(input_schema)?)
103 }
104
105 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
106 self.src_array_expr.nullable(input_schema)
107 }
108
109 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
110 let pos_value = self
111 .pos_expr
112 .evaluate(batch)?
113 .into_array(batch.num_rows())?;
114
115 if !matches!(pos_value.data_type(), DataType::Int32) {
118 return Err(DataFusionError::Internal(format!(
119 "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32",
120 pos_value.data_type()
121 )));
122 }
123
124 let src_value = self
126 .src_array_expr
127 .evaluate(batch)?
128 .into_array(batch.num_rows())?;
129
130 let src_element_type = match self.array_type(src_value.data_type())? {
131 DataType::List(field) => &field.data_type().clone(),
132 DataType::LargeList(field) => &field.data_type().clone(),
133 _ => unreachable!(),
134 };
135
136 let item_value = self
138 .item_expr
139 .evaluate(batch)?
140 .into_array(batch.num_rows())?;
141 if item_value.data_type() != src_element_type {
142 return Err(DataFusionError::Internal(format!(
143 "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}",
144 src_element_type,
145 item_value.data_type()
146 )));
147 }
148
149 match src_value.data_type() {
150 DataType::List(_) => {
151 let list_array = as_list_array(&src_value)?;
152 array_insert(
153 list_array,
154 &item_value,
155 &pos_value,
156 self.legacy_negative_index,
157 )
158 }
159 DataType::LargeList(_) => {
160 let list_array = as_large_list_array(&src_value)?;
161 array_insert(
162 list_array,
163 &item_value,
164 &pos_value,
165 self.legacy_negative_index,
166 )
167 }
168 _ => unreachable!(), }
170 }
171
172 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
173 vec![&self.src_array_expr, &self.pos_expr, &self.item_expr]
174 }
175
176 fn with_new_children(
177 self: Arc<Self>,
178 children: Vec<Arc<dyn PhysicalExpr>>,
179 ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
180 match children.len() {
181 3 => Ok(Arc::new(ArrayInsert::new(
182 Arc::clone(&children[0]),
183 Arc::clone(&children[1]),
184 Arc::clone(&children[2]),
185 self.legacy_negative_index,
186 ))),
187 _ => internal_err!("ArrayInsert should have exactly three childrens"),
188 }
189 }
190}
191
192fn array_insert<O: OffsetSizeTrait>(
193 list_array: &GenericListArray<O>,
194 items_array: &ArrayRef,
195 pos_array: &ArrayRef,
196 legacy_mode: bool,
197) -> DataFusionResult<ColumnarValue> {
198 let values = list_array.values();
205 let offsets = list_array.offsets();
206 let values_data = values.to_data();
207 let item_data = items_array.to_data();
208 let new_capacity = Capacities::Array(values_data.len() + item_data.len());
209
210 let mut mutable_values =
211 MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
212
213 let mut new_offsets = vec![O::usize_as(0)];
214 let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
215
216 let pos_data: &Int32Array = as_primitive_array(&pos_array); for (row_index, offset_window) in offsets.windows(2).enumerate() {
219 let pos = pos_data.values()[row_index];
220 let start = offset_window[0].as_usize();
221 let end = offset_window[1].as_usize();
222 let is_item_null = items_array.is_null(row_index);
223
224 if list_array.is_null(row_index) {
225 mutable_values.extend_nulls(1);
227 new_offsets.push(new_offsets[row_index] + O::one());
228 new_nulls.push(false);
229 continue;
230 }
231
232 if pos == 0 {
233 return Err(DataFusionError::Internal(
234 "Position for array_insert should be greter or less than zero".to_string(),
235 ));
236 }
237
238 if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
239 let corrected_pos = if pos > 0 {
240 (pos - 1).as_usize()
241 } else {
242 end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
243 };
244 let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
245 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
246 return Err(DataFusionError::Internal(format!(
247 "Max array length in Spark is {:?}, but got {:?}",
248 MAX_ROUNDED_ARRAY_LENGTH, new_array_len
249 )));
250 }
251
252 if (start + corrected_pos) <= end {
253 mutable_values.extend(0, start, start + corrected_pos);
254 mutable_values.extend(1, row_index, row_index + 1);
255 mutable_values.extend(0, start + corrected_pos, end);
256 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
257 } else {
258 mutable_values.extend(0, start, end);
259 mutable_values.extend_nulls(new_array_len - (end - start));
260 mutable_values.extend(1, row_index, row_index + 1);
261 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
264 }
265 } else {
266 let base_offset = if legacy_mode { 1 } else { 0 };
271 let new_array_len = (-pos + base_offset).as_usize();
272 if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
273 return Err(DataFusionError::Internal(format!(
274 "Max array length in Spark is {:?}, but got {:?}",
275 MAX_ROUNDED_ARRAY_LENGTH, new_array_len
276 )));
277 }
278 mutable_values.extend(1, row_index, row_index + 1);
279 mutable_values.extend_nulls(new_array_len - (end - start + 1));
280 mutable_values.extend(0, start, end);
281 new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
282 }
283 if is_item_null {
284 if (start == end) || (values.is_null(row_index)) {
285 new_nulls.push(false)
286 } else {
287 new_nulls.push(true)
288 }
289 } else {
290 new_nulls.push(true)
291 }
292 }
293
294 let data = make_array(mutable_values.freeze());
295 let data_type = match list_array.data_type() {
296 DataType::List(field) => field.data_type(),
297 DataType::LargeList(field) => field.data_type(),
298 _ => unreachable!(),
299 };
300 let new_array = GenericListArray::<O>::try_new(
301 Arc::new(Field::new("item", data_type.clone(), true)),
302 OffsetBuffer::new(new_offsets.into()),
303 data,
304 Some(NullBuffer::new(new_nulls.into())),
305 )?;
306
307 Ok(ColumnarValue::Array(Arc::new(new_array)))
308}
309
310impl Display for ArrayInsert {
311 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
312 write!(
313 f,
314 "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]",
315 self.src_array_expr, self.pos_expr, self.item_expr
316 )
317 }
318}
319
320#[cfg(test)]
321mod test {
322 use super::*;
323 use arrow::datatypes::Int32Type;
324 use arrow_array::{Array, ArrayRef, Int32Array, ListArray};
325 use datafusion_common::Result;
326 use datafusion_expr::ColumnarValue;
327 use std::sync::Arc;
328
329 #[test]
330 fn test_array_insert() -> Result<()> {
331 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
334 Some(vec![Some(1), Some(2), Some(3)]),
335 Some(vec![Some(4), Some(5)]),
336 Some(vec![None]),
337 Some(vec![Some(1), Some(2), Some(3)]),
338 Some(vec![Some(1), Some(2), Some(3)]),
339 None,
340 ]);
341
342 let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]);
343 let items = Int32Array::from(vec![
344 Some(10),
345 Some(20),
346 Some(30),
347 Some(100),
348 Some(100),
349 Some(40),
350 ]);
351
352 let ColumnarValue::Array(result) = array_insert(
353 &list,
354 &(Arc::new(items) as ArrayRef),
355 &(Arc::new(positions) as ArrayRef),
356 false,
357 )?
358 else {
359 unreachable!()
360 };
361
362 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
363 Some(vec![Some(1), Some(10), Some(2), Some(3)]),
364 Some(vec![Some(20), Some(4), Some(5)]),
365 Some(vec![Some(30), None]),
366 Some(vec![Some(1), Some(2), Some(3), None, Some(100)]),
367 Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]),
368 None,
369 ]);
370
371 assert_eq!(&result.to_data(), &expected.to_data());
372
373 Ok(())
374 }
375
376 #[test]
377 fn test_array_insert_negative_index() -> Result<()> {
378 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
381 Some(vec![Some(1), Some(2), Some(3)]),
382 Some(vec![Some(4), Some(5)]),
383 Some(vec![Some(1)]),
384 None,
385 ]);
386
387 let positions = Int32Array::from(vec![-2, -1, -3, -1]);
388 let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]);
389
390 let ColumnarValue::Array(result) = array_insert(
391 &list,
392 &(Arc::new(items) as ArrayRef),
393 &(Arc::new(positions) as ArrayRef),
394 false,
395 )?
396 else {
397 unreachable!()
398 };
399
400 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
401 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
402 Some(vec![Some(4), Some(5), Some(20)]),
403 Some(vec![Some(100), None, Some(1)]),
404 None,
405 ]);
406
407 assert_eq!(&result.to_data(), &expected.to_data());
408
409 Ok(())
410 }
411
412 #[test]
413 fn test_array_insert_legacy_mode() -> Result<()> {
414 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
416 Some(vec![Some(1), Some(2), Some(3)]),
417 Some(vec![Some(4), Some(5)]),
418 None,
419 ]);
420
421 let positions = Int32Array::from(vec![-1, -1, -1]);
422 let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
423
424 let ColumnarValue::Array(result) = array_insert(
425 &list,
426 &(Arc::new(items) as ArrayRef),
427 &(Arc::new(positions) as ArrayRef),
428 true,
429 )?
430 else {
431 unreachable!()
432 };
433
434 let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
435 Some(vec![Some(1), Some(2), Some(10), Some(3)]),
436 Some(vec![Some(4), Some(20), Some(5)]),
437 None,
438 ]);
439
440 assert_eq!(&result.to_data(), &expected.to_data());
441
442 Ok(())
443 }
444}