datafusion_comet_spark_expr/array_funcs/
list_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
19use arrow::datatypes::{DataType, FieldRef, Schema};
20use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
21use datafusion::common::{
22    cast::{as_int32_array, as_large_list_array, as_list_array},
23    internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
24};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29    any::Any,
30    fmt::{Debug, Display, Formatter},
31    sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct ListExtract {
36    child: Arc<dyn PhysicalExpr>,
37    ordinal: Arc<dyn PhysicalExpr>,
38    default_value: Option<Arc<dyn PhysicalExpr>>,
39    one_based: bool,
40    fail_on_error: bool,
41}
42
43impl Hash for ListExtract {
44    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45        self.child.hash(state);
46        self.ordinal.hash(state);
47        self.default_value.hash(state);
48        self.one_based.hash(state);
49        self.fail_on_error.hash(state);
50    }
51}
52impl PartialEq for ListExtract {
53    fn eq(&self, other: &Self) -> bool {
54        self.child.eq(&other.child)
55            && self.ordinal.eq(&other.ordinal)
56            && self.default_value.eq(&other.default_value)
57            && self.one_based.eq(&other.one_based)
58            && self.fail_on_error.eq(&other.fail_on_error)
59    }
60}
61
62impl ListExtract {
63    pub fn new(
64        child: Arc<dyn PhysicalExpr>,
65        ordinal: Arc<dyn PhysicalExpr>,
66        default_value: Option<Arc<dyn PhysicalExpr>>,
67        one_based: bool,
68        fail_on_error: bool,
69    ) -> Self {
70        Self {
71            child,
72            ordinal,
73            default_value,
74            one_based,
75            fail_on_error,
76        }
77    }
78
79    fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
80        match self.child.data_type(input_schema)? {
81            DataType::List(field) | DataType::LargeList(field) => Ok(field),
82            data_type => Err(DataFusionError::Internal(format!(
83                "Unexpected data type in ListExtract: {data_type:?}"
84            ))),
85        }
86    }
87}
88
89impl PhysicalExpr for ListExtract {
90    fn as_any(&self) -> &dyn Any {
91        self
92    }
93
94    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
95        unimplemented!()
96    }
97
98    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
99        Ok(self.child_field(input_schema)?.data_type().clone())
100    }
101
102    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
103        // Only non-nullable if fail_on_error is enabled and the element is non-nullable
104        Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable())
105    }
106
107    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
108        let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
109        let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
110
111        let default_value = self
112            .default_value
113            .as_ref()
114            .map(|d| {
115                d.evaluate(batch).map(|value| match value {
116                    ColumnarValue::Scalar(scalar)
117                        if !scalar.data_type().equals_datatype(child_value.data_type()) =>
118                    {
119                        scalar.cast_to(child_value.data_type())
120                    }
121                    ColumnarValue::Scalar(scalar) => Ok(scalar),
122                    v => Err(DataFusionError::Execution(format!(
123                        "Expected scalar default value for ListExtract, got {v:?}"
124                    ))),
125                })
126            })
127            .transpose()?
128            .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
129
130        let adjust_index = if self.one_based {
131            one_based_index
132        } else {
133            zero_based_index
134        };
135
136        match child_value.data_type() {
137            DataType::List(_) => {
138                let list_array = as_list_array(&child_value)?;
139                let index_array = as_int32_array(&ordinal_value)?;
140
141                list_extract(
142                    list_array,
143                    index_array,
144                    &default_value,
145                    self.fail_on_error,
146                    adjust_index,
147                )
148            }
149            DataType::LargeList(_) => {
150                let list_array = as_large_list_array(&child_value)?;
151                let index_array = as_int32_array(&ordinal_value)?;
152
153                list_extract(
154                    list_array,
155                    index_array,
156                    &default_value,
157                    self.fail_on_error,
158                    adjust_index,
159                )
160            }
161            data_type => Err(DataFusionError::Internal(format!(
162                "Unexpected child type for ListExtract: {data_type:?}"
163            ))),
164        }
165    }
166
167    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
168        vec![&self.child, &self.ordinal]
169    }
170
171    fn with_new_children(
172        self: Arc<Self>,
173        children: Vec<Arc<dyn PhysicalExpr>>,
174    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
175        match children.len() {
176            2 => Ok(Arc::new(ListExtract::new(
177                Arc::clone(&children[0]),
178                Arc::clone(&children[1]),
179                self.default_value.clone(),
180                self.one_based,
181                self.fail_on_error,
182            ))),
183            _ => internal_err!("ListExtract should have exactly two children"),
184        }
185    }
186}
187
188fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
189    if index == 0 {
190        return Err(DataFusionError::Execution(
191            "Invalid index of 0 for one-based ListExtract".to_string(),
192        ));
193    }
194
195    let abs_index = index.abs().as_usize();
196    if abs_index <= len {
197        if index > 0 {
198            Ok(Some(abs_index - 1))
199        } else {
200            Ok(Some(len - abs_index))
201        }
202    } else {
203        Ok(None)
204    }
205}
206
207fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
208    if index < 0 {
209        Ok(None)
210    } else {
211        let positive_index = index.as_usize();
212        if positive_index < len {
213            Ok(Some(positive_index))
214        } else {
215            Ok(None)
216        }
217    }
218}
219
220fn list_extract<O: OffsetSizeTrait>(
221    list_array: &GenericListArray<O>,
222    index_array: &Int32Array,
223    default_value: &ScalarValue,
224    fail_on_error: bool,
225    adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
226) -> DataFusionResult<ColumnarValue> {
227    let values = list_array.values();
228    let offsets = list_array.offsets();
229
230    let data = values.to_data();
231
232    let default_data = default_value.to_array()?.to_data();
233
234    let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len());
235
236    for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() {
237        let start = offset_window[0].as_usize();
238        let len = offset_window[1].as_usize() - start;
239
240        if let Some(i) = adjust_index(*index, len)? {
241            mutable.extend(0, start + i, start + i + 1);
242        } else if list_array.is_null(row) {
243            mutable.extend_nulls(1);
244        } else if fail_on_error {
245            return Err(DataFusionError::Execution(
246                "Index out of bounds for array".to_string(),
247            ));
248        } else {
249            mutable.extend(1, 0, 1);
250        }
251    }
252
253    let data = mutable.freeze();
254    Ok(ColumnarValue::Array(arrow::array::make_array(data)))
255}
256
257impl Display for ListExtract {
258    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259        write!(
260            f,
261            "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]",
262            self.child, self.ordinal,  self.default_value, self.one_based, self.fail_on_error
263        )
264    }
265}
266
267#[cfg(test)]
268mod test {
269    use super::*;
270    use arrow::array::{Array, Int32Array, ListArray};
271    use arrow::datatypes::Int32Type;
272    use datafusion::common::{Result, ScalarValue};
273    use datafusion::physical_plan::ColumnarValue;
274
275    #[test]
276    fn test_list_extract_default_value() -> Result<()> {
277        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
278            Some(vec![Some(1)]),
279            None,
280            Some(vec![]),
281        ]);
282        let indices = Int32Array::from(vec![0, 0, 0]);
283
284        let null_default = ScalarValue::Int32(None);
285
286        let ColumnarValue::Array(result) =
287            list_extract(&list, &indices, &null_default, false, zero_based_index)?
288        else {
289            unreachable!()
290        };
291
292        assert_eq!(
293            &result.to_data(),
294            &Int32Array::from(vec![Some(1), None, None]).to_data()
295        );
296
297        let zero_default = ScalarValue::Int32(Some(0));
298
299        let ColumnarValue::Array(result) =
300            list_extract(&list, &indices, &zero_default, false, zero_based_index)?
301        else {
302            unreachable!()
303        };
304
305        assert_eq!(
306            &result.to_data(),
307            &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
308        );
309        Ok(())
310    }
311}