datafusion_comet_spark_expr/array_funcs/
get_array_struct_fields.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::record_batch::RecordBatch;
19use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{
23    cast::{as_large_list_array, as_list_array},
24    internal_err, DataFusionError, Result as DataFusionResult,
25};
26use datafusion_physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29    any::Any,
30    fmt::{Debug, Display, Formatter},
31    sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct GetArrayStructFields {
36    child: Arc<dyn PhysicalExpr>,
37    ordinal: usize,
38}
39
40impl Hash for GetArrayStructFields {
41    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42        self.child.hash(state);
43        self.ordinal.hash(state);
44    }
45}
46impl PartialEq for GetArrayStructFields {
47    fn eq(&self, other: &Self) -> bool {
48        self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
49    }
50}
51
52impl GetArrayStructFields {
53    pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
54        Self { child, ordinal }
55    }
56
57    fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
58        match self.child.data_type(input_schema)? {
59            DataType::List(field) | DataType::LargeList(field) => Ok(field),
60            data_type => Err(DataFusionError::Internal(format!(
61                "Unexpected data type in GetArrayStructFields: {:?}",
62                data_type
63            ))),
64        }
65    }
66
67    fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
68        match self.list_field(input_schema)?.data_type() {
69            DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
70            data_type => Err(DataFusionError::Internal(format!(
71                "Unexpected data type in GetArrayStructFields: {:?}",
72                data_type
73            ))),
74        }
75    }
76}
77
78impl PhysicalExpr for GetArrayStructFields {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
84        let struct_field = self.child_field(input_schema)?;
85        match self.child.data_type(input_schema)? {
86            DataType::List(_) => Ok(DataType::List(struct_field)),
87            DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
88            data_type => Err(DataFusionError::Internal(format!(
89                "Unexpected data type in GetArrayStructFields: {:?}",
90                data_type
91            ))),
92        }
93    }
94
95    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
96        Ok(self.list_field(input_schema)?.is_nullable()
97            || self.child_field(input_schema)?.is_nullable())
98    }
99
100    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
101        let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
102
103        match child_value.data_type() {
104            DataType::List(_) => {
105                let list_array = as_list_array(&child_value)?;
106
107                get_array_struct_fields(list_array, self.ordinal)
108            }
109            DataType::LargeList(_) => {
110                let list_array = as_large_list_array(&child_value)?;
111
112                get_array_struct_fields(list_array, self.ordinal)
113            }
114            data_type => Err(DataFusionError::Internal(format!(
115                "Unexpected child type for ListExtract: {:?}",
116                data_type
117            ))),
118        }
119    }
120
121    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
122        vec![&self.child]
123    }
124
125    fn with_new_children(
126        self: Arc<Self>,
127        children: Vec<Arc<dyn PhysicalExpr>>,
128    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
129        match children.len() {
130            1 => Ok(Arc::new(GetArrayStructFields::new(
131                Arc::clone(&children[0]),
132                self.ordinal,
133            ))),
134            _ => internal_err!("GetArrayStructFields should have exactly one child"),
135        }
136    }
137}
138
139fn get_array_struct_fields<O: OffsetSizeTrait>(
140    list_array: &GenericListArray<O>,
141    ordinal: usize,
142) -> DataFusionResult<ColumnarValue> {
143    let values = list_array
144        .values()
145        .as_any()
146        .downcast_ref::<StructArray>()
147        .expect("A struct is expected");
148
149    let column = Arc::clone(values.column(ordinal));
150    let field = Arc::clone(&values.fields()[ordinal]);
151
152    let offsets = list_array.offsets();
153    let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned());
154
155    Ok(ColumnarValue::Array(Arc::new(array)))
156}
157
158impl Display for GetArrayStructFields {
159    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160        write!(
161            f,
162            "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
163            self.child, self.ordinal
164        )
165    }
166}