datafusion_comet_spark_expr/array_funcs/
list_extract.rs1use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
19use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{
23 cast::{as_int32_array, as_large_list_array, as_list_array},
24 internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
25};
26use datafusion_physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Debug, Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct ListExtract {
36 child: Arc<dyn PhysicalExpr>,
37 ordinal: Arc<dyn PhysicalExpr>,
38 default_value: Option<Arc<dyn PhysicalExpr>>,
39 one_based: bool,
40 fail_on_error: bool,
41}
42
43impl Hash for ListExtract {
44 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45 self.child.hash(state);
46 self.ordinal.hash(state);
47 self.default_value.hash(state);
48 self.one_based.hash(state);
49 self.fail_on_error.hash(state);
50 }
51}
52impl PartialEq for ListExtract {
53 fn eq(&self, other: &Self) -> bool {
54 self.child.eq(&other.child)
55 && self.ordinal.eq(&other.ordinal)
56 && self.default_value.eq(&other.default_value)
57 && self.one_based.eq(&other.one_based)
58 && self.fail_on_error.eq(&other.fail_on_error)
59 }
60}
61
62impl ListExtract {
63 pub fn new(
64 child: Arc<dyn PhysicalExpr>,
65 ordinal: Arc<dyn PhysicalExpr>,
66 default_value: Option<Arc<dyn PhysicalExpr>>,
67 one_based: bool,
68 fail_on_error: bool,
69 ) -> Self {
70 Self {
71 child,
72 ordinal,
73 default_value,
74 one_based,
75 fail_on_error,
76 }
77 }
78
79 fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
80 match self.child.data_type(input_schema)? {
81 DataType::List(field) | DataType::LargeList(field) => Ok(field),
82 data_type => Err(DataFusionError::Internal(format!(
83 "Unexpected data type in ListExtract: {:?}",
84 data_type
85 ))),
86 }
87 }
88}
89
90impl PhysicalExpr for ListExtract {
91 fn as_any(&self) -> &dyn Any {
92 self
93 }
94
95 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
96 Ok(self.child_field(input_schema)?.data_type().clone())
97 }
98
99 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
100 Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable())
102 }
103
104 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
105 let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
106 let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
107
108 let default_value = self
109 .default_value
110 .as_ref()
111 .map(|d| {
112 d.evaluate(batch).map(|value| match value {
113 ColumnarValue::Scalar(scalar)
114 if !scalar.data_type().equals_datatype(child_value.data_type()) =>
115 {
116 scalar.cast_to(child_value.data_type())
117 }
118 ColumnarValue::Scalar(scalar) => Ok(scalar),
119 v => Err(DataFusionError::Execution(format!(
120 "Expected scalar default value for ListExtract, got {:?}",
121 v
122 ))),
123 })
124 })
125 .transpose()?
126 .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
127
128 let adjust_index = if self.one_based {
129 one_based_index
130 } else {
131 zero_based_index
132 };
133
134 match child_value.data_type() {
135 DataType::List(_) => {
136 let list_array = as_list_array(&child_value)?;
137 let index_array = as_int32_array(&ordinal_value)?;
138
139 list_extract(
140 list_array,
141 index_array,
142 &default_value,
143 self.fail_on_error,
144 adjust_index,
145 )
146 }
147 DataType::LargeList(_) => {
148 let list_array = as_large_list_array(&child_value)?;
149 let index_array = as_int32_array(&ordinal_value)?;
150
151 list_extract(
152 list_array,
153 index_array,
154 &default_value,
155 self.fail_on_error,
156 adjust_index,
157 )
158 }
159 data_type => Err(DataFusionError::Internal(format!(
160 "Unexpected child type for ListExtract: {:?}",
161 data_type
162 ))),
163 }
164 }
165
166 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
167 vec![&self.child, &self.ordinal]
168 }
169
170 fn with_new_children(
171 self: Arc<Self>,
172 children: Vec<Arc<dyn PhysicalExpr>>,
173 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
174 match children.len() {
175 2 => Ok(Arc::new(ListExtract::new(
176 Arc::clone(&children[0]),
177 Arc::clone(&children[1]),
178 self.default_value.clone(),
179 self.one_based,
180 self.fail_on_error,
181 ))),
182 _ => internal_err!("ListExtract should have exactly two children"),
183 }
184 }
185}
186
187fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
188 if index == 0 {
189 return Err(DataFusionError::Execution(
190 "Invalid index of 0 for one-based ListExtract".to_string(),
191 ));
192 }
193
194 let abs_index = index.abs().as_usize();
195 if abs_index <= len {
196 if index > 0 {
197 Ok(Some(abs_index - 1))
198 } else {
199 Ok(Some(len - abs_index))
200 }
201 } else {
202 Ok(None)
203 }
204}
205
206fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
207 if index < 0 {
208 Ok(None)
209 } else {
210 let positive_index = index.as_usize();
211 if positive_index < len {
212 Ok(Some(positive_index))
213 } else {
214 Ok(None)
215 }
216 }
217}
218
219fn list_extract<O: OffsetSizeTrait>(
220 list_array: &GenericListArray<O>,
221 index_array: &Int32Array,
222 default_value: &ScalarValue,
223 fail_on_error: bool,
224 adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
225) -> DataFusionResult<ColumnarValue> {
226 let values = list_array.values();
227 let offsets = list_array.offsets();
228
229 let data = values.to_data();
230
231 let default_data = default_value.to_array()?.to_data();
232
233 let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len());
234
235 for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() {
236 let start = offset_window[0].as_usize();
237 let len = offset_window[1].as_usize() - start;
238
239 if let Some(i) = adjust_index(*index, len)? {
240 mutable.extend(0, start + i, start + i + 1);
241 } else if list_array.is_null(row) {
242 mutable.extend_nulls(1);
243 } else if fail_on_error {
244 return Err(DataFusionError::Execution(
245 "Index out of bounds for array".to_string(),
246 ));
247 } else {
248 mutable.extend(1, 0, 1);
249 }
250 }
251
252 let data = mutable.freeze();
253 Ok(ColumnarValue::Array(arrow::array::make_array(data)))
254}
255
256impl Display for ListExtract {
257 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
258 write!(
259 f,
260 "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]",
261 self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error
262 )
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use super::*;
269 use arrow::datatypes::Int32Type;
270 use arrow_array::{Array, Int32Array, ListArray};
271 use datafusion_common::{Result, ScalarValue};
272 use datafusion_expr::ColumnarValue;
273
274 #[test]
275 fn test_list_extract_default_value() -> Result<()> {
276 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
277 Some(vec![Some(1)]),
278 None,
279 Some(vec![]),
280 ]);
281 let indices = Int32Array::from(vec![0, 0, 0]);
282
283 let null_default = ScalarValue::Int32(None);
284
285 let ColumnarValue::Array(result) =
286 list_extract(&list, &indices, &null_default, false, zero_based_index)?
287 else {
288 unreachable!()
289 };
290
291 assert_eq!(
292 &result.to_data(),
293 &Int32Array::from(vec![Some(1), None, None]).to_data()
294 );
295
296 let zero_default = ScalarValue::Int32(Some(0));
297
298 let ColumnarValue::Array(result) =
299 list_extract(&list, &indices, &zero_default, false, zero_based_index)?
300 else {
301 unreachable!()
302 };
303
304 assert_eq!(
305 &result.to_data(),
306 &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
307 );
308 Ok(())
309 }
310}