datafusion_comet_spark_expr/array_funcs/
get_array_struct_fields.rs1use arrow::array::{make_array, Array, GenericListArray, OffsetSizeTrait, StructArray};
19use arrow::buffer::NullBuffer;
20use arrow::datatypes::{DataType, FieldRef, Schema};
21use arrow::record_batch::RecordBatch;
22use datafusion::common::{
23 cast::{as_large_list_array, as_list_array},
24 internal_err, DataFusionError, Result as DataFusionResult,
25};
26use datafusion::logical_expr::ColumnarValue;
27use datafusion::physical_expr::PhysicalExpr;
28use std::hash::Hash;
29use std::{
30 any::Any,
31 fmt::{Debug, Display, Formatter},
32 sync::Arc,
33};
34
35#[derive(Debug, Eq)]
36pub struct GetArrayStructFields {
37 child: Arc<dyn PhysicalExpr>,
38 ordinal: usize,
39}
40
41impl Hash for GetArrayStructFields {
42 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
43 self.child.hash(state);
44 self.ordinal.hash(state);
45 }
46}
47impl PartialEq for GetArrayStructFields {
48 fn eq(&self, other: &Self) -> bool {
49 self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
50 }
51}
52
53impl GetArrayStructFields {
54 pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
55 Self { child, ordinal }
56 }
57
58 fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
59 match self.child.data_type(input_schema)? {
60 DataType::List(field) | DataType::LargeList(field) => Ok(field),
61 data_type => Err(DataFusionError::Internal(format!(
62 "Unexpected data type in GetArrayStructFields: {data_type:?}"
63 ))),
64 }
65 }
66
67 fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
68 match self.list_field(input_schema)?.data_type() {
69 DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
70 data_type => Err(DataFusionError::Internal(format!(
71 "Unexpected data type in GetArrayStructFields: {data_type:?}"
72 ))),
73 }
74 }
75}
76
77impl PhysicalExpr for GetArrayStructFields {
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81
82 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
83 let struct_field = self.child_field(input_schema)?;
84 match self.child.data_type(input_schema)? {
85 DataType::List(_) => Ok(DataType::List(struct_field)),
86 DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
87 data_type => Err(DataFusionError::Internal(format!(
88 "Unexpected data type in GetArrayStructFields: {data_type:?}"
89 ))),
90 }
91 }
92
93 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
94 Ok(self.list_field(input_schema)?.is_nullable()
95 || self.child_field(input_schema)?.is_nullable())
96 }
97
98 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
99 let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
100
101 match child_value.data_type() {
102 DataType::List(_) => {
103 let list_array = as_list_array(&child_value)?;
104
105 get_array_struct_fields(list_array, self.ordinal)
106 }
107 DataType::LargeList(_) => {
108 let list_array = as_large_list_array(&child_value)?;
109
110 get_array_struct_fields(list_array, self.ordinal)
111 }
112 data_type => Err(DataFusionError::Internal(format!(
113 "Unexpected child type for ListExtract: {data_type:?}"
114 ))),
115 }
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.child]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
126 match children.len() {
127 1 => Ok(Arc::new(GetArrayStructFields::new(
128 Arc::clone(&children[0]),
129 self.ordinal,
130 ))),
131 _ => internal_err!("GetArrayStructFields should have exactly one child"),
132 }
133 }
134
135 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
136 unimplemented!()
137 }
138}
139
140fn get_array_struct_fields<O: OffsetSizeTrait>(
141 list_array: &GenericListArray<O>,
142 ordinal: usize,
143) -> DataFusionResult<ColumnarValue> {
144 let values = list_array
145 .values()
146 .as_any()
147 .downcast_ref::<StructArray>()
148 .expect("A StructType is expected");
149
150 let field = Arc::clone(&values.fields()[ordinal]);
151 let extracted_column = values.column(ordinal);
153
154 let data = if values.null_count() == extracted_column.null_count() {
155 Arc::clone(extracted_column)
156 } else {
157 let merged_nulls = NullBuffer::union(values.nulls(), extracted_column.nulls());
163 make_array(
164 extracted_column
165 .into_data()
166 .into_builder()
167 .nulls(merged_nulls)
168 .build()?,
169 )
170 };
171
172 let array = GenericListArray::new(
173 field,
174 list_array.offsets().clone(),
175 data,
176 list_array.nulls().cloned(),
177 );
178
179 Ok(ColumnarValue::Array(Arc::new(array)))
180}
181
182impl Display for GetArrayStructFields {
183 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184 write!(
185 f,
186 "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
187 self.child, self.ordinal
188 )
189 }
190}