datafusion_comet_spark_expr/array_funcs/
list_extract.rs1use arrow::array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
19use arrow::datatypes::{DataType, FieldRef, Schema};
20use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
21use datafusion::common::{
22 cast::{as_int32_array, as_large_list_array, as_list_array},
23 internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
24};
25use datafusion::logical_expr::ColumnarValue;
26use datafusion::physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Debug, Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct ListExtract {
36 child: Arc<dyn PhysicalExpr>,
37 ordinal: Arc<dyn PhysicalExpr>,
38 default_value: Option<Arc<dyn PhysicalExpr>>,
39 one_based: bool,
40 fail_on_error: bool,
41}
42
43impl Hash for ListExtract {
44 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
45 self.child.hash(state);
46 self.ordinal.hash(state);
47 self.default_value.hash(state);
48 self.one_based.hash(state);
49 self.fail_on_error.hash(state);
50 }
51}
52impl PartialEq for ListExtract {
53 fn eq(&self, other: &Self) -> bool {
54 self.child.eq(&other.child)
55 && self.ordinal.eq(&other.ordinal)
56 && self.default_value.eq(&other.default_value)
57 && self.one_based.eq(&other.one_based)
58 && self.fail_on_error.eq(&other.fail_on_error)
59 }
60}
61
62impl ListExtract {
63 pub fn new(
64 child: Arc<dyn PhysicalExpr>,
65 ordinal: Arc<dyn PhysicalExpr>,
66 default_value: Option<Arc<dyn PhysicalExpr>>,
67 one_based: bool,
68 fail_on_error: bool,
69 ) -> Self {
70 Self {
71 child,
72 ordinal,
73 default_value,
74 one_based,
75 fail_on_error,
76 }
77 }
78
79 fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
80 match self.child.data_type(input_schema)? {
81 DataType::List(field) | DataType::LargeList(field) => Ok(field),
82 data_type => Err(DataFusionError::Internal(format!(
83 "Unexpected data type in ListExtract: {data_type:?}"
84 ))),
85 }
86 }
87}
88
89impl PhysicalExpr for ListExtract {
90 fn as_any(&self) -> &dyn Any {
91 self
92 }
93
94 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
95 unimplemented!()
96 }
97
98 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
99 Ok(self.child_field(input_schema)?.data_type().clone())
100 }
101
102 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
103 Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable())
105 }
106
107 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
108 let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
109 let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
110
111 let default_value = self
112 .default_value
113 .as_ref()
114 .map(|d| {
115 d.evaluate(batch).map(|value| match value {
116 ColumnarValue::Scalar(scalar)
117 if !scalar.data_type().equals_datatype(child_value.data_type()) =>
118 {
119 scalar.cast_to(child_value.data_type())
120 }
121 ColumnarValue::Scalar(scalar) => Ok(scalar),
122 v => Err(DataFusionError::Execution(format!(
123 "Expected scalar default value for ListExtract, got {v:?}"
124 ))),
125 })
126 })
127 .transpose()?
128 .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
129
130 let adjust_index = if self.one_based {
131 one_based_index
132 } else {
133 zero_based_index
134 };
135
136 match child_value.data_type() {
137 DataType::List(_) => {
138 let list_array = as_list_array(&child_value)?;
139 let index_array = as_int32_array(&ordinal_value)?;
140
141 list_extract(
142 list_array,
143 index_array,
144 &default_value,
145 self.fail_on_error,
146 adjust_index,
147 )
148 }
149 DataType::LargeList(_) => {
150 let list_array = as_large_list_array(&child_value)?;
151 let index_array = as_int32_array(&ordinal_value)?;
152
153 list_extract(
154 list_array,
155 index_array,
156 &default_value,
157 self.fail_on_error,
158 adjust_index,
159 )
160 }
161 data_type => Err(DataFusionError::Internal(format!(
162 "Unexpected child type for ListExtract: {data_type:?}"
163 ))),
164 }
165 }
166
167 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
168 vec![&self.child, &self.ordinal]
169 }
170
171 fn with_new_children(
172 self: Arc<Self>,
173 children: Vec<Arc<dyn PhysicalExpr>>,
174 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
175 match children.len() {
176 2 => Ok(Arc::new(ListExtract::new(
177 Arc::clone(&children[0]),
178 Arc::clone(&children[1]),
179 self.default_value.clone(),
180 self.one_based,
181 self.fail_on_error,
182 ))),
183 _ => internal_err!("ListExtract should have exactly two children"),
184 }
185 }
186}
187
188fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
189 if index == 0 {
190 return Err(DataFusionError::Execution(
191 "Invalid index of 0 for one-based ListExtract".to_string(),
192 ));
193 }
194
195 let abs_index = index.abs().as_usize();
196 if abs_index <= len {
197 if index > 0 {
198 Ok(Some(abs_index - 1))
199 } else {
200 Ok(Some(len - abs_index))
201 }
202 } else {
203 Ok(None)
204 }
205}
206
207fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
208 if index < 0 {
209 Ok(None)
210 } else {
211 let positive_index = index.as_usize();
212 if positive_index < len {
213 Ok(Some(positive_index))
214 } else {
215 Ok(None)
216 }
217 }
218}
219
220fn list_extract<O: OffsetSizeTrait>(
221 list_array: &GenericListArray<O>,
222 index_array: &Int32Array,
223 default_value: &ScalarValue,
224 fail_on_error: bool,
225 adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
226) -> DataFusionResult<ColumnarValue> {
227 let values = list_array.values();
228 let offsets = list_array.offsets();
229
230 let data = values.to_data();
231
232 let default_data = default_value.to_array()?.to_data();
233
234 let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len());
235
236 for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() {
237 let start = offset_window[0].as_usize();
238 let len = offset_window[1].as_usize() - start;
239
240 if let Some(i) = adjust_index(*index, len)? {
241 mutable.extend(0, start + i, start + i + 1);
242 } else if list_array.is_null(row) {
243 mutable.extend_nulls(1);
244 } else if fail_on_error {
245 return Err(DataFusionError::Execution(
246 "Index out of bounds for array".to_string(),
247 ));
248 } else {
249 mutable.extend(1, 0, 1);
250 }
251 }
252
253 let data = mutable.freeze();
254 Ok(ColumnarValue::Array(arrow::array::make_array(data)))
255}
256
257impl Display for ListExtract {
258 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259 write!(
260 f,
261 "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]",
262 self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error
263 )
264 }
265}
266
267#[cfg(test)]
268mod test {
269 use super::*;
270 use arrow::array::{Array, Int32Array, ListArray};
271 use arrow::datatypes::Int32Type;
272 use datafusion::common::{Result, ScalarValue};
273 use datafusion::physical_plan::ColumnarValue;
274
275 #[test]
276 fn test_list_extract_default_value() -> Result<()> {
277 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
278 Some(vec![Some(1)]),
279 None,
280 Some(vec![]),
281 ]);
282 let indices = Int32Array::from(vec![0, 0, 0]);
283
284 let null_default = ScalarValue::Int32(None);
285
286 let ColumnarValue::Array(result) =
287 list_extract(&list, &indices, &null_default, false, zero_based_index)?
288 else {
289 unreachable!()
290 };
291
292 assert_eq!(
293 &result.to_data(),
294 &Int32Array::from(vec![Some(1), None, None]).to_data()
295 );
296
297 let zero_default = ScalarValue::Int32(Some(0));
298
299 let ColumnarValue::Array(result) =
300 list_extract(&list, &indices, &zero_default, false, zero_based_index)?
301 else {
302 unreachable!()
303 };
304
305 assert_eq!(
306 &result.to_data(),
307 &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
308 );
309 Ok(())
310 }
311}