datafusion_comet_spark_expr/array_funcs/
list_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
19use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{
23    cast::{as_int32_array, as_large_list_array, as_list_array},
24    internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
25};
26use datafusion_physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29    any::Any,
30    fmt::{Debug, Display, Formatter},
31    sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct ListExtract {
36    child: Arc<dyn PhysicalExpr>,
37    ordinal: Arc<dyn PhysicalExpr>,
38    default_value: Option<Arc<dyn PhysicalExpr>>,
39    one_based: bool,
40    fail_on_error: bool,
41}
42
43impl Hash for ListExtract {
44    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45        self.child.hash(state);
46        self.ordinal.hash(state);
47        self.default_value.hash(state);
48        self.one_based.hash(state);
49        self.fail_on_error.hash(state);
50    }
51}
52impl PartialEq for ListExtract {
53    fn eq(&self, other: &Self) -> bool {
54        self.child.eq(&other.child)
55            && self.ordinal.eq(&other.ordinal)
56            && self.default_value.eq(&other.default_value)
57            && self.one_based.eq(&other.one_based)
58            && self.fail_on_error.eq(&other.fail_on_error)
59    }
60}
61
62impl ListExtract {
63    pub fn new(
64        child: Arc<dyn PhysicalExpr>,
65        ordinal: Arc<dyn PhysicalExpr>,
66        default_value: Option<Arc<dyn PhysicalExpr>>,
67        one_based: bool,
68        fail_on_error: bool,
69    ) -> Self {
70        Self {
71            child,
72            ordinal,
73            default_value,
74            one_based,
75            fail_on_error,
76        }
77    }
78
79    fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
80        match self.child.data_type(input_schema)? {
81            DataType::List(field) | DataType::LargeList(field) => Ok(field),
82            data_type => Err(DataFusionError::Internal(format!(
83                "Unexpected data type in ListExtract: {:?}",
84                data_type
85            ))),
86        }
87    }
88}
89
90impl PhysicalExpr for ListExtract {
91    fn as_any(&self) -> &dyn Any {
92        self
93    }
94
95    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
96        Ok(self.child_field(input_schema)?.data_type().clone())
97    }
98
99    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
100        // Only non-nullable if fail_on_error is enabled and the element is non-nullable
101        Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable())
102    }
103
104    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
105        let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
106        let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
107
108        let default_value = self
109            .default_value
110            .as_ref()
111            .map(|d| {
112                d.evaluate(batch).map(|value| match value {
113                    ColumnarValue::Scalar(scalar)
114                        if !scalar.data_type().equals_datatype(child_value.data_type()) =>
115                    {
116                        scalar.cast_to(child_value.data_type())
117                    }
118                    ColumnarValue::Scalar(scalar) => Ok(scalar),
119                    v => Err(DataFusionError::Execution(format!(
120                        "Expected scalar default value for ListExtract, got {:?}",
121                        v
122                    ))),
123                })
124            })
125            .transpose()?
126            .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
127
128        let adjust_index = if self.one_based {
129            one_based_index
130        } else {
131            zero_based_index
132        };
133
134        match child_value.data_type() {
135            DataType::List(_) => {
136                let list_array = as_list_array(&child_value)?;
137                let index_array = as_int32_array(&ordinal_value)?;
138
139                list_extract(
140                    list_array,
141                    index_array,
142                    &default_value,
143                    self.fail_on_error,
144                    adjust_index,
145                )
146            }
147            DataType::LargeList(_) => {
148                let list_array = as_large_list_array(&child_value)?;
149                let index_array = as_int32_array(&ordinal_value)?;
150
151                list_extract(
152                    list_array,
153                    index_array,
154                    &default_value,
155                    self.fail_on_error,
156                    adjust_index,
157                )
158            }
159            data_type => Err(DataFusionError::Internal(format!(
160                "Unexpected child type for ListExtract: {:?}",
161                data_type
162            ))),
163        }
164    }
165
166    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
167        vec![&self.child, &self.ordinal]
168    }
169
170    fn with_new_children(
171        self: Arc<Self>,
172        children: Vec<Arc<dyn PhysicalExpr>>,
173    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
174        match children.len() {
175            2 => Ok(Arc::new(ListExtract::new(
176                Arc::clone(&children[0]),
177                Arc::clone(&children[1]),
178                self.default_value.clone(),
179                self.one_based,
180                self.fail_on_error,
181            ))),
182            _ => internal_err!("ListExtract should have exactly two children"),
183        }
184    }
185}
186
187fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
188    if index == 0 {
189        return Err(DataFusionError::Execution(
190            "Invalid index of 0 for one-based ListExtract".to_string(),
191        ));
192    }
193
194    let abs_index = index.abs().as_usize();
195    if abs_index <= len {
196        if index > 0 {
197            Ok(Some(abs_index - 1))
198        } else {
199            Ok(Some(len - abs_index))
200        }
201    } else {
202        Ok(None)
203    }
204}
205
206fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
207    if index < 0 {
208        Ok(None)
209    } else {
210        let positive_index = index.as_usize();
211        if positive_index < len {
212            Ok(Some(positive_index))
213        } else {
214            Ok(None)
215        }
216    }
217}
218
219fn list_extract<O: OffsetSizeTrait>(
220    list_array: &GenericListArray<O>,
221    index_array: &Int32Array,
222    default_value: &ScalarValue,
223    fail_on_error: bool,
224    adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
225) -> DataFusionResult<ColumnarValue> {
226    let values = list_array.values();
227    let offsets = list_array.offsets();
228
229    let data = values.to_data();
230
231    let default_data = default_value.to_array()?.to_data();
232
233    let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len());
234
235    for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() {
236        let start = offset_window[0].as_usize();
237        let len = offset_window[1].as_usize() - start;
238
239        if let Some(i) = adjust_index(*index, len)? {
240            mutable.extend(0, start + i, start + i + 1);
241        } else if list_array.is_null(row) {
242            mutable.extend_nulls(1);
243        } else if fail_on_error {
244            return Err(DataFusionError::Execution(
245                "Index out of bounds for array".to_string(),
246            ));
247        } else {
248            mutable.extend(1, 0, 1);
249        }
250    }
251
252    let data = mutable.freeze();
253    Ok(ColumnarValue::Array(arrow::array::make_array(data)))
254}
255
256impl Display for ListExtract {
257    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
258        write!(
259            f,
260            "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]",
261            self.child, self.ordinal,  self.default_value, self.one_based, self.fail_on_error
262        )
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use super::*;
269    use arrow::datatypes::Int32Type;
270    use arrow_array::{Array, Int32Array, ListArray};
271    use datafusion_common::{Result, ScalarValue};
272    use datafusion_expr::ColumnarValue;
273
274    #[test]
275    fn test_list_extract_default_value() -> Result<()> {
276        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
277            Some(vec![Some(1)]),
278            None,
279            Some(vec![]),
280        ]);
281        let indices = Int32Array::from(vec![0, 0, 0]);
282
283        let null_default = ScalarValue::Int32(None);
284
285        let ColumnarValue::Array(result) =
286            list_extract(&list, &indices, &null_default, false, zero_based_index)?
287        else {
288            unreachable!()
289        };
290
291        assert_eq!(
292            &result.to_data(),
293            &Int32Array::from(vec![Some(1), None, None]).to_data()
294        );
295
296        let zero_default = ScalarValue::Int32(Some(0));
297
298        let ColumnarValue::Array(result) =
299            list_extract(&list, &indices, &zero_default, false, zero_based_index)?
300        else {
301            unreachable!()
302        };
303
304        assert_eq!(
305            &result.to_data(),
306            &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
307        );
308        Ok(())
309    }
310}