datafusion_comet_spark_expr/array_funcs/
get_array_struct_fields.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{make_array, Array, GenericListArray, OffsetSizeTrait, StructArray};
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::{DataType, FieldRef, Schema};
21use arrow::record_batch::RecordBatch;
22use datafusion::common::{
23    cast::{as_large_list_array, as_list_array},
24    internal_err, DataFusionError, Result as DataFusionResult,
25};
26use datafusion::logical_expr::ColumnarValue;
27use datafusion::physical_expr::PhysicalExpr;
28use std::hash::Hash;
29use std::{
30    any::Any,
31    fmt::{Debug, Display, Formatter},
32    sync::Arc,
33};
34
35#[derive(Debug, Eq)]
36pub struct GetArrayStructFields {
37    child: Arc<dyn PhysicalExpr>,
38    ordinal: usize,
39}
40
41impl Hash for GetArrayStructFields {
42    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
43        self.child.hash(state);
44        self.ordinal.hash(state);
45    }
46}
47impl PartialEq for GetArrayStructFields {
48    fn eq(&self, other: &Self) -> bool {
49        self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
50    }
51}
52
53impl GetArrayStructFields {
54    pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
55        Self { child, ordinal }
56    }
57
58    fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
59        match self.child.data_type(input_schema)? {
60            DataType::List(field) | DataType::LargeList(field) => Ok(field),
61            data_type => Err(DataFusionError::Internal(format!(
62                "Unexpected data type in GetArrayStructFields: {data_type:?}"
63            ))),
64        }
65    }
66
67    fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
68        match self.list_field(input_schema)?.data_type() {
69            DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
70            data_type => Err(DataFusionError::Internal(format!(
71                "Unexpected data type in GetArrayStructFields: {data_type:?}"
72            ))),
73        }
74    }
75}
76
77impl PhysicalExpr for GetArrayStructFields {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
83        let struct_field = self.child_field(input_schema)?;
84        match self.child.data_type(input_schema)? {
85            DataType::List(_) => Ok(DataType::List(struct_field)),
86            DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
87            data_type => Err(DataFusionError::Internal(format!(
88                "Unexpected data type in GetArrayStructFields: {data_type:?}"
89            ))),
90        }
91    }
92
93    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
94        Ok(self.list_field(input_schema)?.is_nullable()
95            || self.child_field(input_schema)?.is_nullable())
96    }
97
98    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
99        let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
100
101        match child_value.data_type() {
102            DataType::List(_) => {
103                let list_array = as_list_array(&child_value)?;
104
105                get_array_struct_fields(list_array, self.ordinal)
106            }
107            DataType::LargeList(_) => {
108                let list_array = as_large_list_array(&child_value)?;
109
110                get_array_struct_fields(list_array, self.ordinal)
111            }
112            data_type => Err(DataFusionError::Internal(format!(
113                "Unexpected child type for ListExtract: {data_type:?}"
114            ))),
115        }
116    }
117
118    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119        vec![&self.child]
120    }
121
122    fn with_new_children(
123        self: Arc<Self>,
124        children: Vec<Arc<dyn PhysicalExpr>>,
125    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
126        match children.len() {
127            1 => Ok(Arc::new(GetArrayStructFields::new(
128                Arc::clone(&children[0]),
129                self.ordinal,
130            ))),
131            _ => internal_err!("GetArrayStructFields should have exactly one child"),
132        }
133    }
134
135    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
136        unimplemented!()
137    }
138}
139
140fn get_array_struct_fields<O: OffsetSizeTrait>(
141    list_array: &GenericListArray<O>,
142    ordinal: usize,
143) -> DataFusionResult<ColumnarValue> {
144    let values = list_array
145        .values()
146        .as_any()
147        .downcast_ref::<StructArray>()
148        .expect("A StructType is expected");
149
150    let field = Arc::clone(&values.fields()[ordinal]);
151    // Get struct column by ordinal
152    let extracted_column = values.column(ordinal);
153
154    let data = if values.null_count() == extracted_column.null_count() {
155        Arc::clone(extracted_column)
156    } else {
157        // In some cases the column obtained from struct by ordinal doesn't
158        // represent all nulls that imposed by parent values.
159        // This maybe caused by a low level reader bug and needs more investigation.
160        // For this specific case we patch the null buffer for the column by merging nulls buffers
161        // from parent and column
162        let merged_nulls = NullBuffer::union(values.nulls(), extracted_column.nulls());
163        make_array(
164            extracted_column
165                .into_data()
166                .into_builder()
167                .nulls(merged_nulls)
168                .build()?,
169        )
170    };
171
172    let array = GenericListArray::new(
173        field,
174        list_array.offsets().clone(),
175        data,
176        list_array.nulls().cloned(),
177    );
178
179    Ok(ColumnarValue::Array(Arc::new(array)))
180}
181
182impl Display for GetArrayStructFields {
183    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184        write!(
185            f,
186            "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
187            self.child, self.ordinal
188        )
189    }
190}