datafusion_comet_spark_expr/array_funcs/
array_insert.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
19use arrow::datatypes::{DataType, Field, Schema};
20use arrow::{
21    array::{as_primitive_array, Capacities, MutableArrayData},
22    buffer::{NullBuffer, OffsetBuffer},
23    datatypes::ArrowNativeType,
24    record_batch::RecordBatch,
25};
26use datafusion::common::{
27    cast::{as_large_list_array, as_list_array},
28    internal_err, DataFusionError, Result as DataFusionResult,
29};
30use datafusion::logical_expr::ColumnarValue;
31use datafusion::physical_expr::PhysicalExpr;
32use std::hash::Hash;
33use std::{
34    any::Any,
35    fmt::{Debug, Display, Formatter},
36    sync::Arc,
37};
38
39// 2147483632 == java.lang.Integer.MAX_VALUE - 15
40// It is a value of ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH
41// https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java
42const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
43
44#[derive(Debug, Eq)]
45pub struct ArrayInsert {
46    src_array_expr: Arc<dyn PhysicalExpr>,
47    pos_expr: Arc<dyn PhysicalExpr>,
48    item_expr: Arc<dyn PhysicalExpr>,
49    legacy_negative_index: bool,
50}
51
52impl Hash for ArrayInsert {
53    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54        self.src_array_expr.hash(state);
55        self.pos_expr.hash(state);
56        self.item_expr.hash(state);
57        self.legacy_negative_index.hash(state);
58    }
59}
60impl PartialEq for ArrayInsert {
61    fn eq(&self, other: &Self) -> bool {
62        self.src_array_expr.eq(&other.src_array_expr)
63            && self.pos_expr.eq(&other.pos_expr)
64            && self.item_expr.eq(&other.item_expr)
65            && self.legacy_negative_index.eq(&other.legacy_negative_index)
66    }
67}
68
69impl ArrayInsert {
70    pub fn new(
71        src_array_expr: Arc<dyn PhysicalExpr>,
72        pos_expr: Arc<dyn PhysicalExpr>,
73        item_expr: Arc<dyn PhysicalExpr>,
74        legacy_negative_index: bool,
75    ) -> Self {
76        Self {
77            src_array_expr,
78            pos_expr,
79            item_expr,
80            legacy_negative_index,
81        }
82    }
83
84    pub fn array_type(&self, data_type: &DataType) -> DataFusionResult<DataType> {
85        match data_type {
86            DataType::List(field) => Ok(DataType::List(Arc::clone(field))),
87            DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))),
88            data_type => Err(DataFusionError::Internal(format!(
89                "Unexpected src array type in ArrayInsert: {:?}",
90                data_type
91            ))),
92        }
93    }
94}
95
96impl PhysicalExpr for ArrayInsert {
97    fn as_any(&self) -> &dyn Any {
98        self
99    }
100
101    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
102        unimplemented!()
103    }
104
105    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
106        self.array_type(&self.src_array_expr.data_type(input_schema)?)
107    }
108
109    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
110        self.src_array_expr.nullable(input_schema)
111    }
112
113    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
114        let pos_value = self
115            .pos_expr
116            .evaluate(batch)?
117            .into_array(batch.num_rows())?;
118
119        // Spark supports only IntegerType (Int32):
120        // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4737
121        if !matches!(pos_value.data_type(), DataType::Int32) {
122            return Err(DataFusionError::Internal(format!(
123                "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32",
124                pos_value.data_type()
125            )));
126        }
127
128        // Check that src array is actually an array and get it's value type
129        let src_value = self
130            .src_array_expr
131            .evaluate(batch)?
132            .into_array(batch.num_rows())?;
133
134        let src_element_type = match self.array_type(src_value.data_type())? {
135            DataType::List(field) => &field.data_type().clone(),
136            DataType::LargeList(field) => &field.data_type().clone(),
137            _ => unreachable!(),
138        };
139
140        // Check that inserted value has the same type as an array
141        let item_value = self
142            .item_expr
143            .evaluate(batch)?
144            .into_array(batch.num_rows())?;
145        if item_value.data_type() != src_element_type {
146            return Err(DataFusionError::Internal(format!(
147                "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}",
148                src_element_type,
149                item_value.data_type()
150            )));
151        }
152
153        match src_value.data_type() {
154            DataType::List(_) => {
155                let list_array = as_list_array(&src_value)?;
156                array_insert(
157                    list_array,
158                    &item_value,
159                    &pos_value,
160                    self.legacy_negative_index,
161                )
162            }
163            DataType::LargeList(_) => {
164                let list_array = as_large_list_array(&src_value)?;
165                array_insert(
166                    list_array,
167                    &item_value,
168                    &pos_value,
169                    self.legacy_negative_index,
170                )
171            }
172            _ => unreachable!(), // This case is checked already
173        }
174    }
175
176    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
177        vec![&self.src_array_expr, &self.pos_expr, &self.item_expr]
178    }
179
180    fn with_new_children(
181        self: Arc<Self>,
182        children: Vec<Arc<dyn PhysicalExpr>>,
183    ) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
184        match children.len() {
185            3 => Ok(Arc::new(ArrayInsert::new(
186                Arc::clone(&children[0]),
187                Arc::clone(&children[1]),
188                Arc::clone(&children[2]),
189                self.legacy_negative_index,
190            ))),
191            _ => internal_err!("ArrayInsert should have exactly three childrens"),
192        }
193    }
194}
195
196fn array_insert<O: OffsetSizeTrait>(
197    list_array: &GenericListArray<O>,
198    items_array: &ArrayRef,
199    pos_array: &ArrayRef,
200    legacy_mode: bool,
201) -> DataFusionResult<ColumnarValue> {
202    // The code is based on the implementation of the array_append from the Apache DataFusion
203    // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
204    //
205    // This code is also based on the implementation of the array_insert from the Apache Spark
206    // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
207
208    let values = list_array.values();
209    let offsets = list_array.offsets();
210    let values_data = values.to_data();
211    let item_data = items_array.to_data();
212    let new_capacity = Capacities::Array(values_data.len() + item_data.len());
213
214    let mut mutable_values =
215        MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
216
217    let mut new_offsets = vec![O::usize_as(0)];
218    let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
219
220    let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
221
222    for (row_index, offset_window) in offsets.windows(2).enumerate() {
223        let pos = pos_data.values()[row_index];
224        let start = offset_window[0].as_usize();
225        let end = offset_window[1].as_usize();
226        let is_item_null = items_array.is_null(row_index);
227
228        if list_array.is_null(row_index) {
229            // In Spark if value of the array is NULL than nothing happens
230            mutable_values.extend_nulls(1);
231            new_offsets.push(new_offsets[row_index] + O::one());
232            new_nulls.push(false);
233            continue;
234        }
235
236        if pos == 0 {
237            return Err(DataFusionError::Internal(
238                "Position for array_insert should be greter or less than zero".to_string(),
239            ));
240        }
241
242        if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
243            let corrected_pos = if pos > 0 {
244                (pos - 1).as_usize()
245            } else {
246                end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
247            };
248            let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
249            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
250                return Err(DataFusionError::Internal(format!(
251                    "Max array length in Spark is {:?}, but got {:?}",
252                    MAX_ROUNDED_ARRAY_LENGTH, new_array_len
253                )));
254            }
255
256            if (start + corrected_pos) <= end {
257                mutable_values.extend(0, start, start + corrected_pos);
258                mutable_values.extend(1, row_index, row_index + 1);
259                mutable_values.extend(0, start + corrected_pos, end);
260                new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
261            } else {
262                mutable_values.extend(0, start, end);
263                mutable_values.extend_nulls(new_array_len - (end - start));
264                mutable_values.extend(1, row_index, row_index + 1);
265                // In that case spark actualy makes array longer than expected;
266                // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
267                new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
268            }
269        } else {
270            // This comment is takes from the Apache Spark source code as is:
271            // special case- if the new position is negative but larger than the current array size
272            // place the new item at start of array, place the current array contents at the end
273            // and fill the newly created array elements inbetween with a null
274            let base_offset = if legacy_mode { 1 } else { 0 };
275            let new_array_len = (-pos + base_offset).as_usize();
276            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
277                return Err(DataFusionError::Internal(format!(
278                    "Max array length in Spark is {:?}, but got {:?}",
279                    MAX_ROUNDED_ARRAY_LENGTH, new_array_len
280                )));
281            }
282            mutable_values.extend(1, row_index, row_index + 1);
283            mutable_values.extend_nulls(new_array_len - (end - start + 1));
284            mutable_values.extend(0, start, end);
285            new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
286        }
287        if is_item_null {
288            if (start == end) || (values.is_null(row_index)) {
289                new_nulls.push(false)
290            } else {
291                new_nulls.push(true)
292            }
293        } else {
294            new_nulls.push(true)
295        }
296    }
297
298    let data = make_array(mutable_values.freeze());
299    let data_type = match list_array.data_type() {
300        DataType::List(field) => field.data_type(),
301        DataType::LargeList(field) => field.data_type(),
302        _ => unreachable!(),
303    };
304    let new_array = GenericListArray::<O>::try_new(
305        Arc::new(Field::new("item", data_type.clone(), true)),
306        OffsetBuffer::new(new_offsets.into()),
307        data,
308        Some(NullBuffer::new(new_nulls.into())),
309    )?;
310
311    Ok(ColumnarValue::Array(Arc::new(new_array)))
312}
313
314impl Display for ArrayInsert {
315    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
316        write!(
317            f,
318            "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]",
319            self.src_array_expr, self.pos_expr, self.item_expr
320        )
321    }
322}
323
324#[cfg(test)]
325mod test {
326    use super::*;
327    use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
328    use arrow::datatypes::Int32Type;
329    use datafusion::common::Result;
330    use datafusion::physical_plan::ColumnarValue;
331    use std::sync::Arc;
332
333    #[test]
334    fn test_array_insert() -> Result<()> {
335        // Test inserting an item into a list array
336        // Inputs and expected values are taken from the Spark results
337        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
338            Some(vec![Some(1), Some(2), Some(3)]),
339            Some(vec![Some(4), Some(5)]),
340            Some(vec![None]),
341            Some(vec![Some(1), Some(2), Some(3)]),
342            Some(vec![Some(1), Some(2), Some(3)]),
343            None,
344        ]);
345
346        let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]);
347        let items = Int32Array::from(vec![
348            Some(10),
349            Some(20),
350            Some(30),
351            Some(100),
352            Some(100),
353            Some(40),
354        ]);
355
356        let ColumnarValue::Array(result) = array_insert(
357            &list,
358            &(Arc::new(items) as ArrayRef),
359            &(Arc::new(positions) as ArrayRef),
360            false,
361        )?
362        else {
363            unreachable!()
364        };
365
366        let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
367            Some(vec![Some(1), Some(10), Some(2), Some(3)]),
368            Some(vec![Some(20), Some(4), Some(5)]),
369            Some(vec![Some(30), None]),
370            Some(vec![Some(1), Some(2), Some(3), None, Some(100)]),
371            Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]),
372            None,
373        ]);
374
375        assert_eq!(&result.to_data(), &expected.to_data());
376
377        Ok(())
378    }
379
380    #[test]
381    fn test_array_insert_negative_index() -> Result<()> {
382        // Test insert with negative index
383        // Inputs and expected values are taken from the Spark results
384        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
385            Some(vec![Some(1), Some(2), Some(3)]),
386            Some(vec![Some(4), Some(5)]),
387            Some(vec![Some(1)]),
388            None,
389        ]);
390
391        let positions = Int32Array::from(vec![-2, -1, -3, -1]);
392        let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]);
393
394        let ColumnarValue::Array(result) = array_insert(
395            &list,
396            &(Arc::new(items) as ArrayRef),
397            &(Arc::new(positions) as ArrayRef),
398            false,
399        )?
400        else {
401            unreachable!()
402        };
403
404        let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
405            Some(vec![Some(1), Some(2), Some(10), Some(3)]),
406            Some(vec![Some(4), Some(5), Some(20)]),
407            Some(vec![Some(100), None, Some(1)]),
408            None,
409        ]);
410
411        assert_eq!(&result.to_data(), &expected.to_data());
412
413        Ok(())
414    }
415
416    #[test]
417    fn test_array_insert_legacy_mode() -> Result<()> {
418        // Test the so-called "legacy" mode exisiting in the Spark
419        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
420            Some(vec![Some(1), Some(2), Some(3)]),
421            Some(vec![Some(4), Some(5)]),
422            None,
423        ]);
424
425        let positions = Int32Array::from(vec![-1, -1, -1]);
426        let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
427
428        let ColumnarValue::Array(result) = array_insert(
429            &list,
430            &(Arc::new(items) as ArrayRef),
431            &(Arc::new(positions) as ArrayRef),
432            true,
433        )?
434        else {
435            unreachable!()
436        };
437
438        let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
439            Some(vec![Some(1), Some(2), Some(10), Some(3)]),
440            Some(vec![Some(4), Some(20), Some(5)]),
441            None,
442        ]);
443
444        assert_eq!(&result.to_data(), &expected.to_data());
445
446        Ok(())
447    }
448}